|
import collections |
|
import contextlib |
|
import gc |
|
import json |
|
import shutil |
|
import sys |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
|
|
wd = Path(__file__).parent.parent.resolve() |
|
sys.path.append(str(wd)) |
|
|
|
from lit_llama.model import LLaMA, LLaMAConfig |
|
from lit_llama.utils import EmptyInitOnDevice, lazy_load, incremental_save |
|
|
|
|
|
@torch.no_grad() |
|
def convert_hf_checkpoint( |
|
*, |
|
output_dir: Path = Path("checkpoints/lit-llama/7B"), |
|
checkpoint_dir: Path = Path("checkpoints/hf-llama/7B"), |
|
model_size: str = "7B", |
|
dtype: str = "float32", |
|
verify: bool = False, |
|
) -> None: |
|
""" |
|
Perform the reverse operation of: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py |
|
""" |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
shutil.copy(checkpoint_dir / "tokenizer.model", output_dir.parent) |
|
|
|
dt = getattr(torch, dtype, None) |
|
if not isinstance(dt, torch.dtype): |
|
raise ValueError(f"{dtype} is not a valid dtype.") |
|
dtype = dt |
|
|
|
print("Initializing lit-llama") |
|
config = LLaMAConfig.from_name(model_size) |
|
|
|
with EmptyInitOnDevice(device="meta", dtype=dtype): |
|
model = LLaMA(config) |
|
|
|
qkv_size = model.transformer.h[0].attn.c_attn.weight.shape[0] // 3 |
|
|
|
|
|
sd_meta = model.state_dict() |
|
sd = {} |
|
|
|
|
|
pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" |
|
with open(pytorch_bin_map_json_path) as json_map: |
|
bin_index = json.load(json_map) |
|
bin_files = set(checkpoint_dir / bin for bin in bin_index["weight_map"].values()) |
|
if not bin_files: |
|
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files") |
|
|
|
def permute(w): |
|
dim = config.n_embd |
|
w = w._load_tensor().to(dtype) |
|
return ( |
|
w.view(config.n_head, 2, dim // config.n_head // 2, dim) |
|
.transpose(1, 2) |
|
.reshape(dim, dim) |
|
) |
|
|
|
weight_map = { |
|
"self_attn.o_proj.weight": "attn.c_proj.weight", |
|
"self_attn.q_proj.weight": "attn.c_attn.weight", |
|
"self_attn.k_proj.weight": "attn.c_attn.weight", |
|
"self_attn.v_proj.weight": "attn.c_attn.weight", |
|
"mlp.gate_proj.weight": "mlp.c_fc1.weight", |
|
"mlp.up_proj.weight": "mlp.c_fc2.weight", |
|
"mlp.down_proj.weight": "mlp.c_proj.weight", |
|
"input_layernorm.weight": "rms_1.scale", |
|
"post_attention_layernorm.weight": "rms_2.scale", |
|
"model.embed_tokens.weight": "transformer.wte.weight", |
|
"model.norm.weight": "transformer.ln_f.scale", |
|
"lm_head.weight": "lm_head.weight", |
|
} |
|
|
|
print(f"Saving to disk at {output_dir}") |
|
unprocessed_weights = collections.defaultdict(dict) |
|
|
|
with incremental_save(output_dir / "lit-llama.pth") as saver: |
|
|
|
|
|
with contextlib.ExitStack() as stack: |
|
for bin_file in bin_files: |
|
print("Processing", bin_file) |
|
hf_weights = stack.enter_context(lazy_load(bin_file)) |
|
for name, param in hf_weights.items(): |
|
skip = False |
|
if "rotary_emb.inv_freq" in name: |
|
continue |
|
if "model.layers" in name: |
|
block_id = int(name.split(".")[2]) |
|
from_name = ".".join(name.split(".")[3:]) |
|
to_name = weight_map[from_name] |
|
sd_key = f"transformer.h.{block_id}.{to_name}" |
|
|
|
if "q_proj" in name: |
|
unprocessed_weights[sd_key]["q_proj"] = param |
|
skip = True |
|
elif "k_proj" in name: |
|
unprocessed_weights[sd_key]["k_proj"] = param |
|
skip = True |
|
elif "v_proj" in name: |
|
unprocessed_weights[sd_key]["v_proj"] = param |
|
skip = True |
|
if skip and len(unprocessed_weights[sd_key]) == 3: |
|
w = torch.empty( |
|
sd_meta[sd_key].shape, dtype=sd_meta[sd_key].dtype |
|
) |
|
w[:qkv_size] = permute(unprocessed_weights[sd_key]["q_proj"]) |
|
w[qkv_size:-qkv_size] = permute( |
|
unprocessed_weights[sd_key]["k_proj"] |
|
) |
|
w[-qkv_size:] = ( |
|
unprocessed_weights[sd_key]["v_proj"] |
|
._load_tensor() |
|
.to(dtype) |
|
) |
|
sd[sd_key] = w |
|
del unprocessed_weights[sd_key] |
|
skip = False |
|
else: |
|
sd[sd_key] = param._load_tensor().to(dtype) |
|
else: |
|
sd_key = weight_map[name] |
|
sd[sd_key] = param._load_tensor().to(dtype) |
|
if not skip: |
|
sd[sd_key] = saver.store_early(sd[sd_key]) |
|
gc.collect() |
|
saver.save(sd) |
|
|
|
assert len(unprocessed_weights) == 0, f"unexpected partial weights {list(unprocessed_weights)}" |
|
if verify: |
|
try: |
|
from transformers import LlamaForCausalLM |
|
except ImportError: |
|
raise ImportError("verify=True requires transformers to be installed, please `pip install transformers`") |
|
print("Verifying...") |
|
|
|
token_sample = torch.randint(0, config.vocab_size, size=(1, config.block_size), dtype=torch.int64) |
|
out = model(token_sample) |
|
del model |
|
gc.collect() |
|
|
|
print("Loading original model for comparison") |
|
model_hf = LlamaForCausalLM.from_pretrained(checkpoint_dir) |
|
out_hf = model_hf(token_sample)["logits"] |
|
|
|
print("Comparing outputs") |
|
assert out.device.type == out_hf.device.type |
|
assert out.dtype == out_hf.dtype |
|
assert torch.testing.assert_close(out, out_hf) |
|
|
|
|
|
if __name__ == "__main__": |
|
from jsonargparse import CLI |
|
|
|
CLI(convert_hf_checkpoint) |
|
|
|
|