Простановка обучения на основе среднего значения val_loss за последние десять эпох с некоторым количеством терпимости.

Вопрос или проблема

Я обучаю DNN с помощью CNN в Keras. Хотя я могу написать критерий EarlyStopping, основанный на val_loss, из-за незначительных колебаний в val_loss я хочу отслеживать среднюю валидационную потерю за последние n эпох и с n терпением.

Как я могу это сделать в Keras?

Вы всегда можете написать собственный callback, наследуя класс tf.keras.callbacks.Callback(). Вам нужно будет переопределить метод on_epoch_end() с вашей собственной логикой.

.

Ответ или решение

Чтобы реализовать раннюю остановку (Early Stopping) на основе среднего значения потерь валидации (val_loss) за последние n эпох и с заданным количеством эпох терпения (patience), вы можете использовать возможности Keras для написания пользовательских обратных вызовов. Это позволит контролировать метрики обучения более гибко, что особенно актуально в случае незначительных колебаний val_loss.

Реализация пользовательского обратного вызова в Keras

Ниже представлена реализация на Python, которая поможет вам создать кастомный обратный вызов для достижения вашей цели:

import numpy as np
from tensorflow.keras.callbacks import Callback

class AverageValLossEarlyStopping(Callback):
    def __init__(self, patience=5, n_last_epochs=10):
        super(AverageValLossEarlyStopping, self).__init__()
        self.patience = patience
        self.n_last_epochs = n_last_epochs
        self.wait = 0
        self.stopped_epoch = 0
        self.best_avg_val_loss = np.Inf
        self.val_losses = []

    def on_epoch_end(self, epoch, logs=None):
        current_val_loss = logs.get('val_loss')
        if current_val_loss is None:
            raise ValueError("val_loss не найден в логах. Убедитесь, что ваша модель настроена на использование валидационного набора.")

        # Обновляем список значений потерь
        self.val_losses.append(current_val_loss)
        if len(self.val_losses) > self.n_last_epochs:
            self.val_losses.pop(0)

        # Рассчитываем среднее значение потерь за последние n эпох
        avg_val_loss = np.mean(self.val_losses)

        # Проверка на лучшую среднюю потерю
        if avg_val_loss < self.best_avg_val_loss:
            self.best_avg_val_loss = avg_val_loss
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print(f'Остановка обучения на {self.stopped_epoch + 1} эпохе за счет ранней остановки.')

# Использование
# callbacks = [AverageValLossEarlyStopping(patience=5, n_last_epochs=10)]
# model.fit(..., callbacks=callbacks)

Объяснение кода

  1. Инициализация класса: Мы принимаем два параметра: patience, который определяет, сколько эпох модель должна терпеть ухудшение метрики, и n_last_epochs, который задает количество последних эпох для вычисления среднего значения потерь.

  2. Функция on_epoch_end: Этот метод вызывается в конце каждой эпохи. Мы обновляем список последних значений val_loss, вычисляем их среднее значение и проверяем, улучшилось ли среднее по сравнению с предыдущим лучшим.

  3. Логика остановки: Если среднее значение val_loss не улучшается на протяжении patience эпох, мы останавливаем обучение.

  4. Вывод в конце тренировки: Если обучение было остановлено преждевременно, выводится сообщение с указанием на какую эпоху это произошло.

Этот кастомный callback позволяет более устойчиво реагировать на колебания в метрике валидации, сохраняя обучение модели более надежным и эффективным.

Оцените материал
Добавить комментарий

Капча загружается...