File size: 10,061 Bytes
19e72b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import fire
from safetensors.torch import save_file
import os


def save_model_in_chunks(state_dict, directory, num_parts):
    total_size = sum(tensor.nelement() * tensor.element_size() for tensor in state_dict.values())
    max_size = total_size // num_parts + (total_size % num_parts > 0)  # Ensure each part is roughly of max_size

    current_size = 0
    part_number = 1
    current_dict = {}

    for key, tensor in state_dict.items():
        tensor_size = tensor.element_size() * tensor.nelement()
        if current_size + tensor_size > max_size and part_number < num_parts:
            save_model(current_dict, os.path.join(directory,
                                                 f'model-{str(part_number).zfill(5)}-of-{str(num_parts).zfill(5)}.safetensors'))
            current_dict = {}
            current_size = 0
            part_number += 1
        current_dict[key] = tensor
        current_size += tensor_size

    # Save the last part
    if current_dict:
        save_model(current_dict, os.path.join(directory,
                                             f'model-{str(part_number).zfill(5)}-of-{str(num_parts).zfill(5)}.safetensors'))

def vlm(

        hf_dir: str = '/share/home/zyx/Models/cogvlm-1',

        sat_dir: str = '/share/wwh/cogvlm2_sat',

):
    import os
    import json
    import torch
    from pathlib import Path
    Path(hf_dir).mkdir(exist_ok=True)

    # state dict
    print("Loading state dict")
    state_dict = torch.load(os.path.expanduser(os.path.join(sat_dir, '10000', 'mp_rank_00_model_states.pt')),
                            map_location='cpu')
    state_dict = state_dict['module']
    new_state_dict = {}
    for k, v in state_dict.items():
        print(k)
        if k.startswith('mixins.eva.vit_model.mixins.patch_embedding'):
            new_state_dict[k.replace('mixins.eva.vit_model.mixins.', '', 1)] = v
        elif k.startswith('mixins.eva.vit_model.transformer.position_embeddings'):
            new_state_dict[
                k.replace('mixins.eva.vit_model.transformer.position_embeddings', 'patch_embedding.position_embedding',
                          1)] = v
        elif k.startswith('mixins.eva.vit_model.transformer.layers'):
            k = k.replace('mlp.dense_4h_to_h', 'mlp.fc2').replace('mlp.dense_h_to_4h', 'mlp.fc1')
            new_state_dict[k.replace('mixins.eva.vit_model.transformer.layers', 'transformer.layers', 1)] = v
        elif k.startswith('mixins.eva.linear_proj'):
            new_state_dict[k.replace('mixins.eva.linear_proj', 'linear_proj', 1)] = v
        elif k.startswith('mixins.eva.conv'):
            new_state_dict[k.replace('mixins.eva.conv', 'conv', 1)] = v
        elif k in ['mixins.eva.vit_model.transformer.word_embeddings.weight']:
            new_state_dict['patch_embedding.cls_embedding'] = v
        elif k in ['mixins.eva.boi', 'mixins.eva.eoi']:
            new_state_dict[k.replace('mixins.eva.', '', 1)] = v
        else:
            assert not str(k).startswith('mixins.eva'), f"{k}"

    vision_state_dict = {f"model.vision.{k}": v for k, v in new_state_dict.items()}
    new_state_dict = {}
    for k, v in state_dict.items():
        if k == 'mixins.lm.lm_head.weight':
            new_state_dict['lm_head.weight'] = v
        elif k.startswith("mixins.eva"):
            continue
        # mlp
        elif k.startswith('mixins.mlp.vision_dense_h_to_4h_list.') and str(k).endswith('.weight'):
            idx = str(k).replace('mixins.mlp.vision_dense_h_to_4h_list.', '').replace('.weight', '')
            new_state_dict[f"model.layers.{idx}.mlp.vision_mlp.up_proj.weight"] = v
        elif k.startswith('mixins.mlp.vision_dense_4h_to_h_list.') and str(k).endswith('.weight'):
            idx = str(k).replace('mixins.mlp.vision_dense_4h_to_h_list.', '').replace('.weight', '')
            new_state_dict[f"model.layers.{idx}.mlp.vision_mlp.down_proj.weight"] = v
        elif k.startswith('mixins.mlp.vision_gate_proj.') and str(k).endswith('.weight'):
            idx = str(k).replace('mixins.mlp.vision_gate_proj.', '').replace('.weight', '')
            new_state_dict[f"model.layers.{idx}.mlp.vision_mlp.gate_proj.weight"] = v

        elif k.startswith('mixins.mlp.gate_proj.') and str(k).endswith('.weight'):
            idx = str(k).replace('mixins.mlp.gate_proj.', '').replace('.weight', '')
            new_state_dict[f"model.layers.{idx}.mlp.language_mlp.gate_proj.weight"] = v
        elif k.startswith('transformer.layers.') and str(k).endswith('.mlp.dense_h_to_4h.weight'):
            idx = str(k).replace('transformer.layers.', '').replace('.mlp.dense_h_to_4h.weight', '')
            new_state_dict[f"model.layers.{idx}.mlp.language_mlp.up_proj.weight"] = v
        elif k.startswith('transformer.layers.') and str(k).endswith('.mlp.dense_4h_to_h.weight'):
            idx = str(k).replace('transformer.layers.', '').replace('.mlp.dense_4h_to_h.weight', '')
            new_state_dict[f"model.layers.{idx}.mlp.language_mlp.down_proj.weight"] = v
        # attn
        elif k.startswith('transformer.layers.') and str(k).endswith('.attention.query_key_value.weight'):
            idx = str(k).replace('transformer.layers.', '').replace('.attention.query_key_value.weight', '')
            new_state_dict[f"model.layers.{idx}.self_attn.language_expert_query_key_value.weight"] = v
        elif k.startswith('transformer.layers.') and str(k).endswith('.attention.dense.weight'):
            idx = str(k).replace('transformer.layers.', '').replace('.attention.dense.weight', '')
            new_state_dict[f"model.layers.{idx}.self_attn.language_expert_dense.weight"] = v

        elif k.startswith('mixins.rotary.vision_query_key_value_list.') and str(k).endswith('.weight'):
            idx = str(k).replace('mixins.rotary.vision_query_key_value_list.', '').replace('.weight', '')
            new_state_dict[f"model.layers.{idx}.self_attn.vision_expert_query_key_value.weight"] = v
        elif k.startswith('mixins.rotary.vision_dense_list.') and str(k).endswith('.weight'):
            idx = str(k).replace('mixins.rotary.vision_dense_list.', '').replace('.weight', '')
            new_state_dict[f"model.layers.{idx}.self_attn.vision_expert_dense.weight"] = v
        elif k.startswith('mixins.rotary.vision_query_key_value_list.') and str(k).endswith('.weight'):
            idx = str(k).replace('mixins.rotary.vision_query_key_value_list.', '').replace('.weight', '')
            new_state_dict[f"model.layers.{idx}.self_attn.vision_expert_query_key_value.weight"] = v
        elif k.startswith('mixins.rotary.vision_query_key_value_list.') and str(k).endswith('.bias'):
            idx = str(k).replace('mixins.rotary.vision_query_key_value_list.', '').replace('.bias', '')
            new_state_dict[f"model.layers.{idx}.self_attn.vision_expert_query_key_value.bias"] = v

        elif k.startswith('transformer.layers.') and str(k).endswith('.input_layernorm.weight'):
            idx = str(k).replace('transformer.layers.', '').replace('.input_layernorm.weight', '')
            new_state_dict[f"model.layers.{idx}.input_layernorm.weight"] = v
        elif k.startswith('transformer.layers.') and str(k).endswith('.post_attention_layernorm.weight'):
            idx = str(k).replace('transformer.layers.', '').replace('.post_attention_layernorm.weight', '')
            new_state_dict[f"model.layers.{idx}.post_attention_layernorm.weight"] = v

        elif k == 'transformer.word_embeddings.weight':
            new_state_dict[f"model.embed_tokens.weight"] = v
        elif k == 'transformer.final_layernorm.weight':
            new_state_dict[f"model.norm.weight"] = v
        elif k == 'mixins.rotary.rotary_emb.inv_freq':
            for idx in range(32):
                new_state_dict[f"model.layers.{idx}.self_attn.rotary_emb.inv_freq"] = v
        else:
            assert False, f"{k}"
    new_state_dict.update(vision_state_dict)
    # save_model_in_chunks(new_state_dict, hf_dir)
    save_file(new_state_dict, "model.safetensors")
    # configs
    config = json.load(open(os.path.expanduser(os.path.join(sat_dir, 'model_config.json'))))
    vision_config = {
        'dropout_prob': 0.0,
        'hidden_act': 'gelu',
        'in_channels': config['eva_args']['in_channels'],
        'num_hidden_layers': config['eva_args']['num_layers'],
        'hidden_size': config['eva_args']['hidden_size'],
        'patch_size': config['eva_args']['patch_size'],
        'num_heads': config['eva_args']['num_attention_heads'],
        'intermediate_size': config['eva_args']['inner_hidden_size'],
        'layer_norm_eps': config['eva_args']['layernorm_epsilon'],
        'num_positions': int(1 + (config['eva_args']['image_size'][0] / config['eva_args']['patch_size']) * (
                config['eva_args']['image_size'][0] / config['eva_args']['patch_size'])),
        #
        'image_size': config['eva_args']['image_size'][0],
        #
        # 'use_final_layernorm': config['eva_args']['use_final_layernorm'],
        # 'layernorm_order': config['eva_args']['layernorm_order'],
    }

    final_config = {
        'vision_config': vision_config,
        'hidden_size': config['hidden_size'],
        #
        'intermediate_size': config['inner_hidden_size'],
        'num_attention_heads': config['num_attention_heads'],
        'max_position_embeddings': 8192,
        'rms_norm_eps': 1e-5,
        'template_version': 'chat' if 'chat' in sat_dir else 'base',
        'initializer_range': 0.02,
        'pad_token_id': 128002,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        #
        'vocab_size': config['vocab_size'],
        'num_hidden_layers': config['num_layers'],
        'hidden_act': 'silu',
        'use_cache': True,
    }
    with open(os.path.join(hf_dir, 'config.json'), 'w') as f:
        json.dump(final_config, f, indent=2)


if __name__ == '__main__':
    fire.Fire()