Трансформер спамит самый частый символ.

Вопрос или проблема

Я заметил, что трансформер, как правило, оптимизируется для генерации наиболее часто встречающегося символа.

Например, у меня есть следующие входные токены: ["a", "1", "a", "a", "2", "a", "a", "a", "3"].

И выход должен быть: ["<sos>", "b", "1", "b", "b", "2", "b", "b", "b", "3", "<eos>"].

После обучения с вышеуказанными данными модель просто генерирует самый частый символ: ['<sos>', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b'].

Как я могу избежать такой проблемы?

Для вашего сведения, вот код, который я использовал:

import torch
import torch.nn as nn
import torch.optim as optim

seed = 3
device = torch.device("cpu")
torch.manual_seed(seed)
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(seed)

class OurModel(nn.Module):

    def __init__(self, token_to_index):
        """
        Конструктор.
        :param token_to_index: Мы предполагаем, что индекс=0 представляет токен "<sos>", а индекс=1 представляет токен "<eos>".
        """
        super().__init__()
        self.token_to_index = token_to_index
        self.index_to_token = {index: token for token, index in token_to_index.items()}
        self.vocab_size = len(token_to_index)
        # Для информации, 512 — это значение по умолчанию для трансформера.
        token_vec_dim = 64
        self.embeds = nn.Embedding(self.vocab_size, token_vec_dim)
        self.transformer = nn.Transformer(
            d_model=token_vec_dim,
            # `d_model` должен делиться на `nhead`.
            # Для информации, значение по умолчанию для `nhead` — 8.
            # Также вы получите предупреждения, если установите `nhead` на нечетное число.
            nhead=8 if token_vec_dim % 8 == 0 else 2,
            batch_first=True
        )
        self.fc = nn.Sequential(
            nn.Linear(token_vec_dim, 30),
            nn.ReLU(),
            nn.Linear(30, self.vocab_size)
        )

    def forward(self, sequence, max_iteration=10):
        """
        Генерировать последовательность токенов в авторегрессивном режиме.
        Генерация заканчивается, когда токен "<eos>" сгенерирован или достигнут максимальный итераций.
        Вот как это работает:
        Сначала мы преобразуем токены в вектора.
        Во-вторых, нам нужны два параметра для использования `nn.Transformer(src, tgt)`.
        `src` представляет входные векторы.
        Пример: Тензор (векторов) для ["roses", "are"].
        `tgt` представляет сгенерированные векторы на данный момент.
        Сначала мы используем тензор для ["<sos>"].
        Выход `nn.Transformer(src, tgt)` — это тензор такой же формы, что и `tgt`.
        Например, если `tgt` был тензором для ["<sos>"], выход будет тензором векторов, например [[1, 2]].
        Мы берем [1, 2], помещаем его в полносвязный слой для предсказания следующего токена.
        Как только мы получили следующий токен, мы добавляем его в `tgt`.
        Например, `tgt` для следующей итерации может быть тензором для ["<sos>", "red"].
        Выход трансформера может выглядеть как тензор векторов, например [[1, 2], [3, 4]].
        Обратите внимание, что мы берем только последний элемент [3, 4] и игнорируем остальные.
        Мы помещаем последний элемент в полносвязный слой для предсказания следующего токена.
        Если предсказанный токен — "<eos>", мы прекращаем генерацию.
        :param sequence: Список токенов. Пример: ["roses", "are"].
        :param max_iteration: Мы прекращаем генерировать токены, если достигнута max_interation, не получив токен "<eos>".
        :return: Список токенов. Пример: ["<sos>", "red", "<eos>"].
        """
        token_indexes_raw = [self.token_to_index[token] for token in sequence]
        # Форма: (количество_токенов).
        token_indexes_tensor = torch.LongTensor([index for index in token_indexes_raw]).to(device)
        # Форма: (количество_токенов, размер_вектора_токена).
        token_vectors = self.embeds(token_indexes_tensor)
        # Форма: (1, размер_вектора_токена).
        sos_token_vec = self.embeds(torch.LongTensor([0]).to(device))
        # Форма: (количество_сгенерированных_токенов_до_сего_времени, размер_вектора_токена).
        current_target = sos_token_vec
        generated_token_indexes = [0]
        iteration = 1
        predicted_token_index = -1
        while predicted_token_index != 1 and iteration < max_iteration:
            # Форма такая же, как у цели.
            transformer_output = self.transformer(token_vectors, current_target)
            # Форма: (размер_вектора_токена).
            last_element = transformer_output[-1]
            # Форма: (размер_словаря)
            raw_scores = self.fc(last_element)
            predicted_token_index = torch.argmax(raw_scores).item()
            generated_token_indexes.append(predicted_token_index)
            # Форма: (1, размер_вектора_токена).
            last_generated_token_vec = self.embeds(torch.LongTensor([predicted_token_index]).to(device))
            current_target = torch.cat((current_target, last_generated_token_vec), dim=0).to(device)
            iteration += 1
        return [self.index_to_token[index] for index in generated_token_indexes]

    def fit(self, sources, targets):
        """
        Обучить модель.
        Вот как это работает:
        Сначала мы преобразуем токены в вектора.
        Во-вторых, мы получаем выход из `nn.Transformer(sources, targets)`.
        Выход трансформера будет батчем сгенерированных токенов.
        Обратите внимание, что форма выхода такая же, как форма целей.
        Затем мы помещаем выход трансформера в качестве входа для полносвязного слоя.
        Полносвязный слой даст батч сырых оценок.
        С каждой сырой оценкой мы можем определить предсказанный следующий токен.
        Для потерь мы можем сравнить предсказанный следующий токен с фактическим следующим токеном из цели трансформера.
        Например, цели были: [["<sos>", "red"], ["<sos>", "green"]].
        Выход трансформера был: [[value_1, value_2], [value_3, value_4]].
        Выход полносвязного слоя был: [[raw_score_1, raw_score_2], [raw_score_3, raw_score_4]].
        raw_score_1 должен был предсказать токен "red".
        raw_score_2 должен был предсказать токен "<eos>".
        Как видите, последний балл соответствует токену "<eos>".
        Вот почему мы исключили токен "<eos>" из целей.
        :param sources: Батч источников в простом списке. Пример: [["roses", "are"], ["limes", "are"]].
        :param targets: Батч целей в простом списке. Пример: [["red"], ["green"]].
                        На самом деле, он должен выглядеть так: [["<sos>", "red"], ["<sos>", "green"]].
                        Тем не менее, этот метод сам добавит "<sos>".
                        Обратите внимание, что мы исключаем токен "<eos>", потому что после этого токена нечего предсказывать.
        :return: None.
        """
        loss_function = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.parameters())
        for epoch in range(1000):
            batch_source_indexes = [[self.token_to_index[token] for token in source] for source in sources]
            # Форма: (размер_пакета, количество_источников_токенов_индексов, размер_вектора_токена).
            batch_source_vectors = self.embeds(torch.LongTensor(batch_source_indexes).to(device))
            # Обратите внимание, что мы добавляем индекс токена "<sos>" в начале для каждого целевого токена.
            batch_target_indexes = [[0] + [self.token_to_index[token] for token in target] for target in targets]
            # Форма: (размер_пакета, количество_целей_токенов_индексов, размер_вектора_токена).
            batch_target_vectors = self.embeds(torch.LongTensor(batch_target_indexes).to(device))
            # Форма: (размер_пакета, количество_целей_токенов_индексов, размер_вектора_токена).
            transformer_output = self.transformer(batch_source_vectors, batch_target_vectors)
            # Форма: (размер_пакета, количество_целей_токенов_индексов, размер_словаря).
            batch_raw_scores = self.fc(transformer_output)
            # Обратите внимание, что мы добавляем индекс токена "<eos>" в конце для каждого целевого токена.
            batch_correct_token_indexes = [[self.token_to_index[token] for token in target] + [1] for target in targets]
            loss = 0
            for batch_index in range(len(batch_raw_scores)):
                # Форма: (количество_целевых_токенов, размер_словаря).
                raw_scores = batch_raw_scores[batch_index]
                # Форма: (количество_целевых_токенов).
                correct_token_indexes = torch.LongTensor(batch_correct_token_indexes[batch_index]).to(device)
                loss += loss_function(raw_scores, correct_token_indexes)
            print(f"epoch: {epoch}, loss: {loss}")
            if loss < 0.001:
                break
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

def main():
    token_to_index = {
        # sos означает начало последовательности.
        "<sos>": 0,
        # eos означает конец последовательности.
        "<eos>": 1,
        "a": 2,
        "b": 3,
        "1": 4,
        "2": 5,
        "3": 6
    }

    model = OurModel(token_to_index).to(device)

    train_sources = [
        ["a", "1", "a", "a", "2", "a", "a", "a", "3"],
    ]
    train_targets = [
        ["b", "1", "b", "b", "2", "b", "b", "b", "3"],
    ]
    model.fit(train_sources, train_targets)

    input_tokens = ["a", "a", "2"]
    generated_tokens = model(input_tokens)
    print(f"input_tokens: {input_tokens}, generated_tokens: {generated_tokens}")

main()

Ответ или решение

Чтобы решить вопрос о том, почему ваша модель трансформера воспроизводит наиболее часто встречающийся токен и как избежать этого поведения, необходимо рассмотреть несколько ключевых аспектов активного обучения моделью и настройки ее параметров.

Почему модель "спамит"?

  1. Смещение частоты символов: Если у вас в обучающем наборе данных преобладает какой-либо токен (например, b), модель может "закрепиться" за ним, так как он оптимизирует функцию потерь. Когда модель обучается, она скорее будет выбирать самые вероятные токены, что в вашем случае приводит к постоянному выбору b.

  2. Отсутствие разнообразия в обучении: Из-за однообразия входных и выходных данных, модель не изучает разнообразные способы генерации токенов, что приводит к тому, что она начинает "запоминать" чаще встречающиеся токены.

Как решить проблему?

1. Изменение функции потерь:

Используйте модифицированные функции потерь, такие как label smoothing. Эта техника предполагает небольшое размывание вероятностей правильных классов, что позволяет избежать переобучения на частоте токенов.

2. Увеличение данных:

Попробуйте увеличить набор данных, добавив больше примеров, где токены разбираются более разнообразно. Это поможет модели вспомнить о других возможных токенах и не сосредоточиться только на наиболее вероятных.

3. Использование температурного сэмплирования:

Добавьте temperature sampling, что позволяет регулировать степень, с которой выбираются вероятные токены. Высокая температура приведёт к более разнообразному выбору токенов, а низкая – к более консервативному подходу.

Например, вы можете изменить выбор токена следующим образом:

   def sample_with_temperature(raw_scores, temperature=1.0):
       probabilities = torch.softmax(raw_scores / temperature, dim=-1)
       return torch.multinomial(probabilities, num_samples=1).item()
  1. Аугментация данных:
    Применение методов аугментации данных, таких как перетасовка токенов в последовательности или генерация синтетических данных, может увеличить разнообразие тренировочного набора.

4. Использование механизма отбора:

В процессе генерации токенов помимо использования argmax, можно использовать семплирование из распределения вероятностей, что увеличит вероятность выбора менее частых токенов.

Пример изменения кода

Вот как вы могли бы изменить часть вашего кода, чтобы использовать семплирование с температурой:

def forward(self, sequence, max_iteration=10, temperature=1.0):
    # Существующий код остается без изменения
    while predicted_token_index != 1 and iteration < max_iteration:
        transformer_output = self.transformer(token_vectors, current_target)
        last_element = transformer_output[-1]
        raw_scores = self.fc(last_element)

        # Используем температурное сэмплирование
        predicted_token_index = sample_with_temperature(raw_scores, temperature)

        generated_token_indexes.append(predicted_token_index)
        last_generated_token_vec = self.embeds(torch.LongTensor([predicted_token_index]).to(device))
        current_target = torch.cat((current_target, last_generated_token_vec), dim=0).to(device)
        iteration += 1
    return [self.index_to_token[index] for index in generated_token_indexes]

Заключение

Изменяя подход к обучению и генерации токенов, можно значительно улучшить качество работы модели трансформера и избежать "спама" наиболее частых токенов. Экспериментируйте с различными методами, чтобы достичь наилучшего результата.

Оцените материал
Добавить комментарий

Капча загружается...