scipy bootstrap генерирует входные данные с неконсистентным количеством образцов

Вопрос или проблема

У меня есть набор данных из 77 образцов, и я использую scipy bootstrap, чтобы получить доверительный интервал для оценки точности. Я озадачен тем, что вижу, как он генерирует входные переменные с неконсистентным количеством образцов, и подозреваю, что делаю что-то неправильно. Он начинается с pandas dataframe с колонкой “label” (1 или 0) и колонкой “test” (1 или 0). Нет NaN, есть 77 образцов, все значения в label и test – это int.

Затем я изначально запустил:

import pandas as pd
from sklearn.metrics import precision_score
from scipy.stats import bootstrap

prec = precision_score(df["labels"], df["test"]) # работает как ожидалось
res = bootstrap((df["labels"], df["test"]), precision_score, n_resamples=1000)

который в конечном итоге приводит к:

Traceback (most recent call last):
  File "/home/wdecoster/chr15q14/scripts/precision_recall.py", line 40, in <module>
    main()
  File "/home/wdecoster/chr15q14/scripts/precision_recall.py", line 25, in main
    res = bootstrap((df["labels"], df["test"]), precision_score, n_resamples=1000)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/_lib/_util.py", line 440, in wrapper
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/stats/_resampling.py", line 657, in bootstrap
    interval = _bca_interval(data, statistic, axis=-1, alpha=alpha,
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/stats/_resampling.py", line 130, in _bca_interval
    theta_hat_i.append(statistic(*broadcasted, axis=-1))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/stats/_resampling.py", line 38, in stat_nd
    return np.apply_along_axis(stat_1d, 0, z)[()]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/numpy/lib/shape_base.py", line 379, in apply_along_axis
    res = asanyarray(func1d(inarr_view[ind0], *args, **kwargs))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/stats/_resampling.py", line 36, in stat_1d
    return statistic(*data)
           ^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 216, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/metrics/_classification.py", line 2247, in precision_score
    p, _, _, _ = precision_recall_fscore_support(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 189, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/metrics/_classification.py", line 1830, in precision_recall_fscore_support
    labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/metrics/_classification.py", line 1596, in _check_set_wise_labels
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/metrics/_classification.py", line 98, in _check_targets
    check_consistent_length(y_true, y_pred)
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/utils/validation.py", line 475, in check_consistent_length
    raise ValueError(
ValueError: Found input variables with inconsistent numbers of samples: [76, 77]

В попытке отладить это, я написал обертку вокруг функции precision_score, как показано ниже:


def precision(y_true, y_pred):
    print(len(y_true), len(y_pred))
    return precision_score(y_true, y_pred)

res = bootstrap((df["labels"], df["test"]), precision, n_resamples=1000)

При запуске это выводит “77 77” снова и снова, как должно быть, но действительно, прямо перед ошибкой, он выводит “76 77”:

(...)
77 77
77 77
77 77
77 77
77 77
77 77
77 77
77 77
77 77
77 77
77 77
77 77
77 77
76 77
Traceback (most recent call last):
  File "/home/wdecoster/chr15q14/scripts/precision_recall.py", line 40, in <module>
    main()
  File "/home/wdecoster/chr15q14/scripts/precision_recall.py", line 25, in main
    res = bootstrap((df["labels"], df["test"]), precision, n_resamples=1000)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/_lib/_util.py", line 440, in wrapper
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/stats/_resampling.py", line 657, in bootstrap
    interval = _bca_interval(data, statistic, axis=-1, alpha=alpha,
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/stats/_resampling.py", line 130, in _bca_interval
    theta_hat_i.append(statistic(*broadcasted, axis=-1))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/stats/_resampling.py", line 38, in stat_nd
    return np.apply_along_axis(stat_1d, 0, z)[()]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/numpy/lib/shape_base.py", line 379, in apply_along_axis
    res = asanyarray(func1d(inarr_view[ind0], *args, **kwargs))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/scipy/stats/_resampling.py", line 36, in stat_1d
    return statistic(*data)
           ^^^^^^^^^^^^^^^^
  File "/home/wdecoster/chr15q14/scripts/precision_recall.py", line 16, in precision
    return precision_score(y_true, y_pred)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 216, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/metrics/_classification.py", line 2247, in precision_score
    p, _, _, _ = precision_recall_fscore_support(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 189, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/metrics/_classification.py", line 1830, in precision_recall_fscore_support
    labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/metrics/_classification.py", line 1596, in _check_set_wise_labels
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/metrics/_classification.py", line 98, in _check_targets
    check_consistent_length(y_true, y_pred)
  File "/home/wdecoster/miniconda3/envs/data_analysis/lib/python3.12/site-packages/sklearn/utils/validation.py", line 475, in check_consistent_length
    raise ValueError(
ValueError: Found input variables with inconsistent numbers of samples: [76, 77]

Кроме того, когда я перенаправляю вывод скрипта, печатая его, в wc, он всегда возвращает 1002 строки. Если я перенаправляю вывод скрипта в sort | uniq -c, я получаю:

      1 76 77
   1001 77 77

Если я изменяю “n_resamples” на 100, он выдает ошибку после 102 строк.

Что происходит?!

Я понял, что мне также нужно paired=True, и я не знаю, почему это помогает, но у меня больше нет ошибки…

Ответ или решение

Проблема, с которой вы столкнулись при использовании функции bootstrap из библиотеки SciPy, связана с несоответствием количества образцов между двумя входными переменными — labels и test. Эта проблема особенно актуальна при использовании функции precision_score из библиотеки scikit-learn, которая требует, чтобы входные данные имели одинаковое количество образцов для корректной работы. Рассмотрим подробнее, с чем это связано и как можно решить эту проблему.

Теория

Когда вы выполняете операцию бутстрепинга на паре данных с использованием Scipy, важно убедиться, что данные согласованы. Функция бутстрепинга создает множество подвыборок из исходных данных, и одна из основных задач — это сохранить согласованость размерности данных в каждой итерации.

Бутстрепинг — это статистический метод, который используется для оценки распределения некоторой статистики (например, средней или стандартного отклонения) путем многократного выборочного повторного отбора из данных. Обычно предполагается, что каждая подвыборка имеет тот же размер, что и исходный набор данных, и в большинстве случаев это происходит автоматически. Однако в некоторых случаях, как в вашем, могут возникать ошибки связаны с тем, что scipy никак не контролирует метки (например, использование параметра paired=False по умолчанию означает, что каждая переменная обрабатывается независимо).

Пример

Рассмотрим примерный код, вызвавший ошибку:

import pandas as pd
from sklearn.metrics import precision_score
from scipy.stats import bootstrap

df = pd.DataFrame({'labels': labels_values, 'test': test_values})
prec = precision_score(df["labels"], df["test"]) # рабочий код

res = bootstrap((df["labels"], df["test"]), precision_score, n_resamples=1000)

Здесь precision_score рассчитывается корректно при первичном использовании, поскольку исходные данные согласованы по размерности. Однако при попытке использовать bootstrap для многократной выборки эта согласованность становится критически важной.

Применение

Ваш случай демонстрирует, как использование параметра paired=True решает проблему. Вот почему это происходит. По умолчанию функция bootstrap в SciPy предполагает, что данные независимы, то есть они не имеют внутренней структуры, требующей согласованности по индексам (например, каждым парам точек соответствует собственный идентификатор выборки). Использование paired=True указывает функции на необходимость симметричного отбора пар элементов, что обеспечивает их синхронность и согласованность в размерах.

Применение решения:

import pandas as pd
from sklearn.metrics import precision_score
from scipy.stats import bootstrap

df = pd.DataFrame({'labels': labels_values, 'test': test_values})

# Оберните функцию precision_score в другую, которая принимает дополнительные аргументы
def precision(y_true, y_pred):
    return precision_score(y_true, y_pred)

# Выполните бутстрепинг, указав paired=True
res = bootstrap((df["labels"], df["test"]), precision, n_resamples=1000, paired=True)

Используя paired=True, вы предотвращаете потерю согласованных пар данных, вследствие чего не возникает различий в количестве элементов внутри каждой подвыборки. Это позволяет гарантировать, что обе выборки всегда имеют один и тот же размер и структуру на протяжении всего процесса бутстрепинга, что устраняет ошибки, связанные с несоответствием размерности.

Заключение

Метод согласования данных путем использования параметра paired=True особенно полезен, когда вы работаете с набором данных, где каждая точка данных имеет сопоставленное значение, например, в задачах бинарной классификации или оценке производительности алгоритмов машинного обучения. Это гарантирует, что во время бутстрепинга сохранится корректная структура данных, позволяя функции производить статистически значимые оценки.

Надеюсь, это объяснение поможет вам лучше понять, как обрабатывать подобные ошибки и как грамотно использовать инструменты SciPy в ваших будущих аналитических проектах.

Оцените материал
Добавить комментарий

Капча загружается...