Вопрос или проблема
У меня есть набор данных с 3 классами со следующими элементами:
- Класс 1: 900 элементов
- Класс 2: 15000 элементов
- Класс 3: 800 элементов
Мне нужно предсказать классы 1 и 3, которые сигнализируют о важных отклонениях от нормы. Класс 2 – это вариант по умолчанию, который меня не интересует.
Какую функцию потерь я могла бы использовать здесь? Я думала использовать CrossEntropyLoss, но поскольку это дисбаланс классов, это, наверное, нужно взвесить? Как это работает на практике? Например, так (используя PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
Или вес нужно инвертировать? Т.е. 1 / вес?
Это правильный подход с самого начала или есть другие/лучшие методы, которые я могла бы использовать?
Спасибо
Какую функцию потерь я могла бы использовать здесь?
Кросс-энтропия – это стандартная функция потерь для задач классификации, как балансных, так и дисбалансных. Это первый выбор, если предпочтения из области ещё не сформированы.
Это, наверное, нужно взвесить? Как это работает на практике?
Да. Вес класса $c$ – это размер наибольшего класса, делённый на размер класса $c$.
Например, если в классе 1 есть 900, в классе 2 – 15000, и в классе 3 – 800 образцов, их веса будут 16.67, 1.0 и 18.75 соответственно.
Вы также можете использовать наименьший класс в качестве числителя, что даст 0.889, 0.053 и 1.0 соответственно. Это только пересчёт, относительные веса остаются теми же.
Это правильный подход с самого начала или есть другие/лучшие
методы, которые я могла бы использовать?
Да, это правильный подход.
ИЗМЕНЕНИЕ:
Благодаря @Muppet, мы также можем использовать пересэмплинг классов, который эквивалентен использованию весов классов. Это достигается с помощью WeightedRandomSampler
в PyTorch, используя те же самые упомянутые веса.
Полностью согласен с @Esmailian
def compute_pos_weights(cls_repr: torch.Tensor) -> torch.Tensor:
total_weight = cls_repr.sum()
weights = 1/torch.div(cls_repr, total_weight)
# Стандартизация весов
return torch.div(weights, torch.min(weights))
Большинство моделей, таких как catboost/xgboost, будут работать с этим соотношением классов, нет необходимости изменять функцию потерь.
Тем не менее, если хотите увидеть небольшие улучшения, можете попробовать использовать фокусные потери/балансированные потери классов (pytorch).
Увеличение класса меньшинства и уменьшение класса большинства – это другой способ управления дисбалансом классов в самом наборе данных.
Ответ или решение
Когда вы сталкиваетесь с проблемой дисбаланса классов в задаче классификации, одна из первых вещей, которую стоит продумать, — это использование подходящего метода для обработки этого дисбаланса. В вашем случае, когда классы распределены следующим образом: Класс 1 — 900 элементов, Класс 2 — 15000 элементов и Класс 3 — 800 элементов, необходимо уделить особое внимание корректировке потерь, чтобы модель не была предвзята в сторону доминирующего класса 2, который является "нормой".
Теория
Выбор функции потерь — это ключевой элемент в обучении модели классификации. Наиболее распространенной функцией для задач многоклассовой классификации является CrossEntropyLoss
, которая хорошо работает как на сбалансированных, так и на несбалансированных наборах данных. Однако при дисбалансе классов важно взвешивать потери, чтобы скомпенсировать влияние доминирующего класса, в ином случае модель будет обучаться преимущественно на большом классе.
Взвешивание потерь в CrossEntropyLoss
можно произвести с использованием весов, которые делаются обратно пропорциональными количеству примеров в каждом классе. Эта практика позволяет учесть редкие классы и нивелировать их меньшую представленность. Вес класса (c) можно рассчитать как: (\text{вес класса c} = \frac{\text{размер самого большого класса}}{\text{размер класса c}}). В вашем случае, чтобы привнести баланс:
- Вес класса 1 = (\frac{15000}{900} \approx 16.67)
- Вес класса 2 = (\frac{15000}{15000} = 1.0)
- Вес класса 3 = (\frac{15000}{800} \approx 18.75)
Пример
На практике, в PyTorch вычисление таких весов и их применение может выглядеть следующим образом:
import torch
import torch.nn as nn
# Расчет весов
summed = 900 + 15000 + 800
class_weights = torch.tensor([15000/900, 15000/15000, 15000/800])
class_weights = class_weights / class_weights.min()
# Инициализация функции потерь с заданными весами
criterion = nn.CrossEntropyLoss(weight=class_weights)
Обратите внимание на нормализацию весов, которая сохраняет их относительные соотношения, что улучшает стабильность обучения.
Применение
Если вы хотите пойти дальше и исследовать альтернативные подходы, помимо взвешивания потерь, существует несколько других методов для работы с дисбалансом классов:
-
Переобучение на редких классах: Использование
WeightedRandomSampler
в PyTorch, позволяет повторять обучение на недостаточно представленных классах, тем самым уравновешивая тренировочное распределение за счет увеличения вероятности выбора примеров из редких классов. -
Аугментация данных: Создание новых данных на основе существующих путем модификаций, таких как вращение, изменения яркости и другие приемы, чтобы увеличивать выборки из недостаточно представленных классов.
-
Фокальная потеря (Focal Loss): Это улучшение
CrossEntropyLoss
, которое более акцентированно на сложных примерах и примерах из редких классов, что делает ее менее чувствительной к дисбалансу. -
Увеличение данных для редких классов: Еще один способ преодолеть дисбаланс — это создание дополнительных данных для редких классов, что может быть выполнено с использованием техник, таких как аугментация или генерация синтетических данных.
Понимание и корректировка дисбаланса классов — важный шаг в построении надежной модели классификации. Подход, который вы выберете, может существенно повлиять на окончательные результаты, и часто стоит попробовать несколько разных стратегий, чтобы определить, какая из них лучше всего работает для вашей конкретной задачи и данных.