Использование Keras с sklearn: применение class_weight с cross_val_score

Вопрос или проблема

У меня имеется сильно несбалансированный набор данных (± 5% положительных экземпляров), для которого я обучаю бинарные классификаторы. Я использую вложенную 5-кратную кросс-валидацию с поиском по сетке для настройки гиперпараметров.

Я хочу избежать недовыборки, поэтому я исследую гиперпараметр class_weight. Для классификатора decisiontree из sklearn это работает очень хорошо и легко задается как гиперпараметр. Однако для нейронной сети от sklearn (многослойного перцептрона) это, насколько я могу судить, не является вариантом. Вместо этого я использую Keras, и могу применить class_weight с gridsearchCV, но не с cross_val_score.

Существует ли способ использовать class_weights в Keras с кросс-валидацией?

На данный момент MLPClassifier из sklearn не реализует никаких опций для весов классов. У вас есть как минимум два пути. Вы можете написать пользовательскую функцию потерь, которая позволит вам оставаться в рамках sklearn, не обращаясь к keras. Другой вариант — реализовать кросс-валидацию самостоятельно, что не сложно сделать, и запускать вашу модель keras для каждого сложения. Пример можно найти здесь.

Вы должны иметь возможность передавать class_weights через аргумент fit_params функции cross_val_score.

Ответ или решение

Для решения проблемы применения весов классов в Keras при использовании перекрестной валидации с помощью функции cross_val_score из библиотеки Scikit-Learn, мы можем воспользоваться некоторыми подходами, позволяющими передавать веса классов в модель Keras.

Вот пошаговое решение:

1. Определение модели Keras

Сначала мы должны определить архитектуру нашей нейронной сети в Keras. Например:

from keras.models import Sequential
from keras.layers import Dense

def create_model():
    model = Sequential()
    model.add(Dense(64, input_dim=10, activation='relu'))  # Пример: 10 признаков на входе
    model.add(Dense(1, activation='sigmoid'))  # Выход для бинарной классификации
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

2. Обертка для Keras-классификатора

Мы создадим класс-обертку, который позволит интегрировать модель Keras в Scikit-Learn. Это можно сделать с использованием KerasClassifier.

from keras.wrappers.scikit_learn import KerasClassifier

model = KerasClassifier(build_fn=create_model, epochs=100, batch_size=10, verbose=0)

3. Определение функции для получения весов классов

Чтобы передать вес классов в fit во время перекрестной валидации, мы создадим функцию, которая будет принимать ваши данные и веса классов.

from sklearn.utils.class_weight import compute_class_weight
import numpy as np

def get_class_weights(y):
    classes = np.unique(y)
    class_weights = compute_class_weight('balanced', classes=classes, y=y)
    return dict(enumerate(class_weights))

4. Использование cross_val_score с параметрами fit

Теперь мы можем применить cross_val_score с параметрами для передачи весов классов. Для этого воспользуемся функцией make_scorer, если это необходимо, и передадим веса через fit_params.

from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.datasets import make_classification

# Создание искусственного набора данных
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2, weights=[0.95, 0.05], random_state=42)

# Стратегия разбиения на фолды
cv = StratifiedKFold(n_splits=5)

# Оценка модели с использованием cross_val_score
scores = cross_val_score(model, X, y, cv=cv, fit_params={'class_weight': get_class_weights(y)}, scoring='accuracy')

print("Точность для каждой фолды:", scores)
print("Средняя точность:", np.mean(scores))

Заключение

Таким образом, мы успешно интегрировали Keras с функцией cross_val_score и передали веса классов в метод fit. Этот подход позволяет эффективно обрабатывать несбалансированные данные без необходимости в подвыборе. Не забудьте настроить количество эпох и размер пакета в зависимости от ваших данных и вычислительных возможностей.

Оцените материал
Добавить комментарий

Капча загружается...