Baichuan2-13B-QKV / baichuan_reformat.py
larryvrh's picture
Upload 2 files
0aa029b
raw
history blame contribute delete
No virus
873 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import OrderedDict
import torch
model_path = 'model_path'
out_path = 'out_path'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
new_dict = OrderedDict()
for k,v in model.state_dict().items():
if not 'self_attn.W_pack' in k:
new_dict[k] = v
continue
name_base = k[:k.find('W_pack.weight')]
q,k,v = [v[model.config.hidden_size*i:model.config.hidden_size*(i+1),:] for i in range(3)]
new_dict[name_base + 'q_proj.weight'] = q
new_dict[name_base + 'k_proj.weight'] = k
new_dict[name_base + 'v_proj.weight'] = v
model.save_pretrained(out_path, state_dict=new_dict)
tokenizer.save_pretrained(out_path)