Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Overview

This experiment is created to analyze the training dynamics of the vision transfomrers unde Prisma project. The small Vision Transformers were trained and evaluated for the task of shape classification on the dSprites dataset. This dataset consists of 2D shapes generated procedurally, focusing on six independent latent factors. This specific task involved classifying three distinct shapes within the dSprites dataset using ViTs. All of the training checkpoints are available on the Hugging Face Hub. The checkpoints are summarised in the following table with links to the models on the Hub:

Size No. Layers AttentionOnly Attention-and-MLP
tiny 1 link link
base 2 link link
small 3 link link
medium 4 link link

Here each repo has the multiple intermediate checkpoints. Each checkpoint is stored as "checkpoint_{i}.pth", where i the the number of traineng sample the model has been trained on.

The other details regarding training and results is described Here.

How to Use

!git clone https://github.com/soniajoseph/ViT-Prisma
!cd ViT-Prisma
!pip install -e .
from huggingface_hub import hf_hub_download
import torch

REPO_ID = "IamYash/dSprites-tiny-AttentionOnly"
FILENAME = "model_0.pth"

checkpoint = torch.load(
hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
)
from vit_prisma.models.base_vit import BaseViT
from vit_prisma.configs.DSpritesConfig import GlobalConfig
from vit_prisma.utils.wandb_utils import update_dataclass_from_dict

config = GlobalConfig()
print(config)
update_dict = {
    'transformer':{
        'attention_only': True,
        'hidden_dim': 512,
        'num_heads': 8,
        'num_layers': 1
    }
}
update_dataclass_from_dict(config, update_dict)

model = BaseViT(config)

model.load_state_dict(checkpoint['model_state_dict'])

license: mit

Downloads last month
0
Unable to determine this model's library. Check the docs .