nico-x's picture
codebase withouth model
b54146b
# MNIST multimodal transformer - Overview
## Task 1
Our goal is to build and train a multimodal transformer from scratch.
The task is to predict the sequence of digits from a image composed of 4x4 MNIST images tiled together.
the transformer should be able to predict the labels from tope left, to top right, bottom left, bottom right.
## Outcome
clean minimal and well organized project folder structure
Clean and minimal pytorch code, well organized across dataset, dataloader and model classes
clear evaluation metrics
# Execution
## Dataset
create a dataset class that returns a single example of:
- an image made of 2x2 MNIST images picked at random (from training split) and stitched together
- the 4 labels organized in top-left top-right, bottom-left, bottom-right
## Model
create a transformer architecture, encoder decoder for this task. The architecture is made of three main elements:
- Feature extractor
- Encoder
- Decoder
### feature extractor
each image is cut into 16 patches of dim 14x14px (given my stitched 2x2 image is now 56x56 pixels)
and linearly projected to a dimension of 64, which is the constant latent vector size D for the encoder.
these represent the image embeddings that are fed as input to the encoder block
### Encoder
should follow closely the "attention is all you need" vanilla implementation, similarly to the ViT Vision Transformer paper
- positional embeddings are added to the patch embeddings to retain positional information
using standard learnable 1D position embeddings
- encoder consists then of alternating layers of multi-headed self-attention and MLP blocks
- layernorm is applied before every attention block and MLP block
- residual connections are applied after every block
the output of the encoder is going to be a set of encoded representations of the image patches (16x64)
### Decoder