olm-chat-7b / open_lm /utils /convert_llama.py
henhenhahi111112's picture
Upload folder using huggingface_hub
af6e330 verified
raw
history blame
2.66 kB
"""
This script converts the weights from LLAMA to OpenLM compatible weights.
Usage: `python convert_llama_to_openlm.py <llama_weight_path> <openlm_weight_path>`
"""
import torch
import sys
def convert(llama_state_dict: dict) -> dict:
openlm_state_dict = {}
n_layer = len(set([key.split(".")[1] for key in llama_state_dict if "layers." in key]))
print(f"n_layer: {n_layer}")
for key in ["tok_embeddings.weight", "norm.weight", "output.weight"]:
value = llama_state_dict[key]
assert key not in openlm_state_dict
openlm_state_dict[key] = value
for i in range(n_layer):
src_key_1, src_key_2, src_key_3 = (
f"layers.{i}.attention.wq.weight",
f"layers.{i}.attention.wk.weight",
f"layers.{i}.attention.wv.weight",
)
tgt_key = f"layers.{i}.attention.in_proj.weight"
assert tgt_key not in openlm_state_dict
openlm_state_dict[tgt_key] = torch.cat(
[
llama_state_dict[src_key_1],
llama_state_dict[src_key_2],
llama_state_dict[src_key_3],
],
dim=0,
)
src_key = f"layers.{i}.attention.wo.weight"
tgt_key = f"layers.{i}.attention.out_proj.weight"
assert tgt_key not in openlm_state_dict
openlm_state_dict[tgt_key] = llama_state_dict[src_key]
src_key_1, src_key_2 = (
f"layers.{i}.feed_forward.w1.weight",
f"layers.{i}.feed_forward.w3.weight",
)
tgt_key = f"layers.{i}.feed_forward.w12.weight"
assert tgt_key not in openlm_state_dict
openlm_state_dict[tgt_key] = torch.cat([llama_state_dict[src_key_1], llama_state_dict[src_key_2]], dim=0)
src_key = f"layers.{i}.feed_forward.w2.weight"
tgt_key = f"layers.{i}.feed_forward.w3.weight"
assert tgt_key not in openlm_state_dict
openlm_state_dict[tgt_key] = llama_state_dict[src_key]
tgt_key = f"layers.{i}.attention_norm.weight"
assert tgt_key not in openlm_state_dict
openlm_state_dict[tgt_key] = llama_state_dict[tgt_key]
tgt_key = f"layers.{i}.ffn_norm.weight"
assert tgt_key not in openlm_state_dict
openlm_state_dict[tgt_key] = llama_state_dict[tgt_key]
return openlm_state_dict
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: `python convert_llama_to_openlm.py <llama_weight_path> <openlm_weight_path>`")
sys.exit(1)
llama_state_dict = torch.load(sys.argv[1])
openlm_state_dict = {"state_dict": convert(llama_state_dict)}
torch.save(openlm_state_dict, sys.argv[2])