Edit model card

Tensorflow Keras implementation of Learning to tokenize in Vision Transformers

Full credits to Sayak Paul and Aritra Roy Gosthipaty for this work.

Intended uses & limitations

Vision Transformers (Dosovitskiy et al.) and many other Transformer-based architectures (Liu et al., Yuan et al., etc.) have shown strong results in image recognition. The following provides a brief overview of the components involved in the Vision Transformer architecture for image classification:

  • Extract small patches from input images.
  • Linearly project those patches.
  • Add positional embeddings to these linear projections.
  • Run these projections through a series of Transformer (Vaswani et al.) blocks.
  • Finally, take the representation from the final Transformer block and add a classification head. If we take 224x224 images and extract 16x16 patches, we get a total of 196 patches (also called tokens) for each image. The number of patches increases as we increase the resolution, leading to higher memory footprint. Could we use a reduced number of patches without having to compromise performance? Ryoo et al. investigate this question in TokenLearner: Adaptive Space-Time Tokenization for Videos. They introduce a novel module called TokenLearner that can help reduce the number of patches used by a Vision Transformer (ViT) in an adaptive manner. With TokenLearner incorporated in the standard ViT architecture, they are able to reduce the amount of compute (measured in FLOPS) used by the model. In this example, we implement the TokenLearner module and demonstrate its performance with a mini ViT and the CIFAR-10 dataset. We make use of the following references:
  • Official TokenLearner code
  • Image Classification with ViTs on keras.io
  • TokenLearner slides from NeurIPS 2021

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

name learning_rate decay beta_1 beta_2 epsilon amsgrad weight_decay exclude_from_weight_decay training_precision
AdamW 0.0010000000474974513 0.0 0.8999999761581421 0.9990000128746033 1e-07 False 9.999999747378752e-05 None float32
Downloads last month
4
Unable to determine this model’s pipeline type. Check the docs .