metadata
license: apache-2.0
datasets:
- imagenet-1k
metrics:
- accuracy
pipeline_tag: image-classification
tags:
- pytorch
- torch-dag
Model Card for beit_base_patch16_224_pruned_65
This is a prunned version of the timm/beit_base_patch16_224.in22k_ft_in22k_in1k model in a toch-dag format.
This model has rougly 65% of the original model FLOPs with minimal metrics drop.
Model | KMAPPs* | M Parameters | Accuracy (224x224) |
---|---|---|---|
timm/beit_base_patch16_224.in22k_ft_in22k_in1 (baseline) | 673.2 | 86.5 | 85.23% |
beit_base_patch16_224_pruned_65 (ours) | 438 (65%) | 56.7 (66%) | 84.53% (↓ 0.7%) |
*KMAPPs thousands of FLOPs per input pixel
KMAPPs(model) = FLOPs(model) / (H * W * 1000)
, where (H, W)
is the input resolution.
The accuracy was calculated on the ImageNet-1k validation dataset. For details about image pre-processing, please refer to the original repository.
Model Details
Model Description
- Developed by: TCL Research Europe
- Model type: Classification / feature backbone
- License: Apache 2.0
- Finetuned from model: timm/beit_base_patch16_224.in22k_ft_in22k_in1k
Model Sources
- Repository: timm/beit_base_patch16_224.in22k_ft_in22k_in1k
How to Get Started with the Model
To load the model, You have to install torch-dag library, which can be done using pip
by
pip install torch-dag
then, clone this repository
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/TCLResearchEurope/beit_base_patch16_224_pruned_65
and now You are ready to load the model:
import torch_dag
import torch
model = torch_dag.io.load_dag_from_path('./beit_base_patch16_224_pruned_65')
model.eval()
out = model(torch.ones(1, 3, 224, 224))
print(out.shape)