baichuan-llama-7B-GPTQ / convert_baichuan_to_llama.py
TheBloke's picture
Initial GPTQ model commit
5382d51
raw
history blame contribute delete
No virus
467 Bytes
from collections import OrderedDict
import torch
baichuan = torch.load("pytorch_model.bin")
llama = OrderedDict()
for key in baichuan:
if 'W_pack' in key:
llama[key.replace('W_pack', 'q_proj')] = baichuan[key][:4096]
llama[key.replace('W_pack', 'k_proj')] = baichuan[key][4096:4096 * 2]
llama[key.replace('W_pack', 'v_proj')] = baichuan[key][4096 * 2:]
else:
llama[key] = baichuan[key]
torch.save(llama, "pytorch_model.bin")