File size: 2,382 Bytes
b54146b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

from dataset import MNIST_2x2
from model.model import ImageToDigitTransformer
from utils.tokenizer import START, FINISH, decode

# device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# config
VOCAB_SIZE = 13
MAX_LEN = 5  # length of decoder input: [<start>, d1, d2, d3, d4]
SEQ_LEN = 4  # number of predicted digits

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_dataset = MNIST_2x2(mnist_test, seed=42)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model = ImageToDigitTransformer(vocab_size=VOCAB_SIZE).to(device)
model.load_state_dict(torch.load("checkpoints/transformer_mnist.pt", map_location=device))
model.eval()

# Evaluation Loop
correct_sequences = 0
digit_correct = 0
digit_total = 0

with torch.no_grad():
    loop = tqdm(test_loader, desc="Evaluating", leave=False)

    for image, _, target_ids in loop:
        image = image.to(device)
        target_ids = target_ids.squeeze(0).tolist()[:-1]  # remove <finish>

        decoded = [START]
        for _ in range(SEQ_LEN):
            input_ids = torch.tensor(decoded, dtype=torch.long).unsqueeze(0).to(device)
            logits = model(image, input_ids)
            next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
            decoded.append(next_token)
            if next_token == FINISH:
                break

        pred = decoded[1:][:SEQ_LEN]
        target = target_ids

        if pred == target:
            correct_sequences += 1
        digit_correct += sum(p == t for p, t in zip(pred, target))
        digit_total += len(target)

        seq_acc = 100.0 * correct_sequences / (digit_total // SEQ_LEN)
        digit_acc = 100.0 * digit_correct / digit_total
        loop.set_postfix(seq_acc=f"{seq_acc:.2f}%", digit_acc=f"{digit_acc:.2f}%")


# final results
total_samples = len(test_loader)
seq_acc = 100.0 * correct_sequences / total_samples
digit_acc = 100.0 * digit_correct / digit_total

print(f"\nFinal Evaluation Results:")
print(f"  Sequence accuracy: {seq_acc:.2f}%")
print(f"  Per-digit accuracy: {digit_acc:.2f}%")