Spaces:
Sleeping
Sleeping
File size: 2,996 Bytes
412988e b54146b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
---
title: MNIST 2x2 Transformer Demo
emoji: π’
colorFrom: gray
colorTo: blue
sdk: gradio
sdk_version: 5.29.0 # or your current gradio version
app_file: app.py
pinned: false
---
# Transformer MNIST 2Γ2 β Image-to-Sequence Prediction
This project implements a minimal Transformer-based model that takes a 2Γ2 grid of MNIST digits as input and autoregressively predicts the corresponding 4-digit sequence. It serves as a practical deep dive into the inner workings of the Transformer architecture and basic multimodality concepts, combining vision (image patches) with language modeling (digit sequences).
## 1. Project Overview
The goal is to understand how a vanilla Transformer encoder-decoder can be applied to a simple multimodal task: mapping an image input to a discrete token sequence. This project focuses on building each architectural component from scratch and wiring them together cleanly.
## 2. Task Definition
- **Input:** a 2Γ2 grid composed of 4 random MNIST digits, forming a 56Γ56 grayscale image.
- **Output:** the 4-digit sequence corresponding to the digits in the grid (top-left β bottom-right).
- **Modeling approach:** sequence-to-sequence using an autoregressive decoder with special `<start>` and `<finish>` tokens.
## 3. Model Architecture
The model follows a clean encoder-decoder Transformer architecture:
- **Feature Extractor:** splits the 56Γ56 image into 16 non-overlapping patches of 14Γ14 pixels and projects each to a 64-dimensional embedding.
- **Transformer Encoder:** processes the 16 patch embeddings using standard multi-head self-attention, positional embeddings, and MLP blocks.
- **Transformer Decoder:** autoregressively predicts the digit sequence:
- Uses masked self-attention over token embeddings
- Attends to encoder output via cross-attention
- Outputs a sequence of logits over a vocabulary of 13 tokens (digits 0β9, `<start>`, `<finish>`)
- **Tokenizer:** handles token β digit conversions and input preparation.
## 4. Training Setup
- **Dataset:** MNIST, wrapped into a custom `MNIST_2x2` PyTorch dataset that returns the stitched image and 4-digit target.
- **Batch size:** 64
- **Epochs:** 10
- **Loss:** `CrossEntropyLoss` over vocabulary tokens
- **Optimizer:** Adam
- **Hardware:** Apple M4 with `mps` acceleration
- **Logging:** `tqdm` per-batch loss tracking for clear training progress
## 5. Evaluation
Evaluation is done on the held-out MNIST test set using greedy decoding:
- Starts with <start> token
- Predicts one token at a time (no teacher forcing)
- Stops after 4 tokens or if <finish> is predicted
### Evaluation Metrics
- **Sequence accuracy:** % of samples where all 4 digits are predicted correctly
- **Per-digit accuracy:** % of individual digits predicted correctly across all positions
### final results after 10 epochs of training
- **training loss at epoch 10:** 0.0101
- **Sequence accuracy:** 93.77% on held-out test set
- **Per digit accuracy:** 98.43% on held-out test set |