Числовая точность в Flux.jl

Вопрос или проблема

Я пытаюсь изучить обучение нейронных сетей в рамках динамических систем, рассматривая модель как систему, а обучение как динамику во временной эволюции. В качестве расширения я попытался сделать так, чтобы обучение проходило в “гамильтоновом” режиме – я реализовал это, сделав так, чтобы модель состояла из комплексных чисел, веса и смещения были комплекснозначными матрицами. Теперь, изменяя скорость обучения в градиентном спуске на чисто мнимую (участвуют и другие числовые факторы, которые я опускаю в описании для краткости, но которые реализованы в моем коде), динамика становится гамильтоновой, т.е. функция потерь остается постоянной во время обучения/динамики.

Когда я смотрю на функцию потерь для такой конфигурации, она действительно остается постоянной, как и предполагалось, но лишь в течение нескольких эпох, после чего она либо увеличивается, либо уменьшается и ведет себя, в общем, хаотично. Сначала я предположил, что это связано с сложной структурой поверхности потерь и попытался реализовать более эффективные методы для “обучения” системы – заставив оптимизатор градиентного спуска использовать симплектический метод Эйлера вместо обновления, подобного Эйлеру, использующегося в обычном обновлении градиентного спуска (я фактически рассматриваю обучение как численное интегрирование динамической системы). Но это все равно не исправило проблему хаотичного поведения функции потерь.

На графике ниже, не обращайте внимания на синюю линию и обращайте внимание только на красную. Я обучаю систему в режиме градиентного спуска в течение 100 эпох, после чего изменяю скорость обучения на чисто мнимую. Функция потерь остается постоянной в течение следующих около 300 эпох, после чего она начинает увеличиваться. Другие примеры такого обучения также демонстрируют немонотонное поведение.

Erratic loss in Hamiltonian mode

Затем я попытался построить одномерные срезы поверхности потерь вокруг минимума и обнаружил, что поверхность потерь действительно довольно гладкая. Теперь моя теория заключается в том, что причиной хаотичного поведения после некоторых эпох обучения являются эффекты численной точности. Прав ли я в этом предположении? Если да, то как минимизировать ошибки численной точности, а если нет, в чем моя ошибка?

Я работаю с MNIST для всех анализов, у меня имеется один скрытый слой с 16 нейронами, с перекрестной энтропией в качестве функции потерь. Я использую пользовательские функции активации: zrelu(z) = z, если angle(z) < pi и 0, если angle(z) > pi, и abslu(z) = abs(z). Таким образом, архитектура моей модели выглядит следующим образом:

    model = Chain(
        Dense(28^2 => 16, zrelu; init=complex_glorot_uniform),
        Dense(16 => 16, zrelu; init=complex_glorot_uniform),
        Dense(16 => 10, abslu; init=complex_glorot_uniform),
        softmax
    )

Где complex_glorot_uniform – это модифицированная версия стандартного glorot_uniform, где как действительная часть, так и мнимая часть отбираются из glorot_uniform.

Проблема имеет еще одно интересное поведение – если я тренирую модель в гамильтоновом режиме напрямую, не давая ей сначала обучаться на градиентном спуске, она ведет себя хаотично с самого начала, в отличие от случая, когда я сначала обучаю её на градиентном спуске, после чего переключаю на гамильтонов режим. Первый случай показан на рисунке ниже (рассматривается только красная линия)no gradient descent

Я могу предоставить дополнительную информацию, если это потребуется.

Ответ или решение

Когда вы рассматриваете численную точность в Flux.jl, особенно в контексте обучения нейронных сетей на основе гамильтоновой динамики, важно учитывать несколько ключевых аспектов. Ваша идея о том, что вы рассматриваете модель как систему, а обучение как временную эволюцию, весьма инновативна и заслуживает внимания. Однако, как вы заметили, возникающее со временем неустойчивое поведение потерь может свидетельствовать о проблемах с численной точностью.

Проблема численной точности

Переход на гамильтоновую динамику, где скорость обучения является чисто мнимой, предполагает, что потери должны оставаться постоянными. Тем не менее, отклонения, которые вы наблюдаете через несколько эпох, вероятно, связаны с численными эффектами. Варианты, такие как неточности в выравнивании чисел с плавающей запятой и накопление ошибок при вычислениях, могут играть ключевую роль.

Причины и решения

  1. Численная стабильность: Проблемы с численной стабильностью могут возникать из-за представления комплексных чисел и аналитических градиентов. Вам следует убедиться, что используете числовые типы с достаточной точностью. Возможно, использование Float64 или даже более высоких уровней точности в Julia может снизить ошибку округления.

  2. Алгоритмы оптимизации: Симплектические методы часто используются для интеграции гамильтоновых систем, однако они могут испытывать сложности при применении в сетях с необычными ландшафтами потерь. Рассмотрите проверку ваших симплектических обновлений и исследуйте альтернативные методы для повышения их стабильности.

  3. Инициализация весов: Изменения в инициализации, даже в комплексной модификации glorot_uniform, могут влиять на динамику обучения. Параметризация инициализации может помочь минимизировать начальные ошибки, оказывающие значительное влияние на долгосрочную точность.

Дополнительные рекомендации

  • Диагностика: Рассмотрите возможность создания подробных диагностических записей процесса, чтобы выявить шаги, на которых происходит отклонение. Это может включать промежуточные значения потерь, изменений в значениях параметров и их распределение.

  • Увеличение точности данных: В некоторых случаях перевод всего процесса в режим высокой точности может помочь. Это может включать не только обучение, но и функцию потерь, проверку численной устойчивости для всех частей вашей архитектуры.

Ваша гипотеза о том, что численная точность может быть виновником нестабильности, заслуживает дальнейшего расследования. Удачи в вашей исследовательской работе, и надеемся, что данные рекомендации помогут вам стабилизировать гамильтоново обучение в Flux.jl.

Оцените материал
Добавить комментарий

Капча загружается...