File size: 1,875 Bytes
b54146b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# MNIST multimodal transformer - Overview

## Task 1
Our goal is to build and train a multimodal transformer from scratch.
The task is to predict the sequence of digits from a image composed of 4x4 MNIST images tiled together.
the transformer should be able to predict the labels from tope left, to top right, bottom left, bottom right. 

## Outcome

clean minimal and well organized project folder structure
Clean and minimal pytorch code, well organized across dataset, dataloader and model classes
clear evaluation metrics 

# Execution

## Dataset 

create a dataset class that returns a single example of:
- an image made of 2x2 MNIST images picked at random (from training split) and stitched together 
- the 4 labels organized in top-left top-right, bottom-left, bottom-right 

## Model 

create a transformer architecture, encoder decoder for this task. The architecture is made of three main elements:
- Feature extractor 
- Encoder 
- Decoder

### feature extractor 

each image is cut into 16 patches of dim 14x14px (given my stitched 2x2 image is now 56x56 pixels) 
and linearly projected to a dimension of 64, which is the constant latent vector size D for the encoder. 
these represent the image embeddings that are fed as input to the encoder block 

### Encoder 

should follow closely the "attention is all you need" vanilla implementation, similarly to the ViT Vision Transformer paper

- positional embeddings are added to the patch embeddings to retain positional information
using standard learnable 1D position embeddings
- encoder consists then of alternating layers of multi-headed self-attention and MLP blocks
- layernorm is applied before every attention block and MLP block
- residual connections are applied after every block 

the output of the encoder is going to be a set of encoded representations of the image patches (16x64)

### Decoder