import fire from safetensors.torch import save_file import os def save_model_in_chunks(state_dict, directory, num_parts): total_size = sum(tensor.nelement() * tensor.element_size() for tensor in state_dict.values()) max_size = total_size // num_parts + (total_size % num_parts > 0) # Ensure each part is roughly of max_size current_size = 0 part_number = 1 current_dict = {} for key, tensor in state_dict.items(): tensor_size = tensor.element_size() * tensor.nelement() if current_size + tensor_size > max_size and part_number < num_parts: save_model(current_dict, os.path.join(directory, f'model-{str(part_number).zfill(5)}-of-{str(num_parts).zfill(5)}.safetensors')) current_dict = {} current_size = 0 part_number += 1 current_dict[key] = tensor current_size += tensor_size # Save the last part if current_dict: save_model(current_dict, os.path.join(directory, f'model-{str(part_number).zfill(5)}-of-{str(num_parts).zfill(5)}.safetensors')) def vlm( hf_dir: str = '/share/home/zyx/Models/cogvlm-1', sat_dir: str = '/share/wwh/cogvlm2_sat', ): import os import json import torch from pathlib import Path Path(hf_dir).mkdir(exist_ok=True) # state dict print("Loading state dict") state_dict = torch.load(os.path.expanduser(os.path.join(sat_dir, '10000', 'mp_rank_00_model_states.pt')), map_location='cpu') state_dict = state_dict['module'] new_state_dict = {} for k, v in state_dict.items(): print(k) if k.startswith('mixins.eva.vit_model.mixins.patch_embedding'): new_state_dict[k.replace('mixins.eva.vit_model.mixins.', '', 1)] = v elif k.startswith('mixins.eva.vit_model.transformer.position_embeddings'): new_state_dict[ k.replace('mixins.eva.vit_model.transformer.position_embeddings', 'patch_embedding.position_embedding', 1)] = v elif k.startswith('mixins.eva.vit_model.transformer.layers'): k = k.replace('mlp.dense_4h_to_h', 'mlp.fc2').replace('mlp.dense_h_to_4h', 'mlp.fc1') new_state_dict[k.replace('mixins.eva.vit_model.transformer.layers', 'transformer.layers', 1)] = v elif k.startswith('mixins.eva.linear_proj'): new_state_dict[k.replace('mixins.eva.linear_proj', 'linear_proj', 1)] = v elif k.startswith('mixins.eva.conv'): new_state_dict[k.replace('mixins.eva.conv', 'conv', 1)] = v elif k in ['mixins.eva.vit_model.transformer.word_embeddings.weight']: new_state_dict['patch_embedding.cls_embedding'] = v elif k in ['mixins.eva.boi', 'mixins.eva.eoi']: new_state_dict[k.replace('mixins.eva.', '', 1)] = v else: assert not str(k).startswith('mixins.eva'), f"{k}" vision_state_dict = {f"model.vision.{k}": v for k, v in new_state_dict.items()} new_state_dict = {} for k, v in state_dict.items(): if k == 'mixins.lm.lm_head.weight': new_state_dict['lm_head.weight'] = v elif k.startswith("mixins.eva"): continue # mlp elif k.startswith('mixins.mlp.vision_dense_h_to_4h_list.') and str(k).endswith('.weight'): idx = str(k).replace('mixins.mlp.vision_dense_h_to_4h_list.', '').replace('.weight', '') new_state_dict[f"model.layers.{idx}.mlp.vision_mlp.up_proj.weight"] = v elif k.startswith('mixins.mlp.vision_dense_4h_to_h_list.') and str(k).endswith('.weight'): idx = str(k).replace('mixins.mlp.vision_dense_4h_to_h_list.', '').replace('.weight', '') new_state_dict[f"model.layers.{idx}.mlp.vision_mlp.down_proj.weight"] = v elif k.startswith('mixins.mlp.vision_gate_proj.') and str(k).endswith('.weight'): idx = str(k).replace('mixins.mlp.vision_gate_proj.', '').replace('.weight', '') new_state_dict[f"model.layers.{idx}.mlp.vision_mlp.gate_proj.weight"] = v elif k.startswith('mixins.mlp.gate_proj.') and str(k).endswith('.weight'): idx = str(k).replace('mixins.mlp.gate_proj.', '').replace('.weight', '') new_state_dict[f"model.layers.{idx}.mlp.language_mlp.gate_proj.weight"] = v elif k.startswith('transformer.layers.') and str(k).endswith('.mlp.dense_h_to_4h.weight'): idx = str(k).replace('transformer.layers.', '').replace('.mlp.dense_h_to_4h.weight', '') new_state_dict[f"model.layers.{idx}.mlp.language_mlp.up_proj.weight"] = v elif k.startswith('transformer.layers.') and str(k).endswith('.mlp.dense_4h_to_h.weight'): idx = str(k).replace('transformer.layers.', '').replace('.mlp.dense_4h_to_h.weight', '') new_state_dict[f"model.layers.{idx}.mlp.language_mlp.down_proj.weight"] = v # attn elif k.startswith('transformer.layers.') and str(k).endswith('.attention.query_key_value.weight'): idx = str(k).replace('transformer.layers.', '').replace('.attention.query_key_value.weight', '') new_state_dict[f"model.layers.{idx}.self_attn.language_expert_query_key_value.weight"] = v elif k.startswith('transformer.layers.') and str(k).endswith('.attention.dense.weight'): idx = str(k).replace('transformer.layers.', '').replace('.attention.dense.weight', '') new_state_dict[f"model.layers.{idx}.self_attn.language_expert_dense.weight"] = v elif k.startswith('mixins.rotary.vision_query_key_value_list.') and str(k).endswith('.weight'): idx = str(k).replace('mixins.rotary.vision_query_key_value_list.', '').replace('.weight', '') new_state_dict[f"model.layers.{idx}.self_attn.vision_expert_query_key_value.weight"] = v elif k.startswith('mixins.rotary.vision_dense_list.') and str(k).endswith('.weight'): idx = str(k).replace('mixins.rotary.vision_dense_list.', '').replace('.weight', '') new_state_dict[f"model.layers.{idx}.self_attn.vision_expert_dense.weight"] = v elif k.startswith('mixins.rotary.vision_query_key_value_list.') and str(k).endswith('.weight'): idx = str(k).replace('mixins.rotary.vision_query_key_value_list.', '').replace('.weight', '') new_state_dict[f"model.layers.{idx}.self_attn.vision_expert_query_key_value.weight"] = v elif k.startswith('mixins.rotary.vision_query_key_value_list.') and str(k).endswith('.bias'): idx = str(k).replace('mixins.rotary.vision_query_key_value_list.', '').replace('.bias', '') new_state_dict[f"model.layers.{idx}.self_attn.vision_expert_query_key_value.bias"] = v elif k.startswith('transformer.layers.') and str(k).endswith('.input_layernorm.weight'): idx = str(k).replace('transformer.layers.', '').replace('.input_layernorm.weight', '') new_state_dict[f"model.layers.{idx}.input_layernorm.weight"] = v elif k.startswith('transformer.layers.') and str(k).endswith('.post_attention_layernorm.weight'): idx = str(k).replace('transformer.layers.', '').replace('.post_attention_layernorm.weight', '') new_state_dict[f"model.layers.{idx}.post_attention_layernorm.weight"] = v elif k == 'transformer.word_embeddings.weight': new_state_dict[f"model.embed_tokens.weight"] = v elif k == 'transformer.final_layernorm.weight': new_state_dict[f"model.norm.weight"] = v elif k == 'mixins.rotary.rotary_emb.inv_freq': for idx in range(32): new_state_dict[f"model.layers.{idx}.self_attn.rotary_emb.inv_freq"] = v else: assert False, f"{k}" new_state_dict.update(vision_state_dict) # save_model_in_chunks(new_state_dict, hf_dir) save_file(new_state_dict, "model.safetensors") # configs config = json.load(open(os.path.expanduser(os.path.join(sat_dir, 'model_config.json')))) vision_config = { 'dropout_prob': 0.0, 'hidden_act': 'gelu', 'in_channels': config['eva_args']['in_channels'], 'num_hidden_layers': config['eva_args']['num_layers'], 'hidden_size': config['eva_args']['hidden_size'], 'patch_size': config['eva_args']['patch_size'], 'num_heads': config['eva_args']['num_attention_heads'], 'intermediate_size': config['eva_args']['inner_hidden_size'], 'layer_norm_eps': config['eva_args']['layernorm_epsilon'], 'num_positions': int(1 + (config['eva_args']['image_size'][0] / config['eva_args']['patch_size']) * ( config['eva_args']['image_size'][0] / config['eva_args']['patch_size'])), # 'image_size': config['eva_args']['image_size'][0], # # 'use_final_layernorm': config['eva_args']['use_final_layernorm'], # 'layernorm_order': config['eva_args']['layernorm_order'], } final_config = { 'vision_config': vision_config, 'hidden_size': config['hidden_size'], # 'intermediate_size': config['inner_hidden_size'], 'num_attention_heads': config['num_attention_heads'], 'max_position_embeddings': 8192, 'rms_norm_eps': 1e-5, 'template_version': 'chat' if 'chat' in sat_dir else 'base', 'initializer_range': 0.02, 'pad_token_id': 128002, "bos_token_id": 128000, "eos_token_id": 128001, # 'vocab_size': config['vocab_size'], 'num_hidden_layers': config['num_layers'], 'hidden_act': 'silu', 'use_cache': True, } with open(os.path.join(hf_dir, 'config.json'), 'w') as f: json.dump(final_config, f, indent=2) if __name__ == '__main__': fire.Fire()