Почему моя модель Keras не обучается сегментации изображений?

Вопрос или проблема

Редактировать: как выясняется, даже первоначальному создателю модели не удалось успешно ее донастроить. Скорее всего, это проблема реализации или потенциально связана с неинтуитивным способом работы слоя пакетной нормализации Keras.

Я пытаюсь донастроить эту реализацию Keras модели DeepLab v3+ от Google на собственном наборе данных, полученном из неувеличенной эталонной выборки Pascal VOC 2012 (1449 обучающих примеров) для своих исследовательских целей.

Я решил сначала попробовать просто дообучить ее на оригинальном наборе данных Pascal VOC и попытаться получить результаты, близкие к тем, что в статье. Автор репозитория, похоже, смог это сделать, так что шансы на то, что модель Keras неверна, невелики. Я успешно загрузил предобученную на ImageNet модель (из официального зоопарка моделей Google), и карты признаков явно показывают, что модель способна различать все объекты на изображениях, которые ей подаются (см. рисунки ниже).

Я замораживаю первые 356 слоев, которые соответствуют исходной главной модели (в моем случае это Xception). Я добавил дополнительный финальный softmax слой в модель, поскольку модель из статьи изначально выдает логиты. Связано с этим выбором, в наборе данных есть класс фона. Я использую оптимизатор tf.keras.optimizers.Adadelta.

Тем не менее, после недель корректировок и изучения, я все еще не могу заставить модель учиться или делать что-то стоящее в сегментации.

Я пробовал использовать десятки различных функций потерь и точности, найденных здесь и там в интернете, в основном варианты потерь перекрестной энтропии по пикселям и мягких потерь Dice, а также изменял скорость обучения от $10^{-1}$ до $10^{-5}$ (авторы использовали $10^{-2}$ в оригинальной статье), и каждый раз получаю одно и то же; значение потерь фактически колеблется вокруг довольно маленького значения, и затем обратный вызов ранней остановки, который я использую, прекращает обучающий этап после $7$ или $8$ эпох.

Вот как меняется типичная метрика, если я не останавливаю процесс (в этом случае скорость обучения была установлена на $10^{-5}$, размер батча $10$):

введите описание изображения здесь

Я решил сделать предсказание на том же изображении после каждой эпохи, и вот как это выглядит после первой эпохи (в правом нижнем углу “метки” – это просто argmax на картах признаков):

введите описание изображения здесь

И после 20 эпох:

введите описание изображения здесь

Все промежуточные результаты выглядят примерно одинаково, по-видимому, независимо от гиперпараметров.

Я даже пробовал использовать функции точности и потерь, которые автор репозитория сказал, что он использовал для того, чтобы сделать именно то, что я пытаюсь сделать, но получаю те же хаотичные кривые метрик.

У меня действительно заканчиваются идеи о том, откуда это может происходить. Я был бы рад получить подсказки о том, где мне стоит поискать возможную ошибку, которую я мог сделать.


Детали потока данных

Я использую API наборов данных TensorFlow (по сути, следуя этому очень хорошему руководству) для загрузки набора данных в память. Указанный набор данных был заранее перемешан и разбит на $140$ фрагментов по $10$ примеров, что является максимальным размером батча, который я могу использовать на своем оборудовании. Затем я выбираю перемешанный набор фрагментов и предобрабатываю примеры в них, изменяя масштаб/дополнив/обрезая их до размеров $512 \times 512$ с интенсивностями значений между $-1$ и $1$, преобразую их в тензоры tf.float32 и создаю $21$ бинарных масок для каждого класса набора данных.

  • входной тензор имеет размер батча $(10, 512, 512, 3)$ со значениями в $[-1, 1$] и закодирован в float32;
  • ассоциированная истинная метка – это тензор размером $(10, 512, 512, 21)$ со значениями $0$, $1$ или $255$ (последнее значение используется для “двусмысленных” или дополненных областей; в свою очередь части изображения, которые нужно игнорировать).

Функции потерь и точности

Я начинаю с игнорирования меток и предсказаний в игнорируемых областях (см. значение $255$ выше):

def get_valid_labels_and_logits(y_true, y_pred):
    valid_labels_mask = tf.not_equal(y_true, 255.0)
    indices_to_keep = tf.where(valid_labels_mask)
    valid_labels = tf.gather_nd(params=y_true, indices=indices_to_keep)
    valid_logits = tf.gather_nd(params=y_pred, indices=indices_to_keep)

    return valid_labels, valid_logits

Я трижды проверял это на крошечном пользовательском $2 \times 3$ изображении, и это работает, как ожидалось.

Далее я вычисляю потери Dice, усредненные по всем классам, как это определено в этой статье:

def soft_dice_loss(y_true, y_pred):
    y_true, y_pred = get_valid_labels_and_logits(y_true, y_pred)
    # Следующие тензоры имеют размер (num_batches, num_classes)
    interception_volume = tf.reduce_sum(tf.reduce_sum(y_true * y_pred, axis=1), axis=1)
    y_true_sum_per_class = tf.reduce_sum(tf.reduce_sum(y_true, axis=1), axis=1)
    y_pred_sum_per_class = tf.reduce_sum(tf.reduce_sum(y_pred, axis=1), axis=1)

    return tf.reduce_mean(1.0 - 2.0 * interception_volume / (y_true_sum_per_class + y_pred_sum_per_class))

Я пробовал разные варианты этого, включая встроенную перекрестную энтропию и бинарную перекрестную энтропию, но это не изменило поведение.

Стандартная функция точности, похоже, не работает, поэтому я реализовал пользовательскую функцию точности средней IoU, которая хорошо работает на примерах, которые я ввел вручную.

Извините за этот большой объем текста, но я хотел сделать ситуацию ясной. Большое спасибо за вашу добрую помощь и советы!

Согласно комментарию Педро Энрике Монфорт, поскольку у пользователя было достаточно времени, чтобы сделать это самостоятельно, я превратил “ответ”, отредактированный в ответ на сообщение пользователя, в настоящий ответ:

как выясняется, даже первоначальному создателю модели не удалось успешно ее донастроить. Скорее всего, это проблема реализации или потенциально связана с неинтуитивным способом работы слоя пакетной нормализации Keras.

Ответ пользователя ссылается на эту страницу:

https://blog.datumbox.com/the-batch-normalization-layer-of-keras-is-broken/

…которая описывает проблемы со слоем пакетной нормализации (BN) Keras, особенно при использовании с Transfer Learning. Слой BN, критически важный в глубоком обучении для более быстрого обучения и уменьшения затухания градиентов, ведет себя иначе в режимах обучения и вывода. В Keras, когда слои BN заморожены, они неправильно используют статистику мини-батча во время обучения вместо заранее изученных средних значений и дисперсий, что вызывает проблемы в Transfer Learning, где слои повторно используются с новыми данными. Автор предлагает изменить слой BN, чтобы он полагался на статистику оригинального набора данных, когда заморожен, обеспечивая однородное масштабирование данных и более точные предсказания.

Наконец, страница автора на github, которая реализует их идеи здесь

Ответ или решение

Когда моделируется задача сегментации изображений с использованием Keras, разные аспекты реализации могут повлиять на способность модели к обучению. Рассмотрим основные причины, по которым ваша модель может не обучаться должным образом, а также советы по устранению проблем.

Анализ проблем

  1. Неправильная обработка данных:

    • Убедитесь, что входные изображения и маски имеют правильные размеры и значения. Ваши входные данные должны быть нормализованы, а маски содержать классы, приведённые к одному масштабу (например, 0 и 1 для бинарной классификации или от 0 до 20 для многоклассовой сегментации). Кроме того, стоит убедиться, что область с 255, которая обозначает игнорируемые пиксели, не участвует в расчёте потерь.
  2. Настройка модели:

    • Замораживание первых 356 слоев (Xception) может препятствовать обучению модели. Это может быть особенно критично, если ваша новая задача значительно отличается от задач, на которых оригинальная модель была обучена. Попробуйте замораживать меньшее количество слоев или вовсе не замораживать их.
  3. Проблемы с функцией потерь:

    • Ваша реализация soft_dice_loss требует проверки. Убедитесь, что она правильно учитывает действительные классы в масках, и что расчеты ведутся только на валидных пикселях без игнорируемых значений. Порой использование промежуточных функций, таких как cross-entropy, может давать более предсказуемые результаты, если информация о классе не сильно различается.
  4. Скорость обучения и оптимизатор:

    • Подбор Learning Rate имеет критическое значение. Если модель не обучается, возможно, стоит попробовать снизить его до более низких значений (например, $10^{-6}$), что может помочь избежать проблем с нестабильностью потерь.
  5. Проблемы с нормализацией:

    • Как указано в исходном тексте, проблема может быть связана с работой слоя пакетной нормализации (Batch Normalization). В Keras, если заморожены слои BN, они могут использовать статистику мини-пакетов вместо предварительно обученных значений. Это может помешать корректной адаптации модели к новой задаче. Попробуйте использовать альтернативные реализации, такие как те, что описаны на GitHub-странице Datumbox.
  6. Очень малая обучающая выборка:

    • Поврежденные или слишком малые данные для обучения могут быть причиной того, что модель не способна выявить закономерности. В вашем случае, 1449 примеров может не быть достаточно для сложных задач сегментации, особенно если данные сильно разнообразны.

Рекомендации по улучшению

  • Наличие валидационной выборки: Используйте валидационную выборку для контроля переобучения и оценки производительности модели на разных этапах обучения.

  • Аугментация данных: Применение различных методов аугментации (повороты, масштабирование, сдвиги) может значительно улучшить результаты и сделать модель более устойчивой.

  • Проверка на маленьком наборе данных: Перед обучением на полноразмерном наборе данных, протестируйте модель на меньшем, но разнообразном подмножества, чтобы быстро отладить архитектуру и параметры.

  • Визуализация результатов: Обязательно визуализируйте промежуточные результаты предсказания модели, чтобы понять, на каких элементах модель обучается, а на каких нет.

В целом, стоит подходить к проблеме с комплексом решений, изменяя сразу несколько факторов, которые могут влиять на результаты. Это многогранный процесс, требующий терпения и систематического подхода.

Оцените материал
Добавить комментарий

Капча загружается...