|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transformer_model = transformer_model.to(device) |
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss() |
|
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001) |
|
|
|
|
|
def generate_square_subsequent_mask(sz): |
|
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool() |
|
return mask |
|
|
|
|
|
def create_padding_mask(seq): |
|
mask = (seq == 0).transpose(0, 1) |
|
return mask |
|
|
|
|
|
num_epochs = 10 |
|
for epoch in range(num_epochs): |
|
transformer_model.train() |
|
epoch_loss = 0 |
|
|
|
for batch_idx, batch in enumerate(train_data_loader): |
|
src_batch = batch[:, :-1].to(device) |
|
tgt_input = batch[:, :-1].to(device) |
|
targets = batch[:, 1:].contiguous().view(-1).to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
src_seq_len = src_batch.size(0) |
|
tgt_seq_len = tgt_input.size(0) |
|
src_mask = generate_square_subsequent_mask(src_seq_len) |
|
tgt_mask = generate_square_subsequent_mask(tgt_seq_len) |
|
src_padding_mask = create_padding_mask(src_batch) |
|
tgt_padding_mask = create_padding_mask(tgt_input) |
|
|
|
|
|
output = transformer_model( |
|
src_batch.transpose(0,1), tgt_input.transpose(0,1), |
|
src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = output.view(-1, vocab_size) |
|
loss = loss_fn(output, targets) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
epoch_loss += loss.item() |
|
|
|
if batch_idx % 100 == 0: |
|
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {epoch_loss / (batch_idx+1):.4f}") |
|
|
|
print(f"Epoch {epoch+1}/{num_epochs} finished, Average Loss: {epoch_loss / len(train_data_loader):.4f}") |
|
|
|
|
|
torch.save(transformer_model.state_dict(), 'transformer_model.pth') |
|
|
|
import json |
|
|
|
with open('vocabulary.json', 'w') as vocab_file: |
|
json.dump(dataset.vocab, vocab_file) |
|
|
|
|