File size: 1,917 Bytes
b54146b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import os

from dataset import MNIST_2x2
from model.model import ImageToDigitTransformer

# Use MPS if available (Apple Silicon)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Config
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-3
VOCAB_SIZE = 13

# Transforms 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Dataset & DataLoader 
train_base = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_dataset = MNIST_2x2(train_base, seed=42)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Model, Loss, Optimizer 
model = ImageToDigitTransformer(vocab_size=VOCAB_SIZE).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# Training Loop
model.train()
for epoch in range(EPOCHS):
    total_loss = 0.0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)

    for images, dec_input, dec_target in loop:
        images = images.to(device)
        dec_input = dec_input.to(device)
        dec_target = dec_target.to(device)

        logits = model(images, dec_input)
        loss = loss_fn(logits.view(-1, VOCAB_SIZE), dec_target.view(-1))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()

        # Update tqdm every batch
        loop.set_postfix(batch_loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{EPOCHS} - Loss: {avg_loss:.4f}")

# save weights
os.makedirs("checkpoints", exist_ok=True)
torch.save(model.state_dict(), "checkpoints/transformer_mnist.pt")
print("Model saved to checkpoints/transformer_mnist.pt")