Вопрос или проблема
Недавно я изучал модель Transformer, используя реализацию на Pytorch, но моя модель не сходилась. Я задумался, была ли проблема в моем коде или в чем-то еще. Поэтому я подумал, что если я “упрощу” задачу для обучения, то ее будет легче обучать, и создал предложение и повторил его несколько раз (чтобы заставить модель переобучиться), но моя модель все равно не сходилась. Так что мой вопрос заключается в том, есть ли какая-то ошибка в моей реализации/цикле обучения или переобучение не так тривиально?
Модель:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class TransformerModel(nn.Module):
def __init__(self, d_model, max_seq_len, enc_vocab_size, dec_vocab_size):
super(TransformerModel, self).__init__()
self.transformer = nn.Transformer(d_model=d_model, nhead=8, num_encoder_layers=4, num_decoder_layers=8, dim_feedforward=2048, batch_first=True)
self.d_model = d_model
self.enc_vocab_size = enc_vocab_size
self.dec_vocab_size = dec_vocab_size
self.encoder_emb = nn.Embedding(enc_vocab_size, d_model)
self.decoder_emb = nn.Embedding(dec_vocab_size, d_model)
self.pe = PositionalEncoding(d_model)
self.ol = nn.Linear(d_model, dec_vocab_size)
def forward(self, enc_input, enc_mask, dec_input, dec_mask, causal_mask):
enc_input = self.encoder_emb(enc_input)
enc_input = self.pe(enc_input)
dec_input = self.decoder_emb(dec_input)
dec_input = self.pe(dec_input)
out = self.transformer(enc_input, dec_input, tgt_mask=causal_mask, src_key_padding_mask=enc_mask, tgt_key_padding_mask=dec_mask)
return self.ol(out)
Цикл обучения:
en_vocab_size = len(tokenizer_en.get_vocab())
pt_vocab_size = len(tokenizer_pt.get_vocab())
seq_len = 40
model = TransformerModel(d_model=256, max_seq_len=seq_len, enc_vocab_size=en_vocab_size, dec_vocab_size=pt_vocab_size)
model.to(device)
causal_mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1) * -10**9
causal_mask = causal_mask.to(device)
opt = torch.optim.Adam(model.parameters(), lr=10e-3)
loss_fn = nn.CrossEntropyLoss()
model.train()
def get_input_and_mask(sequence_batch, tokenizer, device):
sequence_batch = tokenizer.encode_batch(sequence_batch)
mask = torch.tensor([e.attention_mask for e in sequence_batch], dtype=bool, device=device)
sequence_batch = torch.tensor([e.ids for e in sequence_batch], device=device)
return sequence_batch, mask
num_batches = len(train_data)
print(f"num_batches = {num_batches}")
for epoch in range(100):
print(f" EPOCH {epoch} ".center(80, '-'))
for iteration, (enc, dec) in enumerate(train_data):
enc_input, enc_mask = get_input_and_mask(enc, tokenizer_en, device)
dec_input, dec_mask = get_input_and_mask(dec, tokenizer_pt, device)
y = model(enc_input, enc_mask, dec_input, dec_mask, causal_mask)
opt.zero_grad()
loss = loss_fn(y.view(-1, pt_vocab_size), dec_input.view(-1))
loss.backward()
Набор данных:
Input> "Porque semeiam vento, colherão tempestade; não haverá seara, pois o talo não produzirá cereal; o pouco que der, os estrangeiros o devorarão."
Target> "For they have sown the wind, and they shall reap the whirlwind: it hath no stalk: the bud shall yield no meal: if so be it yield, the strangers shall swallow it up."
Я повторил это предложение 20 тысяч раз.
Ответ или решение
Как заставить модель переобучиться: детали и рекомендации
Переобучение модели — это процесс, при котором она начинает лучше всего представлять обучающие данные, но теряет способность обобщать на невиданных данных. Это может быть полезно в некоторых случаях, когда нужно выявить ошибки в коде или настройках модели. Однако важным является понимание методологии переобучения и определение причин, по которым ваша модель не сходится, даже при попытках переобучения.
Анализ вашего подхода
Вы уже предприняли шаги к созданию примитивного набора данных для переобучения, путем многократного повторения одной и той же пары фраз. Тем не менее, это не всегда гарантирует переобучение, особенно учитывая, что ваше решение не достигло желаемого результата. Давайте рассмотрим несколько аспектов, которые могли повлиять на этот процесс.
1. Проверка архитектуры модели
Ваша модель TransformerModel
выглядит правильно, и вы используете nn.Transformer
с необходимыми параметрами. Однако важно убедиться, что:
- Размерность входных данных соответствует ожидаемому формату.
- Параметры модели, например, размер модели (
d_model
), соответствуют сложности задачи.
Переобучение также может произойти из-за:
- Слишком больших или слишком маленьких значений параметров и гиперпараметров, таких как скорость обучения.
2. Гиперпараметры
Ваша скорость обучения (lr=10e-3
) кажется достаточно высокой для большинства задач. Рекомендуется:
- Попробовать уменьшить скорость обучения (например, 1e-4 или 1e-5) и посмотреть, приведет ли это к большей стабильности во время обучения.
- Также стоит поэкспериментировать с разными оптимизаторами (например, AdamW, RMSprop) и настройками.
3. Потеря и обработка данных
Вы используете функцию потерь nn.CrossEntropyLoss()
, что является стандартным выбором для задач классификации. Но:
- Убедитесь, что порядок данных соответствует целям модели.
- Проверьте, корректно ли подготовлены входные и масочные тензоры перед подачей в модель.
4. Изучение тщательности подготовки данных
Вы упомянули, что повторяли одно предложение 20,000 раз. Это может быть недостаточно для создания полезного многообразия данных. Даже для переобучения важно понимать:
- Модель может не переобучиться на одном предложении, если оно слишком короткое.
- Рассмотрите возможность добавления резкого искажения или шума в данные для создания "грустных" примеров.
5. Логика построения цикла обучения
Ваш цикл обучения в целом выглядит корректно, но там может быть множество деталей, на которые следует обратить внимание:
- Проверьте, вызывается ли
opt.step()
в конце каждой итерации, чтобы обновить веса модели. - Оцените, выводится ли потеря каждый раз и сохраняется ли очень высокая потеря или колебания.
Заключение
Переобучение модели — это более сложный процесс, чем просто продублирование данных. Необходимо учесть множество факторов, включая архитектуру модели, гиперпараметры, обработку данных и логику цикла обучения. Если и после этих проверок проблема останется актуальной, возможно, имеет смысл более детально проанализировать код, использованный для обработки данных, или переосмыслить подход к архитектуре модели.