Почему функция активации “tanh” работает лучше с Pytorch, чем с Keras.

Вопрос или проблема

Итак, я создаю нейронную сеть, которая должна распознавать написанные кириллические буквы, и я обнаружил, что по какой-то причине, когда я использую функцию tanh, она работает значительно лучше с PyTorch, чем с Keras.

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Lambda, BatchNormalization
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

a_param = 0.5
activation_type="tanh"

custom_activation = create_activation_function(a_param, activation_type)

input_shape = (100, 100, 4)

model = Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=input_shape))
model.add(Flatten())

model.add(Dense(784))
model.add(BatchNormalization())
model.add(Lambda(custom_activation))

model.add(Dense(30))
model.add(BatchNormalization())
model.add(Lambda(custom_activation))

model.add(Dense(10))
model.add(Lambda(custom_activation))

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss="categorical_crossentropy",
              metrics=['accuracy'])

model.summary()

epochs = 30
batch_size = 128

lr_schedule = ReduceLROnPlateau(
    monitor="val_loss", factor=0.5, patience=3, min_lr=1e-6
)

early_stopping = EarlyStopping(
    monitor="val_loss", patience=3, restore_best_weights=True
)

history = model.fit(
    X_train, y_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[lr_schedule, early_stopping]
)

Код выше написан с использованием Keras

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 100*100 изображения с 4 каналами (RGBA)
        self.fc1 = nn.Linear(100 * 100 * 4, 128)  # Входной слой для изображений 100x100 с 4 каналами
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        # Изменение формы тензора в вектор длины 100 * 100 * 4
        x = x.view(-1, 100 * 100 * 4)
        x = torch.tanh(self.fc1(x))  # Использование tanh
        x = torch.tanh(self.fc2(x))
        x = self.fc3(x)
        return torch.tanh(x)  # Применение tanh на выходе

Этот код написан с использованием Pytorch

Когда Keras дает около +/- 10%, Pytorch значительно выше — до 90%. В чем может быть причина? Я сделал что-то не так или это должно работать именно так?

Ответ или решение

Почему функция активации "tanh" работает лучше в PyTorch, чем в Keras

При разработке нейронных сетей для распознавания письменных кириллических букв вы заметили, что функция активации "tanh" демонстрирует значительно лучшие результаты в PyTorch по сравнению с Keras. Давайте детально рассмотрим возможные причины этого явления.

1. Структура и внутреннее устройство фреймворков

PyTorch и Keras (который является частью TensorFlow) имеют различные архитектуры и внутренние механизмы выполнения, что может повлиять на производительность нейронной сети:

  • Динамическое вычисление графа в PyTorch позволяет более эффективно управлять памятью и оптимизировать маршруты распространения ошибок во время обратного распространения. Это может улучшить эффективность обучения, особенно для сложных моделей.

  • TensorFlow (в том числе Keras) использует статический вычислительный граф, который требует предварительного определения структуры модели. Это может ограничивать гибкость в некоторых случаях и снижать производительность при динамическом изменении параметров модели.

2. Инициализация весов и нормализация

Эффективность нейронной сети во многом зависит от метода инициализации весов и нормализации. В вашем коде Keras используется слой BatchNormalization, который может влиять на распределение активаций в сети:

  • В PyTorch вы не используете BatchNormalization, а только применяете активацию "tanh" сразу после линейных слоев. Это может привести к лучшему распространению градиента и более эффективному обучению. В Keras, добавление BatchNormalization может изменять статистику активаций, что в свою очередь может влиять на эффективность работы активационной функции "tanh".

3. Параметры обучения и настройка гиперпараметров

Параметры обучения, такие как скорость обучения и алгоритм оптимизации, играют ключевую роль в успешности обучения модели. В вашем примере, вы используете Adam в Keras с фиксированными параметрами:

  • Проверьте, как настройки скорости обучения и других оптимизаторов в Keras могут влиять на скорость сходимости. Возможно, случайная инициализация в одном из запусков в Keras привела к более худшему началу и колебаниям, в отличие от более стабильного поведения в PyTorch.

4. Характеристики данных

Не стоит недооценивать влияние самих данных на эффективность модели. В вашем случае к таким характеристикам относятся:

  • Размер выборки: Если в PyTorch вы использовали другую выборку данных или изменили предобработку изображений, это может объяснить значительный прирост точности. Проверьте, совпадают ли размеры и предобработка данных при запуске модели в Keras и PyTorch.

  • Шум и искажения: Если ваши данные содержат шум или искажения, то различия в том, как каждый фреймворк обрабатывает данные на этапе обучения, могут приводить к различным результатам.

5. Правильность реализации

Важно гарантировать, что архитектуры моделей в PyTorch и Keras действительно аналогичны. Ко всему прочему, существует вероятность, что в Keras вы могли сделать ошибку при определении модели или в параметрах вызовов функций.

Заключение

Причины, по которым функция активации "tanh" работает лучше в PyTorch, чем в Keras, множество. Это касается как особенностей архитектуры фреймворков, так и параметров обучения, обработанной информации и особенностей реализации кода. Рекомендуется тщательно проанализировать и сопоставить код обеих реализаций и провести несколько экспериментов по изменению гиперпараметров и способов обработки данных. Вы можете добиться повышенной производительности вашей модели независимо от предпочтительного фреймворка.

Оцените материал
Добавить комментарий

Капча загружается...