Является ли это подходящим способом для расчета диаграммы надежности многоклассовой модели для калибровки?

Вопрос или проблема

Я пытаюсь обобщить диаграммы надежности [1] для многоклассового классификатора и реализовать это с использованием pytorch и pytorch-metrics.

Пока все хорошо, но я немного запутан в определении точности и том, как это применимо к пересечению нескольких классов и нескольких диапазонов уверенности, позвольте мне проиллюстрировать это на примере:

Предположим, у меня есть 3 класса A, B и C (который есть только для того, чтобы убедиться, что проблема не связана с двоичным классификатором). Для примера моя модель всегда выдает уверенность 0 для класса C.

Предположим, что я вижу 20 образцов, где:

  • 10 образцов классифицированы как A с уверенностью 0.8 и их истинное значение на самом деле A
  • 9 образцов классифицированы как B с уверенностью 0.8 и их истинное значение на самом деле B
  • 1 образец классифицирован как B с уверенностью 0.8, но его истинное значение A

Предположим, я рисую многоклассовую диаграмму надежности с 2 диапазонами: [0.0-0.5) и [0.5-1].
Мой текущий код выдал бы следующее:
Класс A: [0, 1]
Класс B: [0, 9/10]

Это кажется мне верным, так как определение точности по диапазону уверенности от Guo и др. гласит: “точность B_m равна”:
acc(B_m) = 1/B_m sum(1 * (ŷ_i = y_i)) for i in B_m

Но меня очень шокирует приписывать точность 1 к классу A, зная, что у него было ложное срабатывание. Я пытался рационализировать это, говоря себе, что все ложные срабатывания уже учтены как ложные положительные для другого класса, но я не уверен, что это имеет смысл. Также нужно учитывать, что это ложное срабатывание должно быть учтено в диапазоне 0.0-0.5… и это действительно так, поскольку оно способствует увеличению количества образцов для этого диапазона, но не надежности.

Поэтому мои вопросы заключаются в следующем:

  • Является ли это правильным способом расчета диаграммы надежности для каждого класса, когда существует множество различных классов?
  • Существует ли какая-либо литература на эту тему?
  • Если нет, является ли это разумным способом сделать это?

Это мой текущий код, если вам интересно:

from typing import Any, List, Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.metric import Metric

class MulticlassReliabilityDiagram(Metric):
    r"""Вычисляет диаграмму надежности для задач классификации.

    Диаграмма надежности отображает точность по оси Y в интервалах уверенности, разделенных на несколько диапазонов по оси X.

    Диаграммы надежности полезны для визуализации ошибки калибровки классификатора по всем диапазонам уверенности,
    захватывая тонкую информацию о калибровке в легко интерпретируемом графике.

    В качестве входа для ``forward`` и ``update`` метрика принимает следующие входные данные:

    - ``preds`` (:class:`~torch.Tensor`): Плавающий тензор формы ``(N, C, ...)``. preds должен быть тензором, содержащим
      вероятности или логиты для каждого наблюдения. Если preds имеет значения, выходящие за пределы диапазона [0,1], мы считаем входные данные логитами и автоматически применим softmax для каждого образца.
    - ``target`` (:class:`~torch.Tensor`): Целевой тензор формы ``(N, ...)``. target должен быть тензором, содержащим
      истинные метки, и поэтому должен содержать только значения в диапазоне [0, n_classes-1] (если не указано `ignore_index`).

    .. примечание::
       Дополнительное измерение ``...`` будет сплющено в размерность батча.

    В качестве результата для ``forward`` и ``compute`` метрика возвращает:
     - ``reliability`` (:class:`~torch.Tensor`): тензор, содержащий гистограмму надежности
     - ``frequency`` (:class:`~torch.Tensor`): тензор, содержащий частоту, наблюдаемую для каждого отдельного интервала уверенности
     - ``class_reliability`` (:class:`~torch.Tensor`): тензор, содержащий гистограмму надежности по классам, игнорируемый индекс не *влияет* 
        на этот вывод, чтобы сохранить ту же нумерацию для классов, вам нужно будет вручную игнорировать индекс при использовании этого тензора
     - ``class_frequency`` (:class:`~torch.Tensor`): тензор, содержащий частоту, наблюдаемую для каждого отдельного интервала уверенности по классам

    Аргументы:
        num_classes: Целое число, указывающее количество классов
        bins: Целое число, указывающее количество диапазонов, на которые нужно разделить область уверенности, по умолчанию: 10
        kwargs: Дополнительные именованные аргументы, см. :ref:`Metric kwargs` для получения дополнительной информации.

    """
    is_differentiable: bool = False
    higher_is_better: Optional[bool] = None
    full_state_update: bool = False

    preds: List[Tensor]
    target: List[Tensor]

    def __init__(
        self,
        num_classes: int,
        bins: Optional[int] = None,
        ignore_index: Optional[int] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)

        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.bins = bins or 10

        shape = (self.bins,)
        # Сохраняет частоту каждого диапазона
        self.add_state(
            "frequency",
            default=torch.zeros(shape, dtype=torch.int64),
            dist_reduce_fx="sum",
        )
        # Сохраняет частоту успехов каждого диапазона
        self.add_state(
            "success",
            default=torch.zeros(shape, dtype=torch.int64),
            dist_reduce_fx="sum",
        )

        shape_classes = (self.num_classes, self.bins)
        # Сохраняет частоту каждого диапазона по классам
        self.add_state(
            "class_frequency",
            default=torch.zeros(shape_classes, dtype=torch.int64),
            dist_reduce_fx="sum",
        )
        # Сохраняет частоту успехов каждого диапазона по классам
        self.add_state(
            "class_success",
            default=torch.zeros(shape_classes, dtype=torch.int64),
            dist_reduce_fx="sum",
        )

    def update(self, preds: Tensor, target: Tensor) -> None:
        """Обновляет состояния метрики."""
        preds, target = _multiclass_reliability_diagram_format(
            preds, target, self.num_classes, self.ignore_index
        )
        confidences, classes = torch.max(preds, dim=1)

        # ЗАМЕТКА: Нет детерминированной реализации histc, это не очень актуально
        # поскольку это используется для генерации диаграммы надежности, но это заставляет нас
        # отключить детерминизм для этой функции, см. также следующее появление через несколько строк
        # далее
        torch.use_deterministic_algorithms(False)

        # Обновите частоты и количество успехов как общие, так и по классам
        freq_histogram = torch.histc(preds, bins=self.bins, min=0, max=1)
        self.frequency = torch.add(self.frequency, freq_histogram)
        succ_idx = classes == target
        succ_confidences = confidences[succ_idx]
        success_histogram = torch.histc(succ_confidences, bins=self.bins, min=0, max=1)
        self.success = torch.add(self.success, success_histogram)

        class_freq_histogram = torch.zeros(
            (self.num_classes, self.bins), device=self.device
        )
        class_succ_histogram = torch.zeros(
            (self.num_classes, self.bins), device=self.device
        )
        for index in range(self.num_classes):
            class_freq_histogram[index] = torch.histc(
                preds[:, index], bins=self.bins, min=0, max=1
            )
            class_succ_idx = torch.logical_and(succ_idx, classes == index)
            class_succ_confidences = confidences[class_succ_idx]
            class_succ_histogram[index] = torch.histc(
                class_succ_confidences, bins=self.bins, min=0, max=1
            )
        self.class_frequency = torch.add(self.class_frequency, class_freq_histogram)
        self.class_success = torch.add(self.class_success, class_succ_histogram)

        torch.use_deterministic_algorithms(True)

    def compute(
        self,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """Вычисляет метрику."""
        accuracy = _safe_divide(self.success, self.frequency)
        frequency = _safe_divide(self.frequency, self.frequency.sum())
        class_accuracy = _safe_divide(self.class_success, self.class_frequency)
        class_frequency = _safe_divide(
            self.class_frequency, self.class_frequency.sum(dim=1, keepdim=True)
        )
        return (accuracy, frequency, class_accuracy, class_frequency)

def _safe_divide(num: Tensor, denom: Tensor) -> Tensor:
    """Безопасное деление, предотвращающее деление на ноль.

    Дополнительно приводит к плавающему типу, если входные данные еще не являются таковыми, чтобы обеспечить обратную совместимость.

    """
    num = num if num.is_floating_point() else num.float()
    denom = denom if denom.is_floating_point() else denom.float()
    return num / torch.where(denom == 0.0, 1.0, denom)

def _multiclass_reliability_diagram_format(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
    """Преобразует все входные данные в правильный формат.

    - сплющивает дополнительные размеры
    - Удаляет все точки данных, которые должны быть игнорированы
    - Применяет softmax, если тензор предсказаний не в диапазоне [0,1]

    """
    preds = preds.transpose(0, 1).reshape(num_classes, -1).T
    target = target.flatten()

    if ignore_index is not None:
        idx = target != ignore_index
        preds = preds[idx]
        target = target[idx]

    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.softmax(1)

    return preds, target

Я также написал несколько тестов, чтобы проверить, что это правильно:

import math
import torch
from ..src.reliability_diagram import MulticlassReliabilityDiagram

def array_is_close(array1, array2, index) -> bool:
    return math.isclose(array1[index], array2[index])

def test_all_right():
    # Размер батча: 2
    # Количество классов: 3
    # Ширина, высота: 2, 2
    target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64)

    prediction = torch.tensor(
        [
            # батч 0
            [
                # класс 0
                [[0.6, 0.2], [0.2, 0.3]],
                # класс 1
                [[0.1, 0.7], [0.7, 0.2]],
                # класс 2
                [[0.3, 0.1], [0.1, 0.5]],
            ],
            # батч 1
            [
                # класс 0
                [[0.3, 0.2], [0.6, 0.6]],
                # класс 1
                [[0.2, 0.7], [0.1, 0.1]],
                # класс 2
                [[0.5, 0.1], [0.3, 0.3]],
            ],
        ],
        dtype=torch.float64,
    )

    # диапазоны = 4 0, 0.25, 0.50, 0.75
    expected_reliability = [0, 0, 1, 0]
    expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0]
    expected_class_reliability = [
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
    ]
    expected_class_frequency = [
        [3 / 8, 2 / 8, 3 / 8, 0],
        [5 / 8, 0 / 8, 3 / 8, 0],
        [3 / 8, 3 / 8, 2 / 8, 0],
    ]

    num_classes = 3
    num_bins = 4

    metric = MulticlassReliabilityDiagram(num_classes, num_bins)

    metric.update(prediction, target)
    reliability, frequency, class_reliability, class_frequency = metric.compute()

    for index in range(len(frequency)):
        assert array_is_close(frequency.tolist(), expected_frequency, index)

    for index in range(len(reliability)):
        assert array_is_close(reliability.tolist(), expected_reliability, index)

    for class_number in range(num_classes):
        for index in range(len(frequency)):
            assert array_is_close(
                class_frequency[class_number].tolist(),
                expected_class_frequency[class_number],
                index,
            )

        for index in range(len(reliability)):
            assert array_is_close(
                class_reliability[class_number].tolist(),
                expected_class_reliability[class_number],
                index,
            )

def test_all_wrong():
    # Размер батча: 2
    # Количество классов: 3
    # Ширина, высота: 2, 2
    target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64)

    prediction = torch.tensor(
        [
            # батч 0
            [
                # класс 0
                [[0.1, 0.2], [0.2, 0.5]],
                # класс 1
                [[0.6, 0.1], [0.1, 0.2]],
                # класс 2
                [[0.3, 0.7], [0.7, 0.3]],
            ],
            # батч 1
            [
                # класс 0
                [[0.5, 0.2], [0.1, 0.1]],
                # класс 1
                [[0.2, 0.1], [0.6, 0.6]],
                # класс 2
                [[0.3, 0.7], [0.3, 0.3]],
            ],
        ],
        dtype=torch.float64,
    )

    # диапазоны = 4 0, 0.25, 0.50, 0.75
    expected_reliability = [0, 0, 0, 0]
    expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0]
    expected_class_reliability = [
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
    ]
    expected_class_frequency = [
        [6 / 8, 0, 2 / 8, 0],
        [5 / 8, 0, 3 / 8, 0],
        [0, 5 / 8, 3 / 8, 0],
    ]

    num_classes = 3
    num_bins = 4

    metric = MulticlassReliabilityDiagram(num_classes, num_bins)

    metric.update(prediction, target)
    reliability, frequency, class_reliability, class_frequency = metric.compute()

    for index in range(len(frequency)):
        assert array_is_close(frequency.tolist(), expected_frequency, index)

    for index in range(len(reliability)):
        assert array_is_close(reliability.tolist(), expected_reliability, index)

    for class_number in range(num_classes):
        for index in range(len(frequency)):
            assert array_is_close(
                class_frequency[class_number].tolist(),
                expected_class_frequency[class_number],
                index,
            )

        for index in range(len(reliability)):
            assert array_is_close(
                class_reliability[class_number].tolist(),
                expected_class_reliability[class_number],
                index,
            )

def test_mixed():
    # Размер батча: 2
    # Количество классов: 3
    # Ширина, высота: 2, 2
    target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64)

    prediction = torch.tensor(
        [
            # батч 0
            [
                # класс 0
                [[0.6, 0.2], [0.2, 0.3]],
                # класс 1
                [[0.1, 0.7], [0.7, 0.2]],
                # класс 2
                [[0.3, 0.1], [0.1, 0.5]],
            ],
            # батч 1
            [
                # класс 0
                [[0.5, 0.2], [0.6, 0.6]],
                # класс 1
                [[0.2, 0.7], [0.1, 0.1]],
                # класс 2
                [[0.3, 0.1], [0.3, 0.3]],
            ],
        ],
        dtype=torch.float64,
    )

    # диапазоны = 4 0, 0.25, 0.50, 0.75
    expected_reliability = [0, 0, 7 / 8, 0]
    expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0]
    expected_class_reliability = [
        [0, 0, 3 / 4, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
    ]
    expected_class_frequency = [
        [3 / 8, 1 / 8, 4 / 8, 0],
        [5 / 8, 0 / 8, 3 / 8, 0],
        [3 / 8, 4 / 8, 1 / 8, 0],
    ]

    num_classes = 3
    num_bins = 4

    metric = MulticlassReliabilityDiagram(num_classes, num_bins)

    metric.update(prediction, target)
    reliability, frequency, class_reliability, class_frequency = metric.compute()

    for index in range(len(frequency)):
        assert array_is_close(frequency.tolist(), expected_frequency, index)

    for index in range(len(reliability)):
        assert array_is_close(reliability.tolist(), expected_reliability, index)

    for class_number in range(num_classes):
        for index in range(len(frequency)):
            assert array_is_close(
                class_frequency[class_number].tolist(),
                expected_class_frequency[class_number],
                index,
            )

        for index in range(len(reliability)):
            assert array_is_close(
                class_reliability[class_number].tolist(),
                expected_class_reliability[class_number],
                index,
            )

def test_mixed_10_bins():
    # Размер батча: 2
    # Количество классов: 3
    # Ширина, высота: 2, 2
    target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64)

    prediction = torch.tensor(
        [
            # батч 0
            [
                # класс 0
                [[0.1, 0.2], [0.2, 0.5]],
                # класс 1
                [[0.6, 0.1], [0.1, 0.2]],
                # класс 2
                [[0.3, 0.7], [0.7, 0.3]],
            ],
            # батч 1
            [
                # класс 0
                [[0.3, 0.2], [0.1, 0.1]],
                # класс 1
                [[0.2, 0.1], [0.6, 0.6]],
                # класс 2
                [[0.5, 0.7], [0.3, 0.3]],
            ],
        ],
        dtype=torch.float64,
    )

    # диапазоны = 10
    expected_reliability = [0, 0, 0, 0, 0, 1 / 2, 0, 0, 0, 0]
    expected_frequency = [
        0,
        6 / 24,
        5 / 24,
        5 / 24,
        0,
        2 / 24,
        3 / 24,
        3 / 24,
        0,
        0,
    ]

    num_classes = 3
    num_bins = 10

    metric = MulticlassReliabilityDiagram(num_classes, num_bins)

    metric.update(prediction, target)
    reliability, frequency, _class_reliability, _class_frequency = metric.compute()

    for index in range(len(frequency)):
        assert array_is_close(frequency.tolist(), expected_frequency, index)

    for index in range(len(reliability)):
        assert array_is_close(reliability.tolist(), expected_reliability, index)

Я думаю, что выход для класса A в вашем случае должен быть [1/10, 1]. Это потому, что у вас есть 9 образцов, которые с уверенностью 0.2 предсказывают A (но их истинное значение на самом деле B) и 1 образец, который с уверенностью 0.2 предсказывает A (и его истинное значение также A). Так что это 1/(1+9)=1/10.

Одна из ссылок, которая помогла мне понять это, это https://towardsdatascience.com/introduction-to-reliability-diagrams-for-probability-calibration-ed785b3f5d44

Ответ или решение

Чтобы правильно рассчитать диаграмму надежности для многоклассового классификатора, необходимо учитывать несколько факторов, связанных с определением точности и обработкой классов. Давайте обсудим ваш подход и внесем некоторые уточнения.

1. Правильность расчета диаграмм надежности

Вы правильно основываетесь на методе Guo et al., который описывает, как рассчитывать точность для каждой классификационной модели, но важно четко понимать, как интерпретировать полученные значения.

В вашем примере с классами A, B и C, действительно возникает путаница с учётом ложных отрицательных и ложных положительных значений, поскольку точность для класса A по вашей логике может казаться равной 1. Однако, как вы уже заметили, один из образцов с предсказанной меткой A на самом деле относится к классу B, что приводит к ложному положительному результату для этого класса.

2. Точность и двусмысленность

Когда вы говорите о точности для класса A, вы правильно подметили, что в этом случае 1 из 10 образцов действительно является истинным положительным (обычно мы отчитываем по истине, а не по предсказаниям). Однако, ваше предположение о высчитывании точности как 1/(1+9) равно 1/10 было бы уместным, если бы вы считали класс A и B одновременно.

В вашем коде точность для класса A не учитывает предсказания для класса B, и это становится причиной недопонимания, который вы испытываете. Для более четкой диаграммы надежности всё же стоит учитывать не только правильные предсказания, но и то, как они повлияли на общую калibriруемость модели.

3. Согласно литературе

Учитывая недавние статьи и исследования, такие как ваша ссылка на работу Guo et al., мы видим, что для многоклассовой классификации нельзя рассматривать каждую метку изолированно. Вместо этого предложено учитывать взаимосвязи между классами.

Также вы можете найти полезные ресурсы, такие как Towards Data Science, которые дадут более подробные примеры и пояснения, как правильно строить диаграммы надежности, особенно в многоклассовых задах.

4. Ваше решение

Согласно вашему коду, вы можете внести несколько изменений, чтобы учесть ложные предсказания:

  • Для каждого класса считайте как количество истинных положительных, так и количество ложных положительных.
  • При расчете точности не забывайте общее количество поданных предсказаний, чтобы это не отразилось на вашей диаграмме надежности.

Ваш подход к расчетам в коде работает, но учтите также, что есть некоторые нюансы в детализации расчета, чтобы вы могли использовать данные более эффективно.

Заключение

Ваша идея привлечь внимание к многоклассовой диаграмме надежности очень актуальна и важна для понимания работы машинного обучения. Применив рекомендации, вы сможете улучшить вашу реализацию и внести ясность в задачу калибровки, что обеспечит более качественный выход предсказаний.

Оцените материал
Добавить комментарий

Капча загружается...