ikala-ray commited on
Commit
3055d59
1 Parent(s): ca0556b

Upload transform.py

Browse files
Files changed (1) hide show
  1. transform.py +100 -0
transform.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors import safe_open
3
+
4
+ def transform(open_clip_safe_tensor_path):
5
+ tensors = {}
6
+ with safe_open(open_clip_safe_tensor_path, framework="pt", device=0) as f:
7
+ metadata = f.metadata()
8
+ for k in f.keys():
9
+ ignore_tensor = False
10
+ first_prefix = k.replace('visual.', 'vision_model.').replace('text.', 'text_model.')
11
+ new_key = first_prefix.replace('.trunk.', '.encoder.')
12
+ new_key = new_key.replace('.blocks.', '.layers.')
13
+ new_key = new_key.replace('.transformer.resblocks.', '.encoder.layers.')
14
+ if 'vision' in new_key:
15
+ new_key = new_key.replace('.self_attn.out_proj.', '.attn.proj.')
16
+ new_key = new_key.replace('.norm', '.layer_norm')
17
+ # mappings extracted from timm code
18
+ # ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2')
19
+ new_key = new_key.replace('.proj.c_fc', '.mlp.fc1')
20
+ new_key = new_key.replace('.proj.c_proj', '.mlp.fc2')
21
+ new_key = new_key.replace('.attn.proj', '.self_attn.out_proj')
22
+ if 'qkv' in new_key:
23
+ qkv_weight = f.get_tensor(k)
24
+ q, k, v = torch.chunk(qkv_weight, 3, dim=0)
25
+ tensors[new_key.replace('.attn.qkv', '.self_attn.q_proj')] = q.clone().detach()
26
+ tensors[new_key.replace('.attn.qkv', '.self_attn.k_proj')] = k.clone().detach()
27
+ tensors[new_key.replace('.attn.qkv', '.self_attn.v_proj')] = v.clone().detach()
28
+ ignore_tensor = True
29
+ # ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
30
+ # ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
31
+ # ("pos_embed", "vit.embeddings.position_embeddings"),
32
+ # ['vision_model.embeddings.patch_embedding.weight',
33
+ # 'vision_model.post_layernorm.weight', 'vision_model.embeddings.position_embedding.weight',
34
+ # 'vision_model.embeddings.class_embedding', 'vision_model.pre_layrnorm.weight',
35
+ # 'vision_model.pre_layrnorm.bias', 'vision_model.post_layernorm.bias']
36
+ # vision_model.encoder.layer_norm.bias
37
+ # vision_model.encoder.layer_norm.weight
38
+ # vision_model.encoder.patch_embed.proj.bias
39
+ # vision_model.encoder.patch_embed.proj.weight
40
+ # vision_model.encoder.pos_embed
41
+ replacement_keys = [
42
+ ('vision_model.encoder.patch_embed.proj.weight', 'vision_model.embeddings.patch_embedding.weight'),
43
+ ('vision_model.encoder.pos_embed', 'vision_model.embeddings.position_embedding.weight'),
44
+ ('vision_model.encoder.patch_embed.proj.bias', 'vision_model.pre_layrnorm.bias'),
45
+ ('vision_model.encoder.layer_norm.bias', 'vision_model.post_layernorm.bias'),
46
+ ('vision_model.encoder.layer_norm.weight', 'vision_model.post_layernorm.weight'),
47
+ ]
48
+ for old_, new_ in replacement_keys:
49
+ if old_ in new_key:
50
+ new_key = new_key.replace(old_, new_)
51
+ elif 'text' in new_key:
52
+ # text_model.encoder.layers.0.ln_1.bias ->
53
+ # text_model.encoder.layers.0.layer_norm1.bias
54
+ # text_model.encoder.layers.1.mlp.c_fc.bias ->
55
+ # text_model.encoder.layers.11.mlp.fc1.weight
56
+ new_key = new_key.replace('.ln_2.', '.layer_norm2.')
57
+ new_key = new_key.replace('.ln_1.', '.layer_norm1.')
58
+ new_key = new_key.replace('.mlp.c_fc', '.mlp.fc1')
59
+ new_key = new_key.replace('.mlp.c_proj', '.mlp.fc2')
60
+ new_key = new_key.replace('.attn.in_proj_', '.self_attn.qkv.')
61
+ new_key = new_key.replace('.attn.out_proj', '.self_attn.out_proj')
62
+ if 'qkv' in new_key:
63
+ # text_model.encoder.layers.0.self_attn.qkv.weight
64
+ # text_model.encoder.layers.4.self_attn.v_proj.weight
65
+ qkv_weight = f.get_tensor(k)
66
+ q, k, v = torch.chunk(qkv_weight, 3, dim=0)
67
+ tensors[new_key.replace('.self_attn.qkv', '.self_attn.q_proj')] = q.clone().detach()
68
+ tensors[new_key.replace('.self_attn.qkv', '.self_attn.k_proj')] = k.clone().detach()
69
+ tensors[new_key.replace('.self_attn.qkv', '.self_attn.v_proj')] = v.clone().detach()
70
+ ignore_tensor = True
71
+ replacement_keys = [
72
+ ('text_model.positional_embedding', 'text_model.embeddings.position_embedding.weight'),
73
+ ('text_model.token_embedding.weight', 'text_model.embeddings.token_embedding.weight'),
74
+ ('text_model.ln_final.bias', 'text_model.final_layer_norm.bias'),
75
+ ('text_model.ln_final.weight', 'text_model.final_layer_norm.weight'),
76
+ ('text_model.text_projection.weight', 'text_projection.weight'),
77
+ ]
78
+ for old_, new_ in replacement_keys:
79
+ if old_ in new_key:
80
+ new_key = new_key.replace(old_, new_)
81
+ if 'vision' in new_key and 'img_projector' in new_key:
82
+ print(new_key)
83
+
84
+ if ignore_tensor:
85
+ continue
86
+ tensors[new_key] = f.get_tensor(k)
87
+ if 'vision_model.embeddings.position_embedding' in new_key:
88
+ tensor = tensors[new_key][0]
89
+ new_tensor = torch.zeros((tensor.shape[0]+1, tensor.shape[1]))
90
+ new_tensor[:tensor.shape[0], :] = tensor
91
+ new_tensor[-1, :] = tensor[-1,:]
92
+ tensors[new_key] = new_tensor
93
+ # siglip doesn't seem to have any pre norm layer so we have to make it identity for now
94
+ tensors['vision_model.pre_layrnorm.weight'] = torch.ones(tensors['vision_model.pre_layrnorm.bias'].shape,
95
+ dtype=tensors['vision_model.pre_layrnorm.bias'].dtype,
96
+ device=tensors['vision_model.pre_layrnorm.bias'].device)
97
+ # this wasn't used
98
+ tensors['vision_model.embeddings.class_embedding'] = tensor = torch.normal(mean=0.0, std=0.02,
99
+ size=tensors['vision_model.pre_layrnorm.bias'].shape)
100
+ return tensors