Использование KerasClassifier для обучения нейронной сети

Вопрос или проблема

Я создал простую нейронную сеть для бинарной классификации текста (спам/неспам) с использованием предобученного трансформера BERT. Текущая реализация на чистом Keras работает нормально. Однако я хотел построить график некоторых метрик обученной модели, в частности ROC-кривую. Согласно этому посту в блоге, я понял, что это возможно только с использованием KerasClassifier() из пакета keras.wrappers.scikit-learn, который теперь устарел и был заменен пакетом scikeras.

Таким образом, я создал функцию build_keras_nn() для построения своей кастомной нейронной сети на базе BERT. Затем я передал эту кастомную функцию в KerasClassifier(), как показано в документации, и обучил модель с использованием обучающих данных.

На этом этапе я получил следующее сообщение об ошибке:

ValueError: Ожидался 2D массив, вместо этого получен 1D массив: Измените форму ваших данных, используя array.reshape(-1, 1), если ваши данные имеют одну характеристику, или array.reshape(1, -1), если они содержат одну выборку.

Ладно, я подумал, что простое изменение формы массива сработает, но потом столкнулся со следующей ошибкой:

ValueError: не удалось преобразовать строку в число с плавающей запятой: 'ужасный биография очевидно подлый пресс-релиз совершенно предвзятый плохая грамматика'

Таким образом, по какой-то причине реализация KerasClassifier не позволяет мне напрямую вводить текст, хотя мои шаги предварительной обработки включены внутри кастомной функции build_keras_nn().

Полный воспроизводимый код ниже:

import tensorflow_hub as hub 
import tensorflow_text as text
import tensorflow as tf 
from tensorflow.keras.layers import Input, Dropout, Dense
from tensorflow.keras.metrics import BinaryAccuracy, AUC

bert_encoder_url = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"
bert_preprocessor_url = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

bert_preprocessor_model = hub.KerasLayer(bert_preprocessor_url)
bert_encoder_model = hub.KerasLayer(bert_encoder_url)

df_ = pd.read_json(spam_ham_json)   # spam_ham_json: данные в формате JSON как строка

X_train_, X_test_, y_train_, y_test_ = train_test_split(df_['comment_text'], df_['label'])

def build_keras_nn():

    text_input = Input(shape=(), dtype=tf.string, name="text")
    preprocessed_text = bert_preprocessor_model(text_input)
    bert_output = bert_encoder_model(preprocessed_text)
    dropout = Dropout(0.1, name="dropout")(bert_output['pooled_output'])
    classification_output = Dense(1, activation='sigmoid', name="classification_output")(dropout)

    model = tf.keras.Model(inputs=[text_input], outputs=[classification_output])

    metrics_list = [AUC(name="auc"), BinaryAccuracy(name="accuracy")]
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics = metrics_list)
    return model 

# Следующие две строки показывают реализацию на чистом Keras: это работает.
# model = build_keras_nn()
# history = model.fit(X_train_, y_train_, epochs=5);

# Теперь давайте посмотрим KerasClassifier
model = KerasClassifier(build_fn=build_keras_nn)
# history = model.fit(X_train_, y_train_, epochs=5);     # <-- Ошибка значения 1
# history = model.fit(np.array(X_train_).reshape(-1,1), np.array(y_train_).reshape(-1,1), epochs=5);     # <-- Ошибка значения 2

Данные в формате json:

'{"comment_text":{"0":"проблема хотела сказать проблема пыталась перенаправить событие расписание пакистана NUMBERTAG NUMBERTAG пакистан мать ебал мальчик хочет жениться на сестре оххх любовь сестры сиська хммммм внууу","1":"получить жизнь ебанный лузер вопрос спросить спросить кэти гулет картина","2":"конченая выпивка эй, какой ниггер думал, что ты сможешь забанить долго, потому что был занят как ад недавно, я продолжу приходить обратно, возьму слово конченая","3":"лжец лжец штанга горит серьезно смотрела вкладку теннисный портал страница теннисная страница ха дискуссия когда-либо пожалуйста обмануть NUMBERTAG NUMBERTAG NUMBERTAG NUMBERTAG","4":"перестань писать п ничего не обсуждать дано отсутствие образования диплома","5":"я ебаный редактирую страницу","6":"вопрос сумасшедший гей","7":"внимание страница nerd, пожалуйста, оставьте одну девушку, пожалуйста, оставайтесь, нееееееееasseeсяяявхррр NUMBERTAG одиннннн","8":"полная хуйня","9":"пошел на хуй конрад чёрный обманул тысячу людей пенсию кто-либо защищает его хуй апологет злодей","10":"список должностного лица национального союза студента австралии wp userfy userfied страница расположена","11":"говорят о истории шотландской национальной партии утверждают, что шпионят, кто бы ни хотел, можно ли кто-то верить, это значение кто-то верит npov утверждение","12":"раздел, предназначенный для рецензии вице-редактора, между прочим, сайт журнала имя писателя также прикреплено, как richardwilson NUMBERTAG даже не знает вопроса ninjarobotpirate, эй, счастливо критиковать ответ \\u2026 \\u2026 между прочим, NUMBERTAG далеко не знаю, ни редактора, ни албанца, ни хорвата может быть, самолёт видения довольно хороший, думаю, что лучше следить","13":"в следующий раз твитливый","14":"физика йо йо йо собака","15":"самоцензура телевизионного шоу может быть, может быть, заметное телевидение, прерывающее новости, заметное происходит в это время","16":"статья содержит информацию, которая soursed huddersfield aa улица улица","17":"utc на что-то центробежная сила испытывается массой, exhibiting инерция результат крошечная маленькая пуля ударит сторону катающегося музыка merry go round rueda puthoff haisch описывается ноль точка поля электронное уравнение лоренца coupling инерциальная рамка ссылку дать массу инерциальное сопротивление, а скорее, сопротивление позволяет описать изменение скорости направления по сравнению с ac v dc tesla v edison NUMBERTAG NUMBERTAG NUMBERTAG июня NUMBERTAG","18":"означал, что я имел в виду блокировать создать новую учетную записьRendering заблокировать бесполезный простой","19":"NUMBERTAG utc привет NUMBERTAG, вероятно, по ошибке думал, что иан был оригинальным участником, потому что всегда думал необходимая группа дисциплинарный гитарист NUMBERTAG, да, почти купил akai headrush looper год назад, знаменитую роль кабель одного гитариста записи settled bos loop station вместо этого скорее headrush boomerang из-за двух надежности проблемы цены соответственно check hovercraft southpacific auburn lull kind hallucinitory гитара looping думал кабель нового состава был невероятным, увидел, что NUMBERTAG пропустил классический состав NUMBERTAG сравнивает две производительности-wise best NUMBERTAG NUMBERTAG NUMBERTAG может"},"label":{"0":1,"1":1,"2":1,"3":1,"4":1,"5":1,"6":1,"7":1,"8":1,"9":1,"10":0,"11":0,"12":0,"13":0,"14":0,"15":0,"16":0,"17":0,"18":0,"19":0}}'

Мне удалось решить проблему!

Основная ошибка заключалась в том, что scikeras не принимает ненумерический ввод. Таким образом, мне пришлось прибегнуть к обходному пути, устранив шаги предварительной обработки текста из моего кастомного build_keras_nn() и преобразовав их в шаги предварительной обработки.

Кастомные шаги предварительной обработки, однако, требуют своих классов, которые наследуются от scikit-learn TransformerMixin и BaseEstimator. Я создал два кастомных класса предварительной обработки: один для выполнения задачи предварительной обработки BERT (BertPreprocessor) и другой для выполнения задачи кодирования BERT (BertEncoder). Все кастомные трансформаторы в стиле scikit-learn должны реализовывать метод fit() вместе с transform(), в противном случае они не могут быть интегрированы в другие встроенные трансформаторы и модели scikit-learn.

Одно последнее изменение заключалось в переписывании входного слоя в build_keras_nn(). Важно отметить, что BERT преобразует слова/токены в представление размерности 768, поэтому аргумент shape слоест Input установлен на 768.

Финальный полный код выглядит следующим образом:

class BertPreprocessor(TransformerMixin, BaseEstimator):
    def __init__(self, preprocessor_model):
        super().__init__()
        self.preprocessor_model = preprocessor_model
    def fit(self, X, y=None):
        return self 
    def transform(self, X, y=None):
        return self.preprocessor_model(X)

class BertEncoder(TransformerMixin, BaseEstimator):
    def __init__(self, encoder_model):
        super().__init__()
        self.encoder_model = encoder_model
    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        return self.encoder_model(X)['pooled_output']

def build_keras_nn():
    input_layer = Input(shape=768, name="bert_pooled_output")
    dropout = Dropout(0.1, name="dropout")(input_layer)
    classification_output = Dense(1, activation='sigmoid', name="classification_output")(dropout)

    model = tf.keras.Model(inputs=[input_layer], outputs=[classification_output])

    metrics_list = [AUC(name="auc"), BinaryAccuracy(name="accuracy")]
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics = metrics_list)
    return model 

pipe = Pipeline([
    ('preprocess', BertPreprocessor(bert_preprocessor_model)),
    ('encode', BertEncoder(bert_encoder_model)),
    ('model', KerasClassifier(build_fn=build_keras_nn, epochs=2)),
])

history = pipe.fit(X_train_, y_train_);
```

Ответ или решение

Использование KerasClassifier для обучения нейронной сети: Руководство по spam/ham классификации с BERT

В данной статье мы рассмотрим, как правильно использовать KerasClassifier в сочетании с предобученной трансформерной моделью BERT для решения задачи бинарной классификации текстов (спам/хам). Мы также разберем возникающие ошибки и предложим способ их устранения.

Задача

У вас имеется задача классификации текстов, где необходимо отделить спам от хамов (недозволенных сообщений). Вы уже создали функциональную модель на базе Keras и BERT, но столкнулись с проблемами при использовании KerasClassifier из пакета scikeras, поскольку он теперь заменяет устаревшую реализацию.

Проблемы, с которыми вы столкнулись

  1. Ошибка плоского массива:
    Вы получили сообщение об ошибке, связанное с тем, что KerasClassifier ожидает 2D массив, но получает 1D. Эта проблема возникает, когда вы пытаетесь подать текстовые данные в модель напрямую.

  2. Проблемы с преобразованием строк:
    При попытке выполнить преобразование массива вы столкнулись с еще одной ошибкой, указывающей на неудовлетворительное преобразование строки в число. Это свидетельствует о том, что необходимо правильно применить этапы предобработки текста перед подачей данных в модель.

Решение

Основная ошибка заключается в том, что scikeras не принимает текстовые данные и ожидает числовые. Вам необходимо извлекать функции (фичи) с помощью кастомных классов, которые будут реализовать преобразование текста, использующее возможности BERT.

Шаги для решения

  1. Создание кастомных классов предобработки:
    Для использования BERT вам необходимо создать два класса, которые будут наследовать TransformerMixin и BaseEstimator из библиотеки scikit-learn.

    • BertPreprocessor: Для предварительной обработки текста.
    • BertEncoder: Для кодирования текста в представление на 768 измерений, соответствующее выходам модели BERT.
  2. Переписывание функции build_keras_nn:
    Убедитесь, что входной слой вашей нейронной сети принимает 768 измерений, и уберите все шаги предобработки текстов из этой функции.

  3. Создание пайплайна с использованием Pipeline:
    С помощью Pipeline вы сможете интегрировать этапы предобработки и обучения модели в одну цепочку.

Полный код

import pandas as pd
import numpy as np
import tensorflow_hub as hub 
import tensorflow_text as text
import tensorflow as tf 
from sklearn.pipeline import Pipeline
from sklearn.base import TransformerMixin, BaseEstimator
from tensorflow.keras.layers import Input, Dropout, Dense
from tensorflow.keras.metrics import BinaryAccuracy, AUC
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import train_test_split

# Загружаем модели BERT
bert_encoder_url = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"
bert_preprocessor_url = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

bert_preprocessor_model = hub.KerasLayer(bert_preprocessor_url)
bert_encoder_model = hub.KerasLayer(bert_encoder_url)

# Загрузка и разделение данных
df_ = pd.read_json(spam_ham_json)   # spam_ham_json: данные в формате JSON
X_train_, X_test_, y_train_, y_test_ = train_test_split(df_['comment_text'], df_['label'], test_size=0.2)

class BertPreprocessor(TransformerMixin, BaseEstimator):
    def __init__(self, preprocessor_model):
        super().__init__()
        self.preprocessor_model = preprocessor_model

    def fit(self, X, y=None):
        return self 

    def transform(self, X, y=None):
        return self.preprocessor_model(X)

class BertEncoder(TransformerMixin, BaseEstimator):
    def __init__(self, encoder_model):
        super().__init__()
        self.encoder_model = encoder_model

    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
        return self.encoder_model(X)['pooled_output']

def build_keras_nn():
    input_layer = Input(shape=768, name="bert_pooled_output")
    dropout = Dropout(0.1, name="dropout")(input_layer)
    classification_output = Dense(1, activation='sigmoid', name="classification_output")(dropout)

    model = tf.keras.Model(inputs=[input_layer], outputs=[classification_output])
    metrics_list = [AUC(name="auc"), BinaryAccuracy(name="accuracy")]
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=metrics_list)
    return model 

# Создаем пайплайн
pipe = Pipeline([
    ('preprocess', BertPreprocessor(bert_preprocessor_model)),
    ('encode', BertEncoder(bert_encoder_model)),
    ('model', KerasClassifier(build_fn=build_keras_nn, epochs=2)),
])

# Обучаем модель
history = pipe.fit(X_train_, y_train_)

Заключение

Использование KerasClassifier с BERT для классификации текстов – это мощный и эффективный подход, хотя он требует тщательной настройки этапов предобработки и интеграции с библиотекой scikit-learn. В данном руководстве мы рассмотрели важные аспекты, связанные с настройкой модели, и предложили решение для устранения возникших проблем. Теперь вы готовы справиться с задачей бинарной классификации текстов, используя мощь предобученных моделей на базе BERT.

Оцените материал
Добавить комментарий

Капча загружается...