File size: 6,682 Bytes
f45ceac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from safetensors import safe_open

def transform(open_clip_safe_tensor_path):
    tensors = {}
    with safe_open(open_clip_safe_tensor_path, framework="pt", device=0) as f:
        metadata = f.metadata()
        for k in f.keys():
            ignore_tensor = False
            first_prefix = k.replace('visual.', 'vision_model.').replace('text.', 'text_model.')
            new_key = first_prefix.replace('.trunk.', '.encoder.')
            new_key = new_key.replace('.blocks.', '.layers.')
            new_key = new_key.replace('.transformer.resblocks.', '.encoder.layers.')
            if 'vision' in new_key:
                new_key = new_key.replace('.self_attn.out_proj.', '.attn.proj.')
                new_key = new_key.replace('.norm', '.layer_norm')
                # mappings extracted from timm code
                # ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2')
                new_key = new_key.replace('.proj.c_fc', '.mlp.fc1')
                new_key = new_key.replace('.proj.c_proj', '.mlp.fc2')
                new_key = new_key.replace('.attn.proj', '.self_attn.out_proj')
                if 'qkv' in new_key:
                    qkv_weight = f.get_tensor(k)
                    q, k, v = torch.chunk(qkv_weight, 3, dim=0)
                    tensors[new_key.replace('.attn.qkv', '.self_attn.q_proj')] = q.clone().detach()
                    tensors[new_key.replace('.attn.qkv', '.self_attn.k_proj')] = k.clone().detach()
                    tensors[new_key.replace('.attn.qkv', '.self_attn.v_proj')] = v.clone().detach()
                    ignore_tensor = True
                # ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
                #             ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
                #             ("pos_embed", "vit.embeddings.position_embeddings"),
                # ['vision_model.embeddings.patch_embedding.weight', 
                #  'vision_model.post_layernorm.weight', 'vision_model.embeddings.position_embedding.weight', 
                # 'vision_model.embeddings.class_embedding', 'vision_model.pre_layrnorm.weight', 
                # 'vision_model.pre_layrnorm.bias', 'vision_model.post_layernorm.bias']
                #             vision_model.encoder.layer_norm.bias
                #             vision_model.encoder.layer_norm.weight
                #             vision_model.encoder.patch_embed.proj.bias
                #             vision_model.encoder.patch_embed.proj.weight
                #             vision_model.encoder.pos_embed
                replacement_keys = [
                    ('vision_model.encoder.patch_embed.proj.weight', 'vision_model.embeddings.patch_embedding.weight'), 
                    ('vision_model.encoder.pos_embed', 'vision_model.embeddings.position_embedding.weight'),
                    ('vision_model.encoder.patch_embed.proj.bias', 'vision_model.pre_layrnorm.bias'),
                    ('vision_model.encoder.layer_norm.bias', 'vision_model.post_layernorm.bias'),
                    ('vision_model.encoder.layer_norm.weight', 'vision_model.post_layernorm.weight'),
                ]
                for old_, new_ in replacement_keys:
                    if old_ in new_key:
                        new_key = new_key.replace(old_, new_)
            elif 'text' in new_key:
                # text_model.encoder.layers.0.ln_1.bias -> 
                # text_model.encoder.layers.0.layer_norm1.bias
                # text_model.encoder.layers.1.mlp.c_fc.bias ->
                # text_model.encoder.layers.11.mlp.fc1.weight
                new_key = new_key.replace('.ln_2.', '.layer_norm2.')
                new_key = new_key.replace('.ln_1.', '.layer_norm1.')
                new_key = new_key.replace('.mlp.c_fc', '.mlp.fc1')
                new_key = new_key.replace('.mlp.c_proj', '.mlp.fc2')
                new_key = new_key.replace('.attn.in_proj_', '.self_attn.qkv.')
                new_key = new_key.replace('.attn.out_proj', '.self_attn.out_proj')
                if 'qkv' in new_key:
                    # text_model.encoder.layers.0.self_attn.qkv.weight
                    # text_model.encoder.layers.4.self_attn.v_proj.weight
                    qkv_weight = f.get_tensor(k)
                    q, k, v = torch.chunk(qkv_weight, 3, dim=0)
                    tensors[new_key.replace('.self_attn.qkv', '.self_attn.q_proj')] = q.clone().detach()
                    tensors[new_key.replace('.self_attn.qkv', '.self_attn.k_proj')] = k.clone().detach()
                    tensors[new_key.replace('.self_attn.qkv', '.self_attn.v_proj')] = v.clone().detach()
                    ignore_tensor = True
                replacement_keys = [
                    ('text_model.positional_embedding', 'text_model.embeddings.position_embedding.weight'),
                    ('text_model.token_embedding.weight', 'text_model.embeddings.token_embedding.weight'),
                    ('text_model.ln_final.bias', 'text_model.final_layer_norm.bias'),
                    ('text_model.ln_final.weight', 'text_model.final_layer_norm.weight'),                
                    ('text_model.text_projection.weight', 'text_projection.weight'),
                ]
                for old_, new_ in replacement_keys:
                    if old_ in new_key:
                        new_key = new_key.replace(old_, new_)
            if 'vision' in new_key and 'img_projector' in new_key:
                print(new_key)

            if ignore_tensor:
                continue
            tensors[new_key] = f.get_tensor(k)
            if 'vision_model.embeddings.position_embedding' in new_key:
                tensor = tensors[new_key][0]
                new_tensor = torch.zeros((tensor.shape[0]+1, tensor.shape[1]))
                new_tensor[:tensor.shape[0], :] = tensor
                new_tensor[-1, :] = tensor[-1,:]
                tensors[new_key] = new_tensor
    # siglip doesn't seem to have any pre norm layer so we have to make it identity for now
    tensors['vision_model.pre_layrnorm.weight'] = torch.ones(tensors['vision_model.pre_layrnorm.bias'].shape, 
                                                             dtype=tensors['vision_model.pre_layrnorm.bias'].dtype,
                                                            device=tensors['vision_model.pre_layrnorm.bias'].device)
    # this wasn't used
    tensors['vision_model.embeddings.class_embedding'] = tensor = torch.normal(mean=0.0, std=0.02, 
                                                                                   size=tensors['vision_model.pre_layrnorm.bias'].shape)
    return tensors