File size: 3,199 Bytes
cd80f69
 
5996db8
cd80f69
5996db8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
---
license: apache-2.0
pipeline_tag: image-classification
---

Pytorch weights for Kornia ViT converted from the original google JAX vision-transformer repo.

Original weights from https://github.com/google-research/vision_transformer: This weight is based on the [Original ViT_L/16 pretrained on imagenet21k](https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz)

Weights converted to PyTorch for Kornia ViT implementation (by [@gau-nernst](https://github.com/gau-nernst) in [kornia/kornia#2786](https://github.com/kornia/kornia/pull/2786#discussion_r1482339811))
<details>
  
  <summary>Convert jax checkpoint function</summary>

  ```
  def convert_jax_checkpoint(np_state_dict: dict[str, np.ndarray]):
      
      def get_weight(key: str) -> torch.Tensor:
          return torch.from_numpy(np_state_dict[key])
      
      state_dict = dict()
      state_dict["patch_embedding.cls_token"] = get_weight("cls")
      state_dict["patch_embedding.backbone.weight"] = get_weight("embedding/kernel").permute(3, 2, 0, 1)  # conv »
      state_dict["patch_embedding.backbone.bias"] = get_weight("embedding/bias")
      state_dict["patch_embedding.positions"] = get_weight("Transformer/posembed_input/pos_embedding").squeeze(0)
      
      # for i, block in enumerate(self.encoder.blocks):
      for i in range(100):
          prefix1 = f"encoder.blocks.{i}"
          prefix2 = f"Transformer/encoderblock_{i}"
  
          if f"{prefix2}/LayerNorm_0/scale" not in np_state_dict:
              break
  
          state_dict[f"{prefix1}.0.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_0/scale")
          state_dict[f"{prefix1}.0.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_0/bias")
  
          mha_prefix = f"{prefix2}/MultiHeadDotProductAttention_1"
          qkv_weight = [get_weight(f"{mha_prefix}/{x}/kernel") for x in ["query", "key", "value"]]
          qkv_bias = [get_weight(f"{mha_prefix}/{x}/bias") for x in ["query", "key", "value"]]
          state_dict[f"{prefix1}.0.fn.1.qkv.weight"] = torch.cat(qkv_weight, 1).flatten(1).T
          state_dict[f"{prefix1}.0.fn.1.qkv.bias"] = torch.cat(qkv_bias, 0).flatten()
          state_dict[f"{prefix1}.0.fn.1.projection.weight"] = get_weight(f"{mha_prefix}/out/kernel").flatten(0, 1»
          state_dict[f"{prefix1}.0.fn.1.projection.bias"] = get_weight(f"{mha_prefix}/out/bias")
  
          state_dict[f"{prefix1}.1.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_2/scale")
          state_dict[f"{prefix1}.1.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_2/bias")
          state_dict[f"{prefix1}.1.fn.1.0.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/kernel").T
          state_dict[f"{prefix1}.1.fn.1.0.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/bias")
          state_dict[f"{prefix1}.1.fn.1.3.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/kernel").T
          state_dict[f"{prefix1}.1.fn.1.3.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/bias")
  
      state_dict["norm.weight"] = get_weight("Transformer/encoder_norm/scale")
      state_dict["norm.bias"] = get_weight("Transformer/encoder_norm/bias")
      return state_dict
  ```
</details>