Potential discrepancy between the weights between Huggingface and Timm for google/vit-base-patch16-224-in21k

#7
by turtleman - opened

If I understand correctly, the google/vit-base-patch16-224-in21k corresponds to timm's vit_base_patch16_224.augreg_in21k.

However, I found HuggingFace's has a Pooler layer that timm's doesn't have.

Besides, I checked some specific weights, e.g.,

  1. Huggingface: embeddings.patch_embeddings.projection.weight
  2. timm: patch_embed.proj.weight
    They are not equal.

Other minor things could be eps of `LayerNorm'.

I'm wondering if the correct weights have been converted.

Google org

cc @nielsr

Hi,

The weights from this model correspond to the original "vit_base_patch16_224_21k" which Google released in their JAX repository and which were ported to timm (and is deprecated in timm now as seen here). It does not correspond to the "augreg" checkpoint ("augreg" is from a follow-up paper called "How to train your ViT"). It looks like that legacy name in timm now points to the augreg version (which is a better trained version). cc'ing @rwightman for confirmation.

This is the conversion script that was used to convert the checkpoint: https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/convert_vit_timm_to_pytorch.py. It would definitely be great to convert newer, more capable ViT checkpoints.

All transformers vit models I'm aware are the originals from An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale.

timm includes both the originals and the best checkpoints from How to train your ViT?. The script @nielsr mentioned can be used to convert the 'augreg' models.

It should be noted that the original '21k' vit models have a zero'd out classifier head. They cannot be used for classification.

The augreg 21k weights have a valid classifier head, they can be used for classification, and in timm have appropriate class mappings (try the classification widget https://huggingface.co/timm/vit_large_patch16_224.augreg_in21k).

timm deprecations are supposed to deprecate the old naming, but for some reason I lost the original base 21k model (it was not deprecated), woops.

  • augreg 21k model is vit_large_patch16_224.augreg_in21k
  • original 21k model which matches the equivalent HF one (except I remove the pre_logits that are no longer being used) is vit_large_patch32_224.orig_in21k

I also fine-tuned several 21k 'How to train your ViT' models with better recipes to augreg2 tags
https://github.com/huggingface/pytorch-image-models/blob/ef72c3cd470dd67836eebf95ec567199c890a6a2/timm/models/vision_transformer.py#L1048-L1052

EDIT:
There are also 50K other checkpoints from the How to train your ViT? paper that can be loaded directly (npz files) in the timm models from this table, incl fine-tunes on CIFAR, resisc, kitti, oxford pets, etc: gs://vit_models/augreg/index.csv

https://colab.research.google.com/github/google-research/vision_transformer/blob/main/vit_jax_augreg.ipynb#scrollTo=yy-cuGxyD6Xw

Note that when augreg models came out, I originally replaced all overlapping original models with the new weights (old weights had an L32 and new did not, so that remained). Then, when I added multi-weight support I added some originals back, but forgot to add a few of the 21k back.

And yeah, the LN eps is wrong in most transformers vit that I'm aware of, it's 1e-6 not 1e-12. The impact is relatively small on validation results, but can impact training stability as 1e-6 is in the range that's okay for lower precision training and 1e-12 is not.

Wow, many thanks for your detailed answers, @lysandre @nielsr @rwightman !

Yeah, I can confirm most of (I'd say 99.9%) the weights are of the same between HF's google/vit-large-patch32-224-in21k and timm's vit_large_patch32_224.orig_in21k.

However, I'm still curious why HF's has a pooler layer at last while timm's doesn't. I understand timm's last layer is the original classification head. HF's model card indicates

"""
 However, the model does include the pre-trained pooler, which can be used for downstream tasks (such as image classification).
"""
(pooler): ViTPooler(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (activation): Tanh()
  )

I'm wondering

  1. Does timm include this layer? If not, how can I convert HF's to timm's? It looks like all original ViTs have this problem between HF and timm.
    (for a quick glance, this might be easy, to directly borrow HF's weight to timms)
  2. Is there a quick way I can convert HF's vit base model to timm's, as there is no support for base vit models?

Thanks a lot!

@turtleman

For the transformers ViTPooler, it is indeed a bit confusing, looks like the 'pooler' in this case is, in addition to selecting the class token, applying the 'pre logits' representation part of the MLP head (nn.Linear + tanh activation) as described in the original paper that was used for pretraining (https://github.com/google-research/vision_transformer/blob/10ffdebb01aa40714b175a7c3be700c872efb2f4/vit_jax/models_vit.py#L291-L295).

Confusingly, the ViTForImageClassification head also 'pools' in that it does class token selection again (https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py#L805). The pooled output is also not used for classification, it exists only as a separate tuple or dict output ('pooled_output'). If the classifier for those original 21k models had been valid, it would have been using the full MLP head.

In timm, I removed the hidden MLP representation altogether since there were no valid classification models that used it or would use it.

  • Those pre-logits were only ever used for the original 21k pretrained (and unreleased JFT) models that had classifier zero'd out
  • All of the original fine-tuned models removed the nn.Linear + tanh activation before fine-tune
  • Prior to collaborating on 'How to train your ViT' paper, I suggested the pre-logits should be removed altogether as my experiments showed it appeared to worsen in1k pretrain
  • Google authors checked that it had little to no impact on 21k pretrain, so all augreg models 21k and fine-tunes were trained without the MLP head

FYI one of my tasks today, I'm uploading the missing orig_21k models for timm. I'm also explictly removing the empty head (num_classes=0) to avoid future confusion as having them there but with weights zero'd out is confusing.

EDIT: Missing models are there now, ie B/16 https://huggingface.co/timm/vit_base_patch16_224.orig_in21k ... and update on the main branch of timm for it. I do recommend using the augreg 21k weights though, they're much better. Also, the vit_xxxx_patchxx_clip_xxx.[laion2b/openai/dfn/metaclip] image tower weights from CLIP models or vit_xxx_patchxx_siglip_xxx.weblifrom SigLIP are even stronger for features and fine-tune use.

Sign up or log in to comment