File size: 4,712 Bytes
12001a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gc
import json
import shutil
import sys
from pathlib import Path
import torch
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from lit_llama.model import LLaMA, LLaMAConfig
from lit_llama.utils import EmptyInitOnDevice
@torch.no_grad()
def convert_hf_checkpoint(
*,
output_dir: Path = Path("checkpoints/lit-llama/7B"),
checkpoint_dir: Path = Path("checkpoints/hf-llama/7B"),
model_size: str = "7B",
dtype: str = "float32",
verify: bool = False,
) -> None:
"""
Perform the reverse operation of: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
"""
output_dir.mkdir(parents=True, exist_ok=True)
# the tokenizer is the same for all model sizes, so we store it in the parent dir
shutil.copy(checkpoint_dir / "tokenizer.model", output_dir.parent)
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
dtype = dt
print("Initializing lit-llama")
config = LLaMAConfig.from_name(model_size)
with EmptyInitOnDevice(device="cpu", dtype=dtype):
model = LLaMA(config)
qkv_size = model.transformer.h[0].attn.c_attn.weight.shape[0] // 3
# initialize a new empty state dict to hold our new weights
sd = model.state_dict()
# Load the json file containing weight mapping
pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json"
with open(pytorch_bin_map_json_path) as json_map:
bin_index = json.load(json_map)
bin_files = set(el for el in bin_index["weight_map"].values())
def permute(w):
dim = config.n_embd
return (
w.view(config.n_head, 2, dim // config.n_head // 2, dim)
.transpose(1, 2)
.reshape(dim, dim)
)
weight_map = {
"self_attn.o_proj.weight": "attn.c_proj.weight",
"self_attn.q_proj.weight": "attn.c_attn.weight",
"self_attn.k_proj.weight": "attn.c_attn.weight",
"self_attn.v_proj.weight": "attn.c_attn.weight",
"mlp.gate_proj.weight": "mlp.c_fc1.weight",
"mlp.up_proj.weight": "mlp.c_fc2.weight",
"mlp.down_proj.weight": "mlp.c_proj.weight",
"input_layernorm.weight": "rms_1.scale",
"post_attention_layernorm.weight": "rms_2.scale",
"model.embed_tokens.weight": "transformer.wte.weight",
"model.norm.weight": "transformer.ln_f.scale",
"lm_head.weight": "lm_head.weight"
}
for bin_file in bin_files:
print("Processing", bin_file)
hf_weights = torch.load(checkpoint_dir / bin_file, map_location="cpu")
for name, param in hf_weights.items():
param = param.to(dtype=dtype)
if "rotary_emb.inv_freq" in name:
continue
if "model.layers" in name:
block_id = int(name.split(".")[2])
from_name = ".".join(name.split(".")[3:])
to_name = weight_map[from_name]
if "q_proj" in name:
sd[f"transformer.h.{block_id}.{to_name}"][:qkv_size] = permute(param)
elif "k_proj" in name:
sd[f"transformer.h.{block_id}.{to_name}"][qkv_size:-qkv_size] = permute(param)
elif "v_proj" in name:
sd[f"transformer.h.{block_id}.{to_name}"][-qkv_size:] = param
else:
sd[f"transformer.h.{block_id}.{to_name}"].copy_(param)
else:
sd[weight_map[name]].copy_(param)
del hf_weights
gc.collect()
print(f"Saving to disk at {output_dir}")
torch.save(model.state_dict(), output_dir / "lit-llama.pth")
if verify:
try:
from transformers import LlamaForCausalLM
except ImportError:
raise ImportError("verify=True requires transformers to be installed, please `pip install transformers`")
print("Verifying...")
token_sample = torch.randint(0, config.vocab_size, size=(1, config.block_size), dtype=torch.int64)
out = model(token_sample)
del model
gc.collect()
print("Loading original model for comparison")
model_hf = LlamaForCausalLM.from_pretrained(checkpoint_dir)
out_hf = model_hf(token_sample)["logits"]
print("Comparing outputs")
assert out.device.type == out_hf.device.type
assert out.dtype == out_hf.dtype
assert torch.testing.assert_close(out, out_hf)
if __name__ == "__main__":
from jsonargparse import CLI
CLI(convert_hf_checkpoint)
|