Вопрос или проблема
Я заметил, что трансформер, как правило, оптимизируется для генерации наиболее часто встречающегося символа.
Например, у меня есть следующие входные токены: ["a", "1", "a", "a", "2", "a", "a", "a", "3"]
.
И выход должен быть: ["<sos>", "b", "1", "b", "b", "2", "b", "b", "b", "3", "<eos>"]
.
После обучения с вышеуказанными данными модель просто генерирует самый частый символ: ['<sos>', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b']
.
Как я могу избежать такой проблемы?
Для вашего сведения, вот код, который я использовал:
import torch
import torch.nn as nn
import torch.optim as optim
seed = 3
device = torch.device("cpu")
torch.manual_seed(seed)
if torch.cuda.is_available():
device = torch.device("cuda")
torch.cuda.manual_seed_all(seed)
class OurModel(nn.Module):
def __init__(self, token_to_index):
"""
Конструктор.
:param token_to_index: Мы предполагаем, что индекс=0 представляет токен "<sos>", а индекс=1 представляет токен "<eos>".
"""
super().__init__()
self.token_to_index = token_to_index
self.index_to_token = {index: token for token, index in token_to_index.items()}
self.vocab_size = len(token_to_index)
# Для информации, 512 — это значение по умолчанию для трансформера.
token_vec_dim = 64
self.embeds = nn.Embedding(self.vocab_size, token_vec_dim)
self.transformer = nn.Transformer(
d_model=token_vec_dim,
# `d_model` должен делиться на `nhead`.
# Для информации, значение по умолчанию для `nhead` — 8.
# Также вы получите предупреждения, если установите `nhead` на нечетное число.
nhead=8 if token_vec_dim % 8 == 0 else 2,
batch_first=True
)
self.fc = nn.Sequential(
nn.Linear(token_vec_dim, 30),
nn.ReLU(),
nn.Linear(30, self.vocab_size)
)
def forward(self, sequence, max_iteration=10):
"""
Генерировать последовательность токенов в авторегрессивном режиме.
Генерация заканчивается, когда токен "<eos>" сгенерирован или достигнут максимальный итераций.
Вот как это работает:
Сначала мы преобразуем токены в вектора.
Во-вторых, нам нужны два параметра для использования `nn.Transformer(src, tgt)`.
`src` представляет входные векторы.
Пример: Тензор (векторов) для ["roses", "are"].
`tgt` представляет сгенерированные векторы на данный момент.
Сначала мы используем тензор для ["<sos>"].
Выход `nn.Transformer(src, tgt)` — это тензор такой же формы, что и `tgt`.
Например, если `tgt` был тензором для ["<sos>"], выход будет тензором векторов, например [[1, 2]].
Мы берем [1, 2], помещаем его в полносвязный слой для предсказания следующего токена.
Как только мы получили следующий токен, мы добавляем его в `tgt`.
Например, `tgt` для следующей итерации может быть тензором для ["<sos>", "red"].
Выход трансформера может выглядеть как тензор векторов, например [[1, 2], [3, 4]].
Обратите внимание, что мы берем только последний элемент [3, 4] и игнорируем остальные.
Мы помещаем последний элемент в полносвязный слой для предсказания следующего токена.
Если предсказанный токен — "<eos>", мы прекращаем генерацию.
:param sequence: Список токенов. Пример: ["roses", "are"].
:param max_iteration: Мы прекращаем генерировать токены, если достигнута max_interation, не получив токен "<eos>".
:return: Список токенов. Пример: ["<sos>", "red", "<eos>"].
"""
token_indexes_raw = [self.token_to_index[token] for token in sequence]
# Форма: (количество_токенов).
token_indexes_tensor = torch.LongTensor([index for index in token_indexes_raw]).to(device)
# Форма: (количество_токенов, размер_вектора_токена).
token_vectors = self.embeds(token_indexes_tensor)
# Форма: (1, размер_вектора_токена).
sos_token_vec = self.embeds(torch.LongTensor([0]).to(device))
# Форма: (количество_сгенерированных_токенов_до_сего_времени, размер_вектора_токена).
current_target = sos_token_vec
generated_token_indexes = [0]
iteration = 1
predicted_token_index = -1
while predicted_token_index != 1 and iteration < max_iteration:
# Форма такая же, как у цели.
transformer_output = self.transformer(token_vectors, current_target)
# Форма: (размер_вектора_токена).
last_element = transformer_output[-1]
# Форма: (размер_словаря)
raw_scores = self.fc(last_element)
predicted_token_index = torch.argmax(raw_scores).item()
generated_token_indexes.append(predicted_token_index)
# Форма: (1, размер_вектора_токена).
last_generated_token_vec = self.embeds(torch.LongTensor([predicted_token_index]).to(device))
current_target = torch.cat((current_target, last_generated_token_vec), dim=0).to(device)
iteration += 1
return [self.index_to_token[index] for index in generated_token_indexes]
def fit(self, sources, targets):
"""
Обучить модель.
Вот как это работает:
Сначала мы преобразуем токены в вектора.
Во-вторых, мы получаем выход из `nn.Transformer(sources, targets)`.
Выход трансформера будет батчем сгенерированных токенов.
Обратите внимание, что форма выхода такая же, как форма целей.
Затем мы помещаем выход трансформера в качестве входа для полносвязного слоя.
Полносвязный слой даст батч сырых оценок.
С каждой сырой оценкой мы можем определить предсказанный следующий токен.
Для потерь мы можем сравнить предсказанный следующий токен с фактическим следующим токеном из цели трансформера.
Например, цели были: [["<sos>", "red"], ["<sos>", "green"]].
Выход трансформера был: [[value_1, value_2], [value_3, value_4]].
Выход полносвязного слоя был: [[raw_score_1, raw_score_2], [raw_score_3, raw_score_4]].
raw_score_1 должен был предсказать токен "red".
raw_score_2 должен был предсказать токен "<eos>".
Как видите, последний балл соответствует токену "<eos>".
Вот почему мы исключили токен "<eos>" из целей.
:param sources: Батч источников в простом списке. Пример: [["roses", "are"], ["limes", "are"]].
:param targets: Батч целей в простом списке. Пример: [["red"], ["green"]].
На самом деле, он должен выглядеть так: [["<sos>", "red"], ["<sos>", "green"]].
Тем не менее, этот метод сам добавит "<sos>".
Обратите внимание, что мы исключаем токен "<eos>", потому что после этого токена нечего предсказывать.
:return: None.
"""
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(self.parameters())
for epoch in range(1000):
batch_source_indexes = [[self.token_to_index[token] for token in source] for source in sources]
# Форма: (размер_пакета, количество_источников_токенов_индексов, размер_вектора_токена).
batch_source_vectors = self.embeds(torch.LongTensor(batch_source_indexes).to(device))
# Обратите внимание, что мы добавляем индекс токена "<sos>" в начале для каждого целевого токена.
batch_target_indexes = [[0] + [self.token_to_index[token] for token in target] for target in targets]
# Форма: (размер_пакета, количество_целей_токенов_индексов, размер_вектора_токена).
batch_target_vectors = self.embeds(torch.LongTensor(batch_target_indexes).to(device))
# Форма: (размер_пакета, количество_целей_токенов_индексов, размер_вектора_токена).
transformer_output = self.transformer(batch_source_vectors, batch_target_vectors)
# Форма: (размер_пакета, количество_целей_токенов_индексов, размер_словаря).
batch_raw_scores = self.fc(transformer_output)
# Обратите внимание, что мы добавляем индекс токена "<eos>" в конце для каждого целевого токена.
batch_correct_token_indexes = [[self.token_to_index[token] for token in target] + [1] for target in targets]
loss = 0
for batch_index in range(len(batch_raw_scores)):
# Форма: (количество_целевых_токенов, размер_словаря).
raw_scores = batch_raw_scores[batch_index]
# Форма: (количество_целевых_токенов).
correct_token_indexes = torch.LongTensor(batch_correct_token_indexes[batch_index]).to(device)
loss += loss_function(raw_scores, correct_token_indexes)
print(f"epoch: {epoch}, loss: {loss}")
if loss < 0.001:
break
loss.backward()
optimizer.step()
optimizer.zero_grad()
def main():
token_to_index = {
# sos означает начало последовательности.
"<sos>": 0,
# eos означает конец последовательности.
"<eos>": 1,
"a": 2,
"b": 3,
"1": 4,
"2": 5,
"3": 6
}
model = OurModel(token_to_index).to(device)
train_sources = [
["a", "1", "a", "a", "2", "a", "a", "a", "3"],
]
train_targets = [
["b", "1", "b", "b", "2", "b", "b", "b", "3"],
]
model.fit(train_sources, train_targets)
input_tokens = ["a", "a", "2"]
generated_tokens = model(input_tokens)
print(f"input_tokens: {input_tokens}, generated_tokens: {generated_tokens}")
main()
Ответ или решение
Чтобы решить вопрос о том, почему ваша модель трансформера воспроизводит наиболее часто встречающийся токен и как избежать этого поведения, необходимо рассмотреть несколько ключевых аспектов активного обучения моделью и настройки ее параметров.
Почему модель "спамит"?
-
Смещение частоты символов: Если у вас в обучающем наборе данных преобладает какой-либо токен (например,
b
), модель может "закрепиться" за ним, так как он оптимизирует функцию потерь. Когда модель обучается, она скорее будет выбирать самые вероятные токены, что в вашем случае приводит к постоянному выборуb
. -
Отсутствие разнообразия в обучении: Из-за однообразия входных и выходных данных, модель не изучает разнообразные способы генерации токенов, что приводит к тому, что она начинает "запоминать" чаще встречающиеся токены.
Как решить проблему?
1. Изменение функции потерь:
Используйте модифицированные функции потерь, такие как label smoothing. Эта техника предполагает небольшое размывание вероятностей правильных классов, что позволяет избежать переобучения на частоте токенов.
2. Увеличение данных:
Попробуйте увеличить набор данных, добавив больше примеров, где токены разбираются более разнообразно. Это поможет модели вспомнить о других возможных токенах и не сосредоточиться только на наиболее вероятных.
3. Использование температурного сэмплирования:
Добавьте temperature sampling, что позволяет регулировать степень, с которой выбираются вероятные токены. Высокая температура приведёт к более разнообразному выбору токенов, а низкая – к более консервативному подходу.
Например, вы можете изменить выбор токена следующим образом:
def sample_with_temperature(raw_scores, temperature=1.0):
probabilities = torch.softmax(raw_scores / temperature, dim=-1)
return torch.multinomial(probabilities, num_samples=1).item()
- Аугментация данных:
Применение методов аугментации данных, таких как перетасовка токенов в последовательности или генерация синтетических данных, может увеличить разнообразие тренировочного набора.
4. Использование механизма отбора:
В процессе генерации токенов помимо использования argmax
, можно использовать семплирование из распределения вероятностей, что увеличит вероятность выбора менее частых токенов.
Пример изменения кода
Вот как вы могли бы изменить часть вашего кода, чтобы использовать семплирование с температурой:
def forward(self, sequence, max_iteration=10, temperature=1.0):
# Существующий код остается без изменения
while predicted_token_index != 1 and iteration < max_iteration:
transformer_output = self.transformer(token_vectors, current_target)
last_element = transformer_output[-1]
raw_scores = self.fc(last_element)
# Используем температурное сэмплирование
predicted_token_index = sample_with_temperature(raw_scores, temperature)
generated_token_indexes.append(predicted_token_index)
last_generated_token_vec = self.embeds(torch.LongTensor([predicted_token_index]).to(device))
current_target = torch.cat((current_target, last_generated_token_vec), dim=0).to(device)
iteration += 1
return [self.index_to_token[index] for index in generated_token_indexes]
Заключение
Изменяя подход к обучению и генерации токенов, можно значительно улучшить качество работы модели трансформера и избежать "спама" наиболее частых токенов. Экспериментируйте с различными методами, чтобы достичь наилучшего результата.