Как улучшить модель предсказания видео в Keras?

Вопрос или проблема

Я работаю над моделью прогнозирования преступлений. У меня есть изображения того, как преступления выглядят каждый день в городе в течение года, и я хочу использовать 30 дней преступности для прогнозирования 31-го дня (так же, как пытаюсь предсказать следующий кадр в видео, используя предыдущие кадры). Я создал простую модель в качестве отправной точки, используя слои ConvLSTM, аналогичные этой тетради: https://keras.io/examples/vision/conv_lstm/ (я размещу код в конце). Обучение использует другую обработку пакетов, так как я использую скользящее окно для обучения и тестирования, оно выглядит так:

Эпоха 1: Использую изображения с 1 по 30, чтобы предсказать изображение 31 и настроить параметры сети. Затем перемещаю скользящее окно и использую изображения с 2 по 31 в качестве входных данных и тестирую результаты с изображением 32. Последующие эпохи похожи, пока я не достигну конца своих данных.

Для функции потерь я использую маскированную MSE (вычисляю потери только для индексов, где вектор y_true ненулевой). Я использую эту функцию потерь, так как изображения очень большие (42×46), и некоторые их части всегда равны 0, поэтому я не хочу, чтобы нулевые значения влияли на прогнозы модели.

Проблема в том, что модель совершенно не обучается, и я не знаю, что может повлиять на модель.

Примечание: причина, по которой я начал с сети ConvLSTM, заключается в том, что в конце мы хотим получить сеть GAN + VAE, где энкодер функции — это сеть, похожая на ту, что у меня есть.

У вас есть какие-либо предложения по улучшению модели? Заранее спасибо.

Модель:

def model():
    model = tf.keras.Sequential()
    model.add(ConvLSTM2D(filters=64, kernel_size=(5, 5), padding='same', return_sequences=True, activation="relu", input_shape=(lag, height, width, 1)))
    model.add(BatchNormalization())

    model.add(ConvLSTM2D(filters=64, kernel_size=(3, 3), padding='same', return_sequences=True, activation="relu"))
    model.add(BatchNormalization())

    model.add(ConvLSTM2D(filters=64, kernel_size=(1, 1), padding='same', return_sequences=True, activation="relu"))
    model.add(BatchNormalization())
    
    model.add(Conv3D(filters=1, kernel_size=(30, 1, 1), activation="sigmoid", padding="valid"))
    assert model.output_shape == (None, 1, 42, 46, 1)

    return model

Функция потерь:

mse_loss = MeanSquaredError()

def create_mask(image):
    # Создать маску, где ненулевые значения помечены как 1, а нулевые значения как 0
    return tf.cast(image > 0, tf.bool)

def masked_loss(y_true, y_pred):
    
    mask = create_mask(y_true)

    # Фильтровать y_true и y_pred по маске
    y_true_mask = tf.boolean_mask(y_true, mask)    
    y_pred_mask = tf.boolean_mask(y_pred, mask)
    
    # Вычислить среднюю квадратическую ошибку
    loss = mse_loss(y_true_mask, y_pred_mask)       
    
    return loss

Я пытался добавить больше слоев ConvLSTM, но это не сработало.

Ответ или решение

Для улучшения модели предсказания видео, разработанной с использованием Keras, существуют несколько подходов, которые можно рассмотреть в контексте вашей задачи по предсказанию преступлений на основе изображений. Давайте остановимся на нескольких стратегиях, которые могут помочь вашей модели достичь лучших результатов.

1. Оптимизация архитектуры сети

a. Изменение количества фильтров и объектов

Убедитесь, что вы экспериментируете с различными значениями фильтров в ваших слоях ConvLSTM. Начните с меньших значений (например, 32) и постепенно увеличивайте их, чтобы найти химическую реакцию вашей модели.

b. Добавление слоев

Вместо того, чтобы увеличивать количество слоев ConvLSTM, попробуйте добавить несколько полносвязных слоев или слоев сжатия после ConvLSTM, которые помогут обработать выходные данные.

c. Использование различных функций активации

Экспериментируйте с другими функциями активации, такими как Leaky ReLU или PReLU, чтобы избежать проблемы "исчезающего градиента".

2. Гиперпараметры и оптимизация обучения

a. Размер пакета

Попробуйте изменить размер пакета (batch size). Иногда использование меньших пакетов помогает модели лучше обобщать данные.

b. Изменение скорости обучения

Обратите внимание на гиперпараметры, такие как скорость обучения. Попробуйте использовать уменьшение скорости обучения при помощи методов, таких как ReduceLROnPlateau, чтобы адаптировать скорость обучения на основе результатов валидации.

3. Аугментация данных

Для улучшения обобщающих свойств модели можно использовать аугментацию данных. Например, вы можете применить повороты, сдвиги и изменения яркости, чтобы увеличить разнообразие обучающих данных и увеличить плотность используемых образцов.

4. Улучшение функции потерь

Существует вероятность, что ваша маскированная функция потерь не полностью захватывает поведение модели в случае неполных данных. Рассмотрите возможность использования других функций потерь, таких как Huber Loss, которая может быть менее чувствительна к выбросам.

5. Мониторинг обучения

a. Запись функции потерь

Следите за значениями функции потерь на каждой эпохе, чтобы удостовериться, что модель действительно обучается. Если значения потерь не уменьшаются, возможно, ваша модель переобучается или неоптимально инициализирована.

b. Визуализация

Используйте визуализацию, чтобы наблюдать за изменениями в градиентах и потерях. Библиотеки, такие как TensorBoard, могут оказаться полезными для визуализации.

6. Обратная связь и переобучение

Наконец, не забывайте о переносе обучения (Transfer Learning). Использование предварительно обученных моделей может значительно ускорить процесс обучения и улучшить качество предсказания. Вы можете использовать начальные слои CNN предварительно обученной модели (такие как VGG16 или ResNet) для извлечения признаков, а затем подключить их к вашей модели.

Заключение

Итак, ваша задача заключается в использовании различных комбинаций этих подходов для улучшения модели. Экспериментируйте, следите за процессом обучения и не бойтесь пересматривать свою архитектуру. Удачи в вашей работе над предсказанием преступлений!

Оцените материал
Добавить комментарий

Капча загружается...