Вопрос или проблема
Я обучаю DNN с помощью CNN в Keras. Хотя я могу написать критерий EarlyStopping, основанный на val_loss, из-за незначительных колебаний в val_loss я хочу отслеживать среднюю валидационную потерю за последние n эпох и с n терпением.
Как я могу это сделать в Keras?
Вы всегда можете написать собственный callback, наследуя класс tf.keras.callbacks.Callback(). Вам нужно будет переопределить метод on_epoch_end()
с вашей собственной логикой.
.
Ответ или решение
Чтобы реализовать раннюю остановку (Early Stopping) на основе среднего значения потерь валидации (val_loss) за последние n эпох и с заданным количеством эпох терпения (patience), вы можете использовать возможности Keras для написания пользовательских обратных вызовов. Это позволит контролировать метрики обучения более гибко, что особенно актуально в случае незначительных колебаний val_loss.
Реализация пользовательского обратного вызова в Keras
Ниже представлена реализация на Python, которая поможет вам создать кастомный обратный вызов для достижения вашей цели:
import numpy as np
from tensorflow.keras.callbacks import Callback
class AverageValLossEarlyStopping(Callback):
def __init__(self, patience=5, n_last_epochs=10):
super(AverageValLossEarlyStopping, self).__init__()
self.patience = patience
self.n_last_epochs = n_last_epochs
self.wait = 0
self.stopped_epoch = 0
self.best_avg_val_loss = np.Inf
self.val_losses = []
def on_epoch_end(self, epoch, logs=None):
current_val_loss = logs.get('val_loss')
if current_val_loss is None:
raise ValueError("val_loss не найден в логах. Убедитесь, что ваша модель настроена на использование валидационного набора.")
# Обновляем список значений потерь
self.val_losses.append(current_val_loss)
if len(self.val_losses) > self.n_last_epochs:
self.val_losses.pop(0)
# Рассчитываем среднее значение потерь за последние n эпох
avg_val_loss = np.mean(self.val_losses)
# Проверка на лучшую среднюю потерю
if avg_val_loss < self.best_avg_val_loss:
self.best_avg_val_loss = avg_val_loss
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print(f'Остановка обучения на {self.stopped_epoch + 1} эпохе за счет ранней остановки.')
# Использование
# callbacks = [AverageValLossEarlyStopping(patience=5, n_last_epochs=10)]
# model.fit(..., callbacks=callbacks)
Объяснение кода
-
Инициализация класса: Мы принимаем два параметра:
patience
, который определяет, сколько эпох модель должна терпеть ухудшение метрики, иn_last_epochs
, который задает количество последних эпох для вычисления среднего значения потерь. -
Функция
on_epoch_end
: Этот метод вызывается в конце каждой эпохи. Мы обновляем список последних значенийval_loss
, вычисляем их среднее значение и проверяем, улучшилось ли среднее по сравнению с предыдущим лучшим. -
Логика остановки: Если среднее значение
val_loss
не улучшается на протяженииpatience
эпох, мы останавливаем обучение. -
Вывод в конце тренировки: Если обучение было остановлено преждевременно, выводится сообщение с указанием на какую эпоху это произошло.
Этот кастомный callback позволяет более устойчиво реагировать на колебания в метрике валидации, сохраняя обучение модели более надежным и эффективным.