""" This script converts the weights from LLAMA to OpenLM compatible weights. Usage: `python convert_llama_to_openlm.py ` """ import torch import sys def convert(llama_state_dict: dict) -> dict: openlm_state_dict = {} n_layer = len(set([key.split(".")[1] for key in llama_state_dict if "layers." in key])) print(f"n_layer: {n_layer}") for key in ["tok_embeddings.weight", "norm.weight", "output.weight"]: value = llama_state_dict[key] assert key not in openlm_state_dict openlm_state_dict[key] = value for i in range(n_layer): src_key_1, src_key_2, src_key_3 = ( f"layers.{i}.attention.wq.weight", f"layers.{i}.attention.wk.weight", f"layers.{i}.attention.wv.weight", ) tgt_key = f"layers.{i}.attention.in_proj.weight" assert tgt_key not in openlm_state_dict openlm_state_dict[tgt_key] = torch.cat( [ llama_state_dict[src_key_1], llama_state_dict[src_key_2], llama_state_dict[src_key_3], ], dim=0, ) src_key = f"layers.{i}.attention.wo.weight" tgt_key = f"layers.{i}.attention.out_proj.weight" assert tgt_key not in openlm_state_dict openlm_state_dict[tgt_key] = llama_state_dict[src_key] src_key_1, src_key_2 = ( f"layers.{i}.feed_forward.w1.weight", f"layers.{i}.feed_forward.w3.weight", ) tgt_key = f"layers.{i}.feed_forward.w12.weight" assert tgt_key not in openlm_state_dict openlm_state_dict[tgt_key] = torch.cat([llama_state_dict[src_key_1], llama_state_dict[src_key_2]], dim=0) src_key = f"layers.{i}.feed_forward.w2.weight" tgt_key = f"layers.{i}.feed_forward.w3.weight" assert tgt_key not in openlm_state_dict openlm_state_dict[tgt_key] = llama_state_dict[src_key] tgt_key = f"layers.{i}.attention_norm.weight" assert tgt_key not in openlm_state_dict openlm_state_dict[tgt_key] = llama_state_dict[tgt_key] tgt_key = f"layers.{i}.ffn_norm.weight" assert tgt_key not in openlm_state_dict openlm_state_dict[tgt_key] = llama_state_dict[tgt_key] return openlm_state_dict if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: `python convert_llama_to_openlm.py `") sys.exit(1) llama_state_dict = torch.load(sys.argv[1]) openlm_state_dict = {"state_dict": convert(llama_state_dict)} torch.save(openlm_state_dict, sys.argv[2])