Вопрос или проблема
Я создаю модель Res2Net50 для задачи верификации говорящего. Я использую AM-Softmax в качестве функции потерь с следующими параметрами:
Маржа: 0.2
Масштаб: 10
После обучения модели в течение 5 эпох на наборе данных с 1000 точками данных (каждая содержит идентификатор говорящего и соответствующий WAV-файл, обработанный в MFCC-функции), производительность выглядит странно. В частности, метрика average_accuracy всегда возвращает только 0% или 100%, что не имеет смысла.
Вот моя настройка:
(На самом деле эта настройка была после нескольких манипуляций с фреймом данных и файлами csv – я бы включил свою записную книжку, чтобы вы могли узнать больше)
Набор данных: 1000 точек данных, состоящих из идентификаторов говорящих и файла WAV. Каждый WAV-файл преобразуется в MFCC, которые являются входными данными для модели.
Модель: Res2Net50. (с размерностью встраивания = 512)
Функция потерь: AM-Softmax с маржей 0.2 и масштабом 10.
Я пытался решить проблему с AM-Softmax (например, неправильные значения масштаба или маржи – изначально я использовал масштаб = 30, затем снизил его, но это не сработало) и экспериментировал с незначительными изменениями, но проблема сохраняется. Метрики, похоже, не улучшаются, и точность либо идеальная, либо полностью неправильная.
Что я пробовал:
Проверил, что предобработка MFCC последовательна и дает разумные характеристики.
Проверил, что идентификаторы говорящих корректно закодированы.
Экспериментировал с коэффициентами обучения и настройками оптимизатора, но поведение average_accuracy не изменяется.
Итак, мой вопрос:
Что может вызывать такое необычное поведение метрики average_accuracy, и как я могу отладить или решить эту проблему?
Буду признателен за любые идеи по настройке AM-Softmax, подготовке данных или оценочным метрикам!
Записная книжка Google Colab
Вот записная книжка для справки:
Модель верификации говорящего Res2Net50.
Мой код AM-Softmax:
import torch
import torch.nn as nn
import torch.nn.functional as F
class AMSoftmax(nn.Module):
def __init__(self, embedding_dim, num_classes, margin=0.2, scale=30, **kwargs):
super(AMSoftmax, self).__init__()
self.m = margin
self.s = scale
self.in_feats = embedding_dim
self.W = torch.nn.Parameter(torch.randn(embedding_dim, num_classes), requires_grad=True)
self.ce = nn.CrossEntropyLoss()
nn.init.xavier_normal_(self.W, gain=1)
print('Инициализирован AM-Softmax m=%.3f s=%.3f'%(self.m, self.s))
def forward(self, x, label=None):
assert len(x.shape) == 3
label = label.repeat_interleave(x.shape[1]) # Повторить метки по размеру последовательности
x = x.reshape(-1, self.in_feats)
assert x.size()[0] == label.size()[0]
assert x.size()[1] == self.in_feats
device = x.device
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
x_norm = torch.div(x, x_norm)
w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
w_norm = torch.div(self.W, w_norm)
# Убедитесь, что и x_norm, и w_norm находятся на одном устройстве
w_norm = w_norm.to(device) # Переместить w_norm на то же устройство, что и x_norm, если они не на одном устройстве
# Выполнить матричное умножение
costh = torch.mm(x_norm, w_norm) # Косинусное сходство между x_norm и w_norm
label_view = label.view(-1, 1)
if label_view.is_cuda:
label_view = label_view.cpu()
# Рассеять, чтобы изменить costh
delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.m) # Добавить маржу
if x.is_cuda:
delt_costh = delt_costh.cuda()
costh_m = costh - delt_costh # Применить маржу
costh_m_s = self.s * costh_m # Масштабировать по s
loss = self.ce(costh_m_s, label) # Потеря кросс-энтропии
# Точность
prec1 = accuracy(costh_m_s.detach(), label.detach(), topk=(1,))[0]
return loss, prec1
def accuracy(output, target, topk=(1,)):
"""Вычисляет точность@k для заданных значений k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
Ответ или решение
Почему average_accuracy возвращает только 0 или 100% при использовании Res2Net50 с AM-Softmax для верификации говорящего
Использование модели Res2Net50 в сочетании с AM-Softmax для задачи верификации говорящего создает интересные возможности, однако проблема, с которой вы столкнулись — получение average_accuracy, равного либо 0%, либо 100% — вызывает серьёзные опасения. Давайте разберем возможные причины такого поведения и варианты их устранения.
1. Проблемы с данными
Ваша выборка состоит из 1000 примеров, что может быть недостаточно для обучения модели по своей сложности. Обратите внимание на следующее:
- Качество и разнообразие данных: Убедитесь, что ваши примеры представляют разнообразие говорящих. Если выборка содержит избыточные или повторяющиеся примеры одного и того же говорящего, это может привести к переобучению.
- Корректность меток: Проверьте, правильно ли закодированы идентификаторы спикеров. Неверные метки могут привести к несоответствиям между истинными метками и предсказанными, что, в свою очередь, будет порождать 0 или 100% точность.
2. Параметры обучающего процесса
Параметры AM-Softmax, такие как margin
и scale
, играют огромную роль. Эксперименты с их значениями показали, что:
- Margin 0.2: Может нарушать обучаемость, если данные не обеспечивают достаточную вариативность. Попробуйте изменить margin на большее значение для более требовательного обучения.
- Scale 10: Это значение также может быть слишком высоким или слишком низким в зависимости от вашей задачи. Попробуйте использовать диапазон от 15 до 30.
3. Архитектура и функция потерь
Ваш код для AM-Softmax выглядит хорошо, но следует обратить внимание на несколько деталей:
- Обработка нормализации: Нормализация векторов входных данных и весов — важный шаг. Убедитесь, что вы правильно нормализуете входные данные перед использованием функции потерь.
- Функция потерь: Проверьте, правильно ли формируются входные данные для функции потерь. Убедитесь, что размерности соответствуют необходимым, особенно в случае одновременного расчета потерь и точности.
4. Мониторинг обучения
Проверьте, как изменяются метрики во время обучения:
- Лоссы: Отображение потерь на протяжении эпох может помочь понять, происходит ли переобучение или нет. Если потери быстро падают до нуля, это может говорить о переобучении.
- Визуализация: Использование графиков для отслеживания точности и потерь поможет вам понять поведенческий паттерн вашей модели. Если метрики остаются статичными, вероятно, модель не обучается должным образом.
5. Stochastic Training
Использование стохастического градиентного спуска с изменяемым размером шага может помочь улучшить результаты. Измените learning rate
и протестируйте несколько оптимизаторов, таких как Adam или SGD с циклическим обучением. Проверьте, как элевированные или сниженные значения влияют на производительность.
Заключение
Приведенные выше вещи — лишь некоторые из учений, которые могут помочь вам решить проблему с 0 и 100% average_accuracy. Основное внимание следует уделить данным, параметрам модели и подходу к мониторингу обучения. Убеждено, что последовательное применение этих методов приведёт к более адекватным и стабильным результатам.
Если вы продолжаете сталкиваться с этой проблемой, раскиньте ваши подходы на обсуждение, возможно, коллаборация с коллегами поможет выявить другие аспекты, которые стоит взять на заметку.