ProtoViT: Interpretable Vision Transformer with Adaptive Prototype Learning

This repository contains pretrained ProtoViT models for interpretable image classification, as described in our paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers".

Model Description

ProtoViT combines Vision Transformers with prototype-based learning to create models that are both highly accurate and interpretable. Rather than functioning as a black box, ProtoViT learns interpretable prototypes that explain its classification decisions through visual similarities.

Supported Architectures

We provide three variants of ProtoViT:

  • ProtoViT-T: Built on DeiT-Tiny backbone
  • ProtoViT-S: Built on DeiT-Small backbone
  • ProtoViT-CaiT: Built on CaiT-XXS24 backbone

Performance

All models were trained and evaluated on the CUB-200-2011 fine-grained bird species classification dataset.

Model Version Backbone Resolution Top-1 Accuracy Checkpoint
ProtoViT-T DeiT-Tiny 224ร—224 83.36% Download
ProtoViT-S DeiT-Small 224ร—224 85.30% Download
ProtoViT-CaiT CaiT_xxs24 224ร—224 86.02% Download

Features

  • ๐Ÿ” Interpretable Decisions: The model performs classification with self-explainatory reasoning based on the inputโ€™s similarity to learned prototypes, the key features for each classes.
  • ๐ŸŽฏ High Accuracy: Achieves competitive performance on fine-grained classification tasks
  • ๐Ÿš€ Multiple Architectures: Supports various Vision Transformer backbones
  • ๐Ÿ“Š Analysis Tools: Comes with tools for both local and global prototype analysis

Requirements

  • Python 3.8+
  • PyTorch 1.8+
  • timm==0.4.12
  • torchvision
  • numpy
  • pillow

Limitations and Bias

  • Data Bias: These models are trained on CUB-200-2011, which may not generalize well to images outside this dataset.
  • Resolution Constraints: The models are trained at a resolution of 224ร—224; higher or lower resolutions may impact performance.
  • Location Misalignment: Same as the CNN based models, these models are not perfectly immune to location misalignment under adversarial attack.

Citation

If you use this model in your research, please cite:

@article{ma2024interpretable,
  title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers},
  author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan},
  journal={arXiv preprint arXiv:2410.20722},
  year={2024}
}

Acknowledgements

This implementation builds upon the following excellent repositories:

License

This project is released under [MIT] license.

Contact

For any questions or feedback, please:

  1. Open an issue in the GitHub repository
  2. Contact [Chiyu.ma.gr@dartmouth.edu]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for chiyum609/ProtoViT

Finetuned
(1)
this model