Edit model card

Pytorch weights for Kornia ViT converted from the original google JAX vision-transformer repo.

from kornia.contrib import VisionTransformer

vit_model = VisionTransformer.from_config('vit_b/32', pretrained=True)
...

Original weights from AugReg as recommended by google research vision transformer repo: This weight is based on the AugReg ViT_B/32 pretrained on imagenet21k

Weights converted to PyTorch for Kornia ViT implementation (by @gau-nernst in kornia/kornia#2786)

Convert jax checkpoint function
def convert_jax_checkpoint(np_state_dict: dict[str, np.ndarray]):
    
    def get_weight(key: str) -> torch.Tensor:
        return torch.from_numpy(np_state_dict[key])
    
    state_dict = dict()
    state_dict["patch_embedding.cls_token"] = get_weight("cls")
    state_dict["patch_embedding.backbone.weight"] = get_weight("embedding/kernel").permute(3, 2, 0, 1)  # conv »
    state_dict["patch_embedding.backbone.bias"] = get_weight("embedding/bias")
    state_dict["patch_embedding.positions"] = get_weight("Transformer/posembed_input/pos_embedding").squeeze(0)
    
    # for i, block in enumerate(self.encoder.blocks):
    for i in range(100):
        prefix1 = f"encoder.blocks.{i}"
        prefix2 = f"Transformer/encoderblock_{i}"

        if f"{prefix2}/LayerNorm_0/scale" not in np_state_dict:
            break

        state_dict[f"{prefix1}.0.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_0/scale")
        state_dict[f"{prefix1}.0.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_0/bias")

        mha_prefix = f"{prefix2}/MultiHeadDotProductAttention_1"
        qkv_weight = [get_weight(f"{mha_prefix}/{x}/kernel") for x in ["query", "key", "value"]]
        qkv_bias = [get_weight(f"{mha_prefix}/{x}/bias") for x in ["query", "key", "value"]]
        state_dict[f"{prefix1}.0.fn.1.qkv.weight"] = torch.cat(qkv_weight, 1).flatten(1).T
        state_dict[f"{prefix1}.0.fn.1.qkv.bias"] = torch.cat(qkv_bias, 0).flatten()
        state_dict[f"{prefix1}.0.fn.1.projection.weight"] = get_weight(f"{mha_prefix}/out/kernel").flatten(0, 1»
        state_dict[f"{prefix1}.0.fn.1.projection.bias"] = get_weight(f"{mha_prefix}/out/bias")

        state_dict[f"{prefix1}.1.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_2/scale")
        state_dict[f"{prefix1}.1.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_2/bias")
        state_dict[f"{prefix1}.1.fn.1.0.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/kernel").T
        state_dict[f"{prefix1}.1.fn.1.0.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/bias")
        state_dict[f"{prefix1}.1.fn.1.3.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/kernel").T
        state_dict[f"{prefix1}.1.fn.1.3.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/bias")

    state_dict["norm.weight"] = get_weight("Transformer/encoder_norm/scale")
    state_dict["norm.bias"] = get_weight("Transformer/encoder_norm/bias")
    return state_dict
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Drag image file here or click to browse from your device
Unable to determine this model's library. Check the docs .