Керас ранняя остановка по цели

Вопрос или проблема

Мне действительно трудно понять, как параметры обратных вызовов ранней остановки Keras работают, особенно в присутствии базового уровня. Что я хочу, так это просто остановить обучение в течение 2 эпох после того, как точность валидации достигнет 95%. Поэтому я пробую:

(trainX, testX, trainY, testY) = train_test_split(Tensor, Labels1Hot, test_size=0.2)
when2stop = EarlyStopping(mode="max", monitor="val_accuracy", verbose=1, patience=2, baseline=0.95)
history = model.fit(trainX, trainY, epochs = 100, batch_size = 500, validation_data = (testX, testY), callbacks=when2stop)

Это просто останавливается после всего лишь 2 эпох, даже если точность валидации на самом деле улучшилась. Но точность валидации значительно ниже 0.95. Это ошибка или я неправильно понимаю настройки базового уровня и терпимости?

Используя: Tensorflow 2.4

Комментарий выше к посту на SO точен. Пользовательский обратный вызов, как указано в этом посте, является решением, и термин ‘базовый уровень’ не должен интерпретироваться как порог.

Ответ или решение

Понимание механизма работы обратных вызовов ранней остановки в Keras, особенно в контексте применения базового значения, может быть непростой задачей, особенно для неопытных пользователей. В данном ответе мы детально рассмотрим, как работают различные параметры, включая baseline, patience и другие связанные аспекты, которые могут повлиять на остановку обучения вашей модели.

Параметры обратного вызова ранней остановки

  1. mode: Этот параметр определяет, как Keras будет интерпретировать значения мониторинга. Если mode='max', это значит, что Keras будет ожидать увеличения значения метрики. В вашем случае, с метрикой val_accuracy, модель будет искать максимальное значение точности.

  2. monitor: Здесь вы указали "val_accuracy", что позволяет Keras отслеживать точность на валидационной выборке. Это правильный выбор, если вы хотите, чтобы модель могла завершить обучение в зависимости от ее производительности на тестовых данных.

  3. patience: Этот параметр указывает количество эпох, которое Keras подождет, прежде чем остановить обучение после того, как не будет наблюдаться улучшение метрики, указанной в параметре monitor. В вашем случае, вы задали значение 2, что означает, что Keras будет ждать 2 эпохи для улучшения val_accuracy перед остановкой.

  4. baseline: Важно понять, что baseline не интерпретируется как пороговое значение для остановки. Вместо этого, оно служит эталоном, с которым Keras будет сравнивать. Если производительность модели не превышает указанное значение baseline в течение patience эпох, обучение будет остановлено.

Почему ваша модель останавливается слишком рано?

В вашем коде указано значение baseline=0.95. Если val_accuracy не превышает это значение хотя бы раз в течение двух эпох, модель будет остановлена. Если значения вашего val_accuracy находятся существенно ниже 0.95, даже если они улучшаются, это объясняет преждевременную остановку.

Решение проблемы: Использование пользовательского обратного вызова

Как указано в комментарии к вашему вопросу, более целесообразным подходом будет использование пользовательского обратного вызова. Это позволит вам настроить логику остановки на основе конкретных требований. Ниже приведен пример простого пользовательского обратного вызова для достижения вашей цели:

import tensorflow as tf

class CustomEarlyStopping(tf.keras.callbacks.Callback):
    def __init__(self, target_accuracy, patience=2):
        super(CustomEarlyStopping, self).__init__()
        self.target_accuracy = target_accuracy
        self.patience = patience
        self.wait = 0
        self.best = 0.0

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get('val_accuracy')
        if current is not None:
            if current >= self.target_accuracy:
                print(f"\nStopping training: reached {self.target_accuracy} validation accuracy.")
                self.model.stop_training = True
            elif self.best < current:
                self.wait = 0
                self.best = current
            else:
                self.wait += 1
                if self.wait >= self.patience:
                    print(f"\nStopping training: patience exceeded {self.patience} epochs.")
                    self.model.stop_training = True

Использование пользовательского обратного вызова

Теперь вы можете использовать ваш новый обратный вызов:

when2stop = CustomEarlyStopping(target_accuracy=0.95, patience=2)
history = model.fit(trainX, trainY, epochs=100, batch_size=500, validation_data=(testX, testY), callbacks=[when2stop])

Заключение

Правильное понимание и использование параметров обратного вызова ранней остановки в Keras может значительно повлиять на обучение вашей модели. Использование пользовательских решений может предоставить вам необходимую гибкость для достижения ваших целей. Не забывайте, что всегда полезно тестировать различные настройки, чтобы увидеть, что лучше всего подходит для вашей конкретной задачи и данных.

Оцените материал
Добавить комментарий

Капча загружается...