Сдерживающие значения или сглаживание результатов при подгонке модели

Вопрос или проблема

Я работаю над обучением сети для предсказания направления прихода, и у меня возникла проблема: независимо от того, какая у меня сеть (ResNet 18 – 101, CRNN, CNN и т. д.), мои результаты склоняются к одному небольшому диапазону значений, как видно на изображении ниже

Гистограмма истинных значений и предсказаний

Это, очевидно, приводит к следующим ошибкам:

Полярная диаграмма ошибок

Я пытался просто “подождать”, пока моя сеть наконец не обучится, но моя валидационная потеря сразу же начинает увеличиваться. Пример можно увидеть ниже.

Потери на обучении и валидации

Странно, что даже моя потеря обучения не стремится к 0. Я думал, что если бы моя сеть переобучалась, она просто идеально выучила бы мой набор данных, но этого не происходит, независимо от того, насколько сложной я делаю свою модель. Единственное, что приходит мне в голову, это то, что моё представление признаков совершенно бессмысленно или у меня есть опечатка где-то в функции обучения, и происходит что-то странное.

Я пытался поиграться с функцией потерь, пробовал разные функции активации на выходе сети, такие как Tanh, сигмоид, ReLU и вообще без функции активации. В данный момент я максимально упростил свои обучающие данные и работаю с 8-канальным звуковым сигналом Chirp длительностью 1 секунда, который можно найти (по крайней мере временно) здесь: https://file.re/2021/06/20/chirp/

Как упоминалось выше, я пробовал стандартные ResNet всех размеров и различные представления признаков, последнее из которых — это комплексное STFT всех 8 каналов, где величины сложены вертикально, а информация об угле добавлена к оси X, как видно ниже:

Пример признака

Если у кого-то есть идеи, я с радостью их попробую. В данный момент я использую CNN с вертикальными свертками и без операций подвыборки в попытке сохранить временную информацию.

Мой основной метод обучения можно увидеть ниже:

def train(self):
        steps, losses, metrics = TrainingUtilities.get_training_variables(self.parameters)
        patience_counter = 0
        best_epoch_data = None
        best_epoch_validation_loss = 999
        best_epoch = 0
        exit_training = False
        try:
            for epoch in range(steps, self.epoch_count):
                epoch_metrics = TrainingUtilities.initialize_metrics(self.mode)
                if exit_training:
                    break
                for idx, phase in enumerate(['train', 'val']):
                    if phase == 'train':
                        self.model.train()
                    else:
                        self.model.eval()
                    for _, data in tqdm(enumerate(self.data_loaders[idx])):

                        self.optimizer.zero_grad()
                        inputs, labels = data

                        outputs = self.model(inputs.to(self.device))
                        # labels = azi_class.squeeze_().to(self.device)
                        loss = self.criterion(outputs.squeeze(), labels.to(self.device))

                        epoch_metrics = TrainingUtilities.get_epoch_metrics(
                            outputs, labels, loss, epoch_metrics, phase, self.mode)

                        if phase == 'train':
                            loss.backward()
                            self.optimizer.step()

                    TrainingUtilities.report_metrics(self.writer, epoch_metrics, epoch, phase, self.parameters, self.mode)
                    if phase == "val":
                        TrainingUtilities.step_scheduler(
                            self.scheduler, np.mean(epoch_metrics[0][phase]), self.parameters)
                  
                losses.append(epoch_metrics[0])
                metrics.append(epoch_metrics)
                if epoch % self.epoch_save_count == 0:
                    TrainingUtilities.save_checkpoint(self.model, losses, metrics, self.training_dir, epoch, self.mode, self.model_name, self.size)

                steps += 1

        except (KeyboardInterrupt, RuntimeError) as error:
            print(f"Ошибка: {error}")
            TrainingUtilities.save_checkpoint(self.model, losses, metrics, self.training_dir, steps, self.mode, self.model_name, self.size)

Возможно, я просто что-то упустил, но основная идея, как мне кажется, довольно проста. Я абстрагировал много вещей, так как тестировал множество различных моделей с разными наборами данных и хотел легко переключаться между классификацией и регрессией.

Я не могу получить доступ к файлу. Если вы не можете его поделиться, не могли бы вы поделиться небольшим образцом?

Несколько идей:

1- Вы пробовали нормализовать свои данные? Некоторые данные с необычными значениями нужно нормализовать, чтобы их можно было правильно обучить.

2- Также может быть проблема в обучении, возможно, переобучение. Вы применяли дропаут для увеличения обобщающей способности.

3- Если ошибка одинаковая для любой используемой вами модели, возможно, что-то не так в вашем процессе обучения: возможно, начальная трансформация данных, которую вы не применяете к данным для предсказания.

Ответ или решение

Для решения проблемы, связанной с предсказанием направления прихода сигнала (direction of arrival prediction) и стагнацией результата вашей сети, важно рассмотреть несколько аспектов, которые могут способствовать получению результатов в узком диапазоне значений, а также привести к проблемам с обучением и валидацией.

### 1. Нормализация данных
Первое, что стоит проверить, это нормализация данных. Если ваши входные данные имеют большие диапазоны значений или нестандартные распределения, это может привести к проблемам с обучением модели. Рекомендуется нормализовать ваши данные, чтобы они располагались в пределах [0, 1] или [-1, 1]. Вы можете использовать стандартные методы нормализации, такие как Min-Max Scaling или Z-score Normalization.

### 2. Проверка функции потерь
Также убедитесь, что ваша функция потерь адекватна для вашей задачи. Если вы решаете задачу регрессии, используйте такие функции потерь, как MSE (Mean Squared Error) или MAE (Mean Absolute Error). Если ваша задача классификации, стоит рассмотреть кросс-энтропию. Также проверьте, правильно ли обрабатываются метки целевого значения (labels).

### 3. Регуляризация и обрезка
Если вы подозреваете переобучение (overfitting), добавление регуляризации, такой как Dropout, поможет вашей модели лучше обобщать данные. Это особенно важно, если у вас небольшой набор данных. Попробуйте разные уровни Dropout и посмотрите, как это влияет на вашу модель.

### 4. Архитектура модели
Попробуйте упростить архитектуру вашей сети, если у вас есть сложности с обучением. Сложные модели, такие как ResNet, требуют значительного объема данных для обучения. В начале использования простых моделей (например, базового CNN) может дать более устойчивые результаты, и в дальнейшем вы сможете постепенно усложнять модель.

### 5. Инициализация весов
Убедитесь, что вы используете эффективную инициализацию весов в вашей модели. Плохая инициализация может привести к мертвым нейронам (например, если вы используете ReLU) и ухудшению сходимости модели.

### 6. Обратная связь и мониторинг
Необходимо внимательно следить за процессом обучения (train and validation loss). Ваш график потерь говорит о том, что модель не только не сходится, но и валидационная потеря diverges. Это может указывать на неправильную настройку гиперпараметров, особенно скорости обучения (learning rate). Попробуйте изменить её, используя метод, такие как пространственный поиск (grid search) или метод случайного поиска (random search).

### 7. Проверка датасета
Если вы тестируете алгоритм на разных архитектурах, имеет смысл проверить ваш датасет и убедиться, что он содержит достаточно разнообразия и правильно размечен. Проверьте, нет ли ошибок в предобработке или разделении данных на обучающую и валидационную выборки.

### 8. Правильная трансформация данных
Убедитесь, что преобразования данных, применяемые к обучающим данным, также применяются к тестовым данным. Неправильная предобработка может привести к несоответствию между обучающими и тестовыми данными.

### Заключение
Если вы уже попробовали все вышеперечисленные рекомендации и проблема остается нерешенной, возможно, стоит рассмотреть возможность использования других методов предобучения или более экспериментального подхода к архитектурами сети. Возможно, стоит изучить такие техники, как Transfer Learning, особенно если у вас есть доступ к подобным задачам с предварительно обученными моделями.

Оцените материал
Добавить комментарий

Капча загружается...