Как справиться с сильным переобучением в сверточной нейронной сети UNet с кодировщиком/декодировщиком в задаче, очень похожей на перевод изображений?

Вопрос или проблема

Я пытаюсь подогнать CNN UNet под задачу, очень похожую на перевод изображений. Входные данные сети – это бинарная матрица размером (64,256), а выход – размером (64,32). Колонки представляют собой статус канала связи, где каждое значение в колонне – это статус подсистемы. 1 означает, что подсистема занята, а 0 – что подсистема свободна. Горизонтальная ось представляет поток времени. Таким образом, первая колонка – это статус канала в 1-й временной интервал, вторая колонка – статус во 2-й временной интервал и так далее. Задача состоит в том, чтобы предсказать статус канала в следующих 32 временных интервалах, исходя из предыдущих 256 временных интервалов, которые я рассматривал как перевод изображений.
Точность на обучающих данных составляет около 90%, в то время как точность на тестовых данных около 50%. Под точностью я здесь имею в виду средний процент правильных значений в каждом изображении. Также во время обучения значение потерь на валидации увеличивается, в то время как потери уменьшаются, что является явным признаком переобучения. Я пробовал большинство техник регуляризации и также пытался снизить мощность модели, но это лишь снижает ошибку обучения, не улучшая ошибку обобщения. Есть ли какие-либо советы или идеи? Я включил в следующую часть кривую обучения для обучения на 1000 образцах, реализацию сети и образцы из обучающего и тестового наборов.

Кривые обучения на 1000 образцах

3 образца из обучающего набора

3 образца из тестового набора

Вот реализация сети:

def define_encoder_block(layer_in, n_filters, batchnorm=True):
    # инициализация весов
    init = RandomNormal(stddev=0.02)
    # добавление слоя downsampling
    g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same',
               kernel_initializer=init)(layer_in)
    # условное добавление batch normalization
    if batchnorm:
        g = BatchNormalization()(g, training=True)
    # активация leaky relu
    g = LeakyReLU(alpha=0.2)(g)
    return g
 
# определение блока декодера
def decoder_block(layer_in, skip_in, n_filters, filter_strides, dropout=True, skip=True):
  # инициализация весов
  init = RandomNormal(stddev=0.02)
    # добавление слоя upsampling
  g = Conv2DTranspose(n_filters, (4,4), strides=filter_strides, padding='same', 
                         kernel_initializer=init)(layer_in)
    # добавление batch normalization
  g = BatchNormalization()(g, training=True)
    # условное добавление dropout
  if dropout:
    g = Dropout(0.5)(g, training=True)
  if skip:
    g = Concatenate()([g, skip_in])
    # активация relu
  g = Activation('relu')(g)
  return g
 
# определение отдельно стоящей модели генератора
def define_generator(image_shape=(64,256,1)):
    # инициализация весов
    init = RandomNormal(stddev=0.02)
    # вход изображения
    in_image = Input(shape=image_shape)
    e1 = define_encoder_block(in_image, 64, batchnorm=False)
    e2 = define_encoder_block(e1, 128)
    e3 = define_encoder_block(e2, 256)
    e4 = define_encoder_block(e3, 512)
    e5 = define_encoder_block(e4, 512)
    e6 = define_encoder_block(e5, 512)
    e7 = define_encoder_block(e6, 512)
    # узкое место, без batch norm и relu
    b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
    b = Activation('relu')(b)
    # модель декодера
    d1 = decoder_block(b, e7, 512, (1,2))
    d2 = decoder_block(d1, e6, 512, (1,2))
    d3 = decoder_block(d2, e5, 512, (2,2))
    d4 = decoder_block(d3, e4, 512, (2,2), dropout=False)
    d5 = decoder_block(d4, e3, 256, (2,2), dropout=False)
    d6 = decoder_block(d5, e2, 128, (2,1), dropout=False, skip=False)
    d7 = decoder_block(d6, e1, 64, (2,1), dropout=False, skip=False)
    # выход
    g = Conv2DTranspose(1, (4,4), strides=(2,1), padding='same', kernel_initializer=init)(d7)
    out_image = Activation('sigmoid')(g)
    # определение модели
    model = Model(in_image, out_image)
    return model

Я не уверен, что то, что вы видите выше, на самом деле является проблемой переобучения. То, что я вижу, это то, что кривая валидации на самом деле вообще не снижается, трудно сказать, что это переобучение. Обычно в случае переобучения мы видим, что значение потерь на валидации снижается, но в какой-то момент оно начинает снова расти.

Все это указывает на то, что либо валидационный набор не похож на обучающий набор, либо у модели нет мощности для того, чтобы действительно изучить паттерны, но она способна их запомнить (это на самом деле проблема тонкого смещения).

Одна проблема, которую я также вижу в вашей реализации, это то, что вы сначала применяете BatchNorm, а затем применяете dropout. Dropout применяется после активации, поэтому переместите это. Это может вызывать проблемы.

Для начала, что произойдет, если регуляризацию полностью снизить? Например, если вы уберете dropout и термины batchnorm?

Ответ или решение

Как справиться с серьезным переобучением в UNet Encoder/Decoder CNN для задач, схожих с переводом изображений

Ваша задача заключается в предсказании статуса коммуникационного канала в будущих временных интервалах на основе данных предыдущих временных отсечек, задаваемых в виде бинарной матрицы. При этом вы столкнулись с проблемой серьезного переобучения вашей модели UNet, что видно по разнице в точности между обучающей и тестовой выборками, а также по росту значения валидационной потери во время обучения, что указывает на явное переобучение.

Проблема переобучения

Вы наблюдаете высокое значение точности на обучающей выборке (~90%) и значительно более низкое на тестовой выборке (~50%), что указывает на то, что модель отлично запоминает обучающие данные, но не может обобщать информацию. Это может быть следствием нескольких факторов:

  1. Недостаточная подгонка модели к произошедшим данным
  2. Недостаточная сложность модели для обучения специфическим паттернам.
  3. Неправильное распределение данных в обучающем и тестовом наборах, включая потенциальную разницу в их форме и содержании.

Распределение данных

Перед тем как применять какие-либо методы решения, убедитесь, что:

  • Данные для обучения и тестирования полностью соответствуют друг другу по распределению, как по количеству, так и по особенностям. Если ваши входные данные (64, 256) и выходные данные (64, 32) сильно различаются, это может быть основной причиной.

Техника регуляризации

Вы упомянули, что пробовали различные методы регуляризации, но безуспешно. Есть несколько аспектов, которые стоит пересмотреть:

  1. Групповая нормализация (Batch Normalization):

    • В вашей реализации Batch Normalization применяется до Dropout. Рекомендуется изменять местоположение этих операций, переключив порядок на Dropout перед Batch Normalization.
  2. Dropout:

    • В вашем случае Dropout установлен на 0.5. Поэкспериментируйте с разными значениями (например, 0.3 или 0.1) или попробуйте полностью отключить Dropout, чтобы увидеть, как это повлияет на переобучение.
  3. Увеличение данных (Data Augmentation):

    • Примените техники увеличения данных, такие как случайные сдвиги, повороты и масштабирование, чтобы обогатить набор обучающих данных.
  4. Снижение сложности модели:

    • Возможно, ваша модель слишком сложна для задачи. Уменьшение количества фильтров в свёрточных слоях поможет предотвратить переобучение, но при этом сохранит возможность обучения.
  5. Ранние остановки и кросс-валидация:

    • Добавление механизма ранних остановок поможет остановить обучение до того, как произойдет переобучение. Кросс-валидация также может помочь оценить обобщающую способность модели.

Параметры обучения

  • Размеры пакетов и скорость обучения (Learning Rate):

    • Попробуйте разные размеры пакетов (batch sizes) и скорости обучения. Использование адаптивных способов, таких как Adam или RMSprop, может улучшить динамику обучения.
  • Функция потерь:

    • Убедитесь, что функция потерь справедливо учитывает ваш набор данных. Возможно, стоит рассмотреть кастомизированные функции потерь или использование скомбинированных подходов.

Наблюдения и вывод

Проблема с переобучением модели UNet, вероятно, связана как с распределением данных, так и с настройками модели. Проверьте корректность данных, пересмотрите способы регуляризации и адаптируйте параметры обучения. Уделите внимание всем подходам к решению, которые могут повлиять на результат.

Эти шаги помогут вам не только снизить степень переобучения, но и улучшить обобщающую способность модели на тестовых данных, что в конечном итоге приведет к более высоким показателям точности в задачах перевода изображений.

Оцените материал
Добавить комментарий

Капча загружается...