Какую функцию потерь использовать для несбалансированных классов (с использованием PyTorch)?

Вопрос или проблема

У меня есть набор данных с 3 классами со следующими элементами:

  • Класс 1: 900 элементов
  • Класс 2: 15000 элементов
  • Класс 3: 800 элементов

Мне нужно предсказать классы 1 и 3, которые сигнализируют о важных отклонениях от нормы. Класс 2 – это вариант по умолчанию, который меня не интересует.

Какую функцию потерь я могла бы использовать здесь? Я думала использовать CrossEntropyLoss, но поскольку это дисбаланс классов, это, наверное, нужно взвесить? Как это работает на практике? Например, так (используя PyTorch)?

summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)

Или вес нужно инвертировать? Т.е. 1 / вес?

Это правильный подход с самого начала или есть другие/лучшие методы, которые я могла бы использовать?

Спасибо

Какую функцию потерь я могла бы использовать здесь?

Кросс-энтропия – это стандартная функция потерь для задач классификации, как балансных, так и дисбалансных. Это первый выбор, если предпочтения из области ещё не сформированы.

Это, наверное, нужно взвесить? Как это работает на практике?

Да. Вес класса $c$ – это размер наибольшего класса, делённый на размер класса $c$.

Например, если в классе 1 есть 900, в классе 2 – 15000, и в классе 3 – 800 образцов, их веса будут 16.67, 1.0 и 18.75 соответственно.

Вы также можете использовать наименьший класс в качестве числителя, что даст 0.889, 0.053 и 1.0 соответственно. Это только пересчёт, относительные веса остаются теми же.

Это правильный подход с самого начала или есть другие/лучшие
методы, которые я могла бы использовать?

Да, это правильный подход.

ИЗМЕНЕНИЕ:

Благодаря @Muppet, мы также можем использовать пересэмплинг классов, который эквивалентен использованию весов классов. Это достигается с помощью WeightedRandomSampler в PyTorch, используя те же самые упомянутые веса.

Полностью согласен с @Esmailian

def compute_pos_weights(cls_repr: torch.Tensor) -> torch.Tensor:
    total_weight = cls_repr.sum()
    weights = 1/torch.div(cls_repr, total_weight)
    # Стандартизация весов
    return torch.div(weights, torch.min(weights))

Большинство моделей, таких как catboost/xgboost, будут работать с этим соотношением классов, нет необходимости изменять функцию потерь.

Тем не менее, если хотите увидеть небольшие улучшения, можете попробовать использовать фокусные потери/балансированные потери классов (pytorch).

Увеличение класса меньшинства и уменьшение класса большинства – это другой способ управления дисбалансом классов в самом наборе данных.

Ответ или решение

Когда вы сталкиваетесь с проблемой дисбаланса классов в задаче классификации, одна из первых вещей, которую стоит продумать, — это использование подходящего метода для обработки этого дисбаланса. В вашем случае, когда классы распределены следующим образом: Класс 1 — 900 элементов, Класс 2 — 15000 элементов и Класс 3 — 800 элементов, необходимо уделить особое внимание корректировке потерь, чтобы модель не была предвзята в сторону доминирующего класса 2, который является "нормой".

Теория

Выбор функции потерь — это ключевой элемент в обучении модели классификации. Наиболее распространенной функцией для задач многоклассовой классификации является CrossEntropyLoss, которая хорошо работает как на сбалансированных, так и на несбалансированных наборах данных. Однако при дисбалансе классов важно взвешивать потери, чтобы скомпенсировать влияние доминирующего класса, в ином случае модель будет обучаться преимущественно на большом классе.

Взвешивание потерь в CrossEntropyLoss можно произвести с использованием весов, которые делаются обратно пропорциональными количеству примеров в каждом классе. Эта практика позволяет учесть редкие классы и нивелировать их меньшую представленность. Вес класса (c) можно рассчитать как: (\text{вес класса c} = \frac{\text{размер самого большого класса}}{\text{размер класса c}}). В вашем случае, чтобы привнести баланс:

  • Вес класса 1 = (\frac{15000}{900} \approx 16.67)
  • Вес класса 2 = (\frac{15000}{15000} = 1.0)
  • Вес класса 3 = (\frac{15000}{800} \approx 18.75)

Пример

На практике, в PyTorch вычисление таких весов и их применение может выглядеть следующим образом:

import torch
import torch.nn as nn

# Расчет весов
summed = 900 + 15000 + 800
class_weights = torch.tensor([15000/900, 15000/15000, 15000/800])
class_weights = class_weights / class_weights.min()

# Инициализация функции потерь с заданными весами
criterion = nn.CrossEntropyLoss(weight=class_weights)

Обратите внимание на нормализацию весов, которая сохраняет их относительные соотношения, что улучшает стабильность обучения.

Применение

Если вы хотите пойти дальше и исследовать альтернативные подходы, помимо взвешивания потерь, существует несколько других методов для работы с дисбалансом классов:

  1. Переобучение на редких классах: Использование WeightedRandomSampler в PyTorch, позволяет повторять обучение на недостаточно представленных классах, тем самым уравновешивая тренировочное распределение за счет увеличения вероятности выбора примеров из редких классов.

  2. Аугментация данных: Создание новых данных на основе существующих путем модификаций, таких как вращение, изменения яркости и другие приемы, чтобы увеличивать выборки из недостаточно представленных классов.

  3. Фокальная потеря (Focal Loss): Это улучшение CrossEntropyLoss, которое более акцентированно на сложных примерах и примерах из редких классов, что делает ее менее чувствительной к дисбалансу.

  4. Увеличение данных для редких классов: Еще один способ преодолеть дисбаланс — это создание дополнительных данных для редких классов, что может быть выполнено с использованием техник, таких как аугментация или генерация синтетических данных.

Понимание и корректировка дисбаланса классов — важный шаг в построении надежной модели классификации. Подход, который вы выберете, может существенно повлиять на окончательные результаты, и часто стоит попробовать несколько разных стратегий, чтобы определить, какая из них лучше всего работает для вашей конкретной задачи и данных.

Оцените материал
Добавить комментарий

Капча загружается...