Увеличение потерь при обучении с каждой эпохой в реализации PyTorch ResNet

Вопрос или проблема

Я реализую сеть ResNet с нуля, используя PyTorch. Эта сеть уникальна для моих требований, так как мне нужно выполнять классификацию изображений спутниковой съемки с 14 различными каналами и размерами 8×8 пикселей. Мои обучающие данные состоят из 588 обучающих изображений для 7 различных классов

Однако у меня проблема в том, что во время обучения потери начинают сильно увеличиваться сразу после первой эпохи, и достигают бесконечности в течение 10 эпох, а точность застревает на уровне 0.17.

Первоначальное преобразование данных заключается в изменении размера изображений с 8×8 пикселей до 64×64 пикселей.

Моя сеть ResNet состоит из:

  1. Одного начального блока свертки – начальный слой для уменьшения размеров изображения при увеличении количества каналов.

  2. Трех остаточных блоков – эти остаточные блоки имеют структура Ботлнек (1×1 с x/2 каналами, 3×3 с x/2 каналами, 1×1 с x каналами). Здесь первые 2 блока имеют одинаковую структуру, после чего я уменьшаю размеры изображения, удваивая количество каналов, и затем передаю его в третий остаточный блок. Все мои остаточные блоки имеют слои ReLU, слои нормализации (BN) между операциями свертки, а также слой Dropout в конце каждого остаточного блока.

  3. Полносвязный слой – состоит из слоя выпрямления, за которым следуют 2 последовательных линейных слоя, разделенных слоем ReLU. Последний слой выводит вероятности классов.

Мой вопрос – каковы могут быть возможные причины таких всплесков потерь во время обучения в сети ResNet? Например, может ли слишком малое количество остаточных блоков вызвать такую проблему? Или это может быть связано с малым размером датасета (588 изображений)? Ниже приведен метод “forward” моего объектного класса ResNet (могу предоставить остальной код, если нужно):

def forward(self, x, conv_channels):

        # Initial Convolution Block
        self.conv1Out = self.conv1block(x)

        # Residual Block 1
        self.conv2Out_1 = self.conv2block(self.conv1Out)
        self.res1 = self.conv2Out_1 + self.conv1Out

        # Residual Block 2 - здесь суммирование происходит после изменения размера
        self.conv2Out_2 = self.conv2block(self.res1)
        self.adjBlock = nn.Conv2d(in_channels=conv_channels,
                                  out_channels=conv_channels*2,
                                  kernel_size=1,
                                  stride=2)
        self.conv2Out_2 = self.adjBlock(self.conv2Out_2)                                 
        self.res1_adj = self.adjBlock(self.res1)
        self.res2 = self.conv2Out_2 + self.res1_adj

        # Residual Block 3
        self.conv3Out = self.conv3block(self.res2)
        self.res3 = self.conv3Out + self.res2

        # Fully Connected Layer
        self.final = self.linear_block(self.res3)
        return self.final
```

Вы проверяли, не происходит ли взрыв градиентов? Думаю, это одна из возможных причин. Попробуйте использовать обрезку градиентов.

Ответ или решение

Увеличение значения функции потерь при обучении нейронной сети ResNet с каждым эпохом, особенно до достижения бесконечности, может быть весьма тревожным симптомом. Разберёмся в возможных причинах и способах их устранения.

Возможные причины:

  1. Проблема с градиентами ("Exploding Gradients"):

    • Ваша сеть может сталкиваться с проблемой разрастающихся градиентов, что и вызывает увеличение функции потерь. Необходимо внимательно посмотреть на динамику обновления весов и попробовать использовать ограничение градиента (gradient clipping), чтобы предотвратить излишнее увеличение градиента.
  2. Маленький размер датасета:

    • Наличие всего 588 изображений может способствовать переобучению. Сеть могла быстро "запомнить" тренировочные данные, но не может обобщить их. Попробуйте увеличить датасет с помощью методов увеличения данных (data augmentation) и убедитесь, что данные правильно перемешиваются.
  3. Неправильная нормализация данных:

    • Проверьте, правильно ли нормализуются ваши данные. Поскольку у вас изображения с 14-ю каналами и размером 8×8 пикселей, переконвертация их на 64×64 может приводить к изменению статистик данных.
  4. Архитектура модели:

    • Возможно, ваша сеть слишком сложная для объёма данных, или наоборот, слишком упрощённая для задачи. Пересмотрите количество блоков и их конфигурацию. Некорректное использование структуры ResNet может привести к возникновению ошибок в вычислениях вычитаемых блоков.
  5. Гиперпараметры обучения:

    • Проверьте выбранный коэффициент обучения (learning rate). Слишком высокий коэффициент может вызвать нестабильность в обучении. Попробуйте использовать адаптивные оптимизаторы, такие как Adam, с уменьшением шага обучения.
  6. Инициализация весов:

    • Неправильная инициализация весов может быть причиной расхождений. Убедитесь, что веса инициализируются методами, подходящими для нелинейных функций (например, по стратегии He).

Рекомендуемые действия:

  1. Мониторинг и ограничение градиентов:

    • Используйте ограничения на величину градиентов (gradient clipping): например, можно добавить torch.nn.utils.clip_grad_norm_.
  2. Улучшение предобработки:

    • Примените улучшенные методы нормализации и увеличения данных, чтобы расширить вашу обучающую выборку.
  3. Анализ архитектуры сети:

    • Проведите эксперименты с количеством и параметрами блоков ResNet, так как их неправильная настройка может значимо повлиять на результаты.
  4. Оптимизация гиперпараметров:

    • Используйте методы оптимизации, обеспечивающие бо́льшую устойчивость, например, шаговое снижение learning rate на плато (ReduceLROnPlateau).

Внимание к этим аспектам и их проработка помогут стабилизировать процесс обучения модели и улучшить её результаты.

Оцените материал
Добавить комментарий

Капча загружается...