from transformers import AutoModelForCausalLM, AutoTokenizer from collections import OrderedDict import torch model_path = 'model_path' out_path = 'out_path' model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) new_dict = OrderedDict() for k,v in model.state_dict().items(): if not 'self_attn.W_pack' in k: new_dict[k] = v continue name_base = k[:k.find('W_pack.weight')] q,k,v = [v[model.config.hidden_size*i:model.config.hidden_size*(i+1),:] for i in range(3)] new_dict[name_base + 'q_proj.weight'] = q new_dict[name_base + 'k_proj.weight'] = k new_dict[name_base + 'v_proj.weight'] = v model.save_pretrained(out_path, state_dict=new_dict) tokenizer.save_pretrained(out_path)