| | import os |
| | import json |
| | import torch |
| | from torch import nn, optim |
| | from pathlib import Path |
| | from torch.utils.data import DataLoader |
| | from services.word_generation_dataset import WordGenDataset, collate_fn |
| | from services.tokenizer import tokenize_dataset, build_vocab |
| | from services.transformer import TinyTransformer |
| | from services.model import save_model |
| |
|
| | |
| | print("Preprocessing...") |
| |
|
| | |
| | cwd = os.path.dirname(__file__) |
| | file_path = os.path.join(cwd, "data", "definitions-2.json") |
| | with open(file_path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| | print("Data loaded!") |
| |
|
| | |
| | inputs, outputs = tokenize_dataset(data) |
| | print("Data Tokenized!") |
| |
|
| | |
| | vocab, inv_vocab = build_vocab(inputs, outputs) |
| | print("Vocabulary built!") |
| |
|
| | |
| | dataset = WordGenDataset(inputs, outputs, vocab) |
| | print("Dataset built!") |
| | dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) |
| | print("Dataloader built!") |
| |
|
| | |
| | print("Training...") |
| |
|
| | |
| | device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| | vocab_size = len(vocab) |
| | model = TinyTransformer(vocab_size=vocab_size).to(device) |
| |
|
| | |
| | criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"]) |
| | optimizer = optim.Adam(model.parameters(), lr=1e-4) |
| |
|
| | |
| | num_epochs = 10 |
| |
|
| | for epoch in range(num_epochs): |
| | model.train() |
| | total_loss = 0 |
| |
|
| | for batch in dataloader: |
| | src, tgt = batch |
| | src, tgt = src.to(device), tgt.to(device) |
| |
|
| | |
| | tgt_input = tgt[:, :-1] |
| | tgt_expected = tgt[:, 1:] |
| |
|
| | |
| | logits = model(src, tgt_input) |
| |
|
| | |
| | loss = criterion(logits.reshape(-1, vocab_size), tgt_expected.reshape(-1)) |
| |
|
| | |
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | total_loss += loss.item() |
| |
|
| | avg_loss = total_loss / len(dataloader) |
| | print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}") |
| |
|
| | |
| | print("Saving...") |
| | save_model(model, vocab) |
| | print("Finished!") |
| |
|