Как модельRetriever (кодировщик запросов) обучается от начала до конца в рамках увеличенного поиска для генерации (RAG)?

Вопрос или проблема

RAG architecture

Архитектура RAG из оригинальной статьи

Поскольку потеря рассчитывается на выходном слое генератора, как градиенты обратным распространением передаются в модель извлечения?

Потому что вводом для генератора является чистый текст, то есть текст извлеченного документа + вопрос.

Указано, что вся архитектура настраивается в режиме end-to-end. Обратите внимание на рисунок “end-to-end backprop through q”, где q — это кодировщик запроса. Мой вопрос: как ошибка обратно пропагируется к q? Потому что при расчете потерь на выходе генератора модель кодировщика запроса не играет никакой роли.

Статья ясно заявила, что

Извлекатель (DPR) предоставляет латентные документы, основанные на
входных данных, а модель seq2seq (BART [32]) затем основывается на этих
латентных документах
вместе с входными данными для генерации выхода.

Таким образом, было очевидно, что генератор принимает текст (идентификаторы ввода) в качестве входных данных. Но тогда возникает вопрос: “как градиенты распространяются до модели кодировщика запроса?”

Мне пришлось просмотреть исходный код проекта RAG на huggingface, чтобы найти ответ на этот вопрос.

Я начал с класса RagModel. Изучив исходный код этого класса, я заметил, что модель генератора действительно использовала input_ids в функции forward.

Чтобы найти, как градиенты обратно пропагируются до кодировщика запроса, нам нужно выяснить, как рассчитываются потери. Мы видим, что используется варьируемая версия потерь nll (строка 864).

loss_snippet

Поскольку кодировщик запроса участвует только благодаря doc_scores, мы будем следовать за ним. Проанализировав функции (get_nll –> marginalize), мы обнаружили, что log-softmax применяется к (a) логиту последовательности генератора и (b) doc_scores.

loss_marginalization

Затем обе эти величины складываются и передаются через трюк логармированный суммы экспонент. Почему лог-сумма-экспонена? Я думаю, это потому, что это будет эквивалентно суммированию вероятностей ($\sum_{i}{e^{log(p_i)}} = \sum_{i}{p_i}$) и взятию логарифма (для численной стабильности). Таким образом, градиенты обратно пропагируются до модели кодировщика запроса.

Поскольку у модели кодировщика запроса нет истинных значений, потери извлечения просто вычисляются как softmax от doc_scores. Мне кажется, что когда эта потеря извлечения добавляется к LM-потере (seq_logits), то результирующая потеря в основном представляет собой следующую философию:

“если окончательный сгенерированный ответ достаточно хорош (низкая LM потеря), то это
указывает на то, что извлеченные документы также были достаточно хороши. Следовательно,
хорошая репрезентация запроса! Нет необходимости в больших изменениях в кодировщике запроса.
Аналогично, большая LM потеря –> плохие извлеченные документы –> нужны большие изменения
(обновления параметров) в кодировщике запроса.”

Ответ или решение

Вопрос о том, как модель извлечения (Query Encoder) обучается в конечной системе Retrieval Augmented Generation (RAG) через обратное распространение градиентов, представляет собой интересную задачу, помимо самой архитектуры и процессов, которые происходят в ней. Давайте разберем ключевые аспекты этого процесса.

Архитектура RAG

Архитектура RAG, как указано в оригинальной статье, включает два ключевых компонента: модель извлечения (например, Dense Passage Retrieval – DPR) и генератор (например, BART). Модель извлечения предоставляет набор документов, соответствующих заданному запросу, на основе которого генератор затем формирует ответ.

Обратное распространение градиентов

При обучении сети в RAG важным моментом является способ расчета потерь (loss) и того, как градиенты успевают вернуться к модели извлечения. Потеря рассчитывается на выходном слое генератора, однако модель извлечения (Query Encoder) также участвует в этом процессе. По сути, родственные комитеты RAG утверждают, что архитектура обучается end-to-end, что подразумевает, что все компоненты испытывают сложный процесс обратного распространения.

Расчет потери

Как вы правильно заметили, модель извлечения вносит свой вклад в расчет потери через doc_scores, которые зависят от выходных данных модели извлечения. При помощи функции get_nll и marginalize создается лог-функция потерь, которая сочетает активность генератора и выходные документы, предоставленные моделью извлечения.

  • log-softmax — эта функция применяется как к логитам генератора, так и к doc_scores. Это позволяет создать совместное распределение, акцентируя внимание на том, как качество документов влияет на качество генерации.

  • Затем используется метод log-sum-exp, который обеспечивает числовую стабильность и позволяет суммировать вероятности, что критически важно для оптимизации.

Как градиенты достигают Query Encoder

Градиенты достигают Query Encoder благодаря тому, что как потери генератора, так и потери извлечения (от softmax doc_scores) требуют обратного распространения. Объединив эти два аспекта, мы способны передать информацию обратно в Query Encoder.

Объяснение через философию потери

Важно отметить, что модель извлечения не имеет прямой "истинной" метки, но мы можем выражать потери через doc_scores. Это производит интуитивное представление, где низкая потеря генерации (LM loss) указывает на то, что документы были уместными. В то время как высокая потеря указывает на необходимость значительного изменения в Query Encoder, подчеркивая, что извлеченные документы нуждаются в улучшении.

Так, конечная утрированная интерпретация потерь подсказывает, что "если сгенерированный ответ хорош (низкая LM loss), то документы соответствуют запросу, следовательно, вероятность успешного паттерна запроса велика".

Заключение

Таким образом, модель извлечения в системе RAG действительно обучается конечным образом, так как градиенты от функции потерь генератора в конечном итоге возвращаются к Query Encoder, обогатив процесс обучения. Этот механизм позволяет как улучшать качество извлечения, так и качество генерации, что делает подход RAG особенно эффективным для задач, связанных с извлечением информации и генерацией текста.

Эффективность этого процесса подчеркивается необходимостью интеграции информации из различных источников, что и является сердцем мощного инструмента RAG.

Оцените материал
Добавить комментарий

Капча загружается...