Upload transform.py
Browse files- transform.py +100 -0
transform.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from safetensors import safe_open
|
3 |
+
|
4 |
+
def transform(open_clip_safe_tensor_path):
|
5 |
+
tensors = {}
|
6 |
+
with safe_open(open_clip_safe_tensor_path, framework="pt", device=0) as f:
|
7 |
+
metadata = f.metadata()
|
8 |
+
for k in f.keys():
|
9 |
+
ignore_tensor = False
|
10 |
+
first_prefix = k.replace('visual.', 'vision_model.').replace('text.', 'text_model.')
|
11 |
+
new_key = first_prefix.replace('.trunk.', '.encoder.')
|
12 |
+
new_key = new_key.replace('.blocks.', '.layers.')
|
13 |
+
new_key = new_key.replace('.transformer.resblocks.', '.encoder.layers.')
|
14 |
+
if 'vision' in new_key:
|
15 |
+
new_key = new_key.replace('.self_attn.out_proj.', '.attn.proj.')
|
16 |
+
new_key = new_key.replace('.norm', '.layer_norm')
|
17 |
+
# mappings extracted from timm code
|
18 |
+
# ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2')
|
19 |
+
new_key = new_key.replace('.proj.c_fc', '.mlp.fc1')
|
20 |
+
new_key = new_key.replace('.proj.c_proj', '.mlp.fc2')
|
21 |
+
new_key = new_key.replace('.attn.proj', '.self_attn.out_proj')
|
22 |
+
if 'qkv' in new_key:
|
23 |
+
qkv_weight = f.get_tensor(k)
|
24 |
+
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
|
25 |
+
tensors[new_key.replace('.attn.qkv', '.self_attn.q_proj')] = q.clone().detach()
|
26 |
+
tensors[new_key.replace('.attn.qkv', '.self_attn.k_proj')] = k.clone().detach()
|
27 |
+
tensors[new_key.replace('.attn.qkv', '.self_attn.v_proj')] = v.clone().detach()
|
28 |
+
ignore_tensor = True
|
29 |
+
# ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
|
30 |
+
# ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
|
31 |
+
# ("pos_embed", "vit.embeddings.position_embeddings"),
|
32 |
+
# ['vision_model.embeddings.patch_embedding.weight',
|
33 |
+
# 'vision_model.post_layernorm.weight', 'vision_model.embeddings.position_embedding.weight',
|
34 |
+
# 'vision_model.embeddings.class_embedding', 'vision_model.pre_layrnorm.weight',
|
35 |
+
# 'vision_model.pre_layrnorm.bias', 'vision_model.post_layernorm.bias']
|
36 |
+
# vision_model.encoder.layer_norm.bias
|
37 |
+
# vision_model.encoder.layer_norm.weight
|
38 |
+
# vision_model.encoder.patch_embed.proj.bias
|
39 |
+
# vision_model.encoder.patch_embed.proj.weight
|
40 |
+
# vision_model.encoder.pos_embed
|
41 |
+
replacement_keys = [
|
42 |
+
('vision_model.encoder.patch_embed.proj.weight', 'vision_model.embeddings.patch_embedding.weight'),
|
43 |
+
('vision_model.encoder.pos_embed', 'vision_model.embeddings.position_embedding.weight'),
|
44 |
+
('vision_model.encoder.patch_embed.proj.bias', 'vision_model.pre_layrnorm.bias'),
|
45 |
+
('vision_model.encoder.layer_norm.bias', 'vision_model.post_layernorm.bias'),
|
46 |
+
('vision_model.encoder.layer_norm.weight', 'vision_model.post_layernorm.weight'),
|
47 |
+
]
|
48 |
+
for old_, new_ in replacement_keys:
|
49 |
+
if old_ in new_key:
|
50 |
+
new_key = new_key.replace(old_, new_)
|
51 |
+
elif 'text' in new_key:
|
52 |
+
# text_model.encoder.layers.0.ln_1.bias ->
|
53 |
+
# text_model.encoder.layers.0.layer_norm1.bias
|
54 |
+
# text_model.encoder.layers.1.mlp.c_fc.bias ->
|
55 |
+
# text_model.encoder.layers.11.mlp.fc1.weight
|
56 |
+
new_key = new_key.replace('.ln_2.', '.layer_norm2.')
|
57 |
+
new_key = new_key.replace('.ln_1.', '.layer_norm1.')
|
58 |
+
new_key = new_key.replace('.mlp.c_fc', '.mlp.fc1')
|
59 |
+
new_key = new_key.replace('.mlp.c_proj', '.mlp.fc2')
|
60 |
+
new_key = new_key.replace('.attn.in_proj_', '.self_attn.qkv.')
|
61 |
+
new_key = new_key.replace('.attn.out_proj', '.self_attn.out_proj')
|
62 |
+
if 'qkv' in new_key:
|
63 |
+
# text_model.encoder.layers.0.self_attn.qkv.weight
|
64 |
+
# text_model.encoder.layers.4.self_attn.v_proj.weight
|
65 |
+
qkv_weight = f.get_tensor(k)
|
66 |
+
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
|
67 |
+
tensors[new_key.replace('.self_attn.qkv', '.self_attn.q_proj')] = q.clone().detach()
|
68 |
+
tensors[new_key.replace('.self_attn.qkv', '.self_attn.k_proj')] = k.clone().detach()
|
69 |
+
tensors[new_key.replace('.self_attn.qkv', '.self_attn.v_proj')] = v.clone().detach()
|
70 |
+
ignore_tensor = True
|
71 |
+
replacement_keys = [
|
72 |
+
('text_model.positional_embedding', 'text_model.embeddings.position_embedding.weight'),
|
73 |
+
('text_model.token_embedding.weight', 'text_model.embeddings.token_embedding.weight'),
|
74 |
+
('text_model.ln_final.bias', 'text_model.final_layer_norm.bias'),
|
75 |
+
('text_model.ln_final.weight', 'text_model.final_layer_norm.weight'),
|
76 |
+
('text_model.text_projection.weight', 'text_projection.weight'),
|
77 |
+
]
|
78 |
+
for old_, new_ in replacement_keys:
|
79 |
+
if old_ in new_key:
|
80 |
+
new_key = new_key.replace(old_, new_)
|
81 |
+
if 'vision' in new_key and 'img_projector' in new_key:
|
82 |
+
print(new_key)
|
83 |
+
|
84 |
+
if ignore_tensor:
|
85 |
+
continue
|
86 |
+
tensors[new_key] = f.get_tensor(k)
|
87 |
+
if 'vision_model.embeddings.position_embedding' in new_key:
|
88 |
+
tensor = tensors[new_key][0]
|
89 |
+
new_tensor = torch.zeros((tensor.shape[0]+1, tensor.shape[1]))
|
90 |
+
new_tensor[:tensor.shape[0], :] = tensor
|
91 |
+
new_tensor[-1, :] = tensor[-1,:]
|
92 |
+
tensors[new_key] = new_tensor
|
93 |
+
# siglip doesn't seem to have any pre norm layer so we have to make it identity for now
|
94 |
+
tensors['vision_model.pre_layrnorm.weight'] = torch.ones(tensors['vision_model.pre_layrnorm.bias'].shape,
|
95 |
+
dtype=tensors['vision_model.pre_layrnorm.bias'].dtype,
|
96 |
+
device=tensors['vision_model.pre_layrnorm.bias'].device)
|
97 |
+
# this wasn't used
|
98 |
+
tensors['vision_model.embeddings.class_embedding'] = tensor = torch.normal(mean=0.0, std=0.02,
|
99 |
+
size=tensors['vision_model.pre_layrnorm.bias'].shape)
|
100 |
+
return tensors
|