|
import gc |
|
import shutil |
|
from pathlib import Path |
|
from typing import Dict |
|
|
|
import torch |
|
from tqdm import tqdm |
|
|
|
""" |
|
Sample usage: |
|
|
|
```bash |
|
python -m scripts.convert_checkpoint -h |
|
|
|
python -m scripts.convert_checkpoint converted |
|
``` |
|
""" |
|
|
|
|
|
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]: |
|
converted = {} |
|
converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype) |
|
converted["lm_head.weight"] = state_dict["output.weight"].to(dtype) |
|
converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype) |
|
|
|
for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])): |
|
|
|
|
|
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat( |
|
( |
|
state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype), |
|
state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype), |
|
state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype), |
|
) |
|
) |
|
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[ |
|
f"layers.{layer_idx}.attention.wo.weight" |
|
].to(dtype) |
|
|
|
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[ |
|
f"layers.{layer_idx}.feed_forward.w1.weight" |
|
].to(dtype) |
|
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[ |
|
f"layers.{layer_idx}.feed_forward.w2.weight" |
|
].to(dtype) |
|
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[ |
|
f"layers.{layer_idx}.feed_forward.w3.weight" |
|
].to(dtype) |
|
|
|
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype) |
|
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype) |
|
return converted |
|
|
|
|
|
shard_dims = { |
|
"lm_head.weight": 0, |
|
"wte.weight": 1, |
|
"attn.c_attn.weight": 0, |
|
"attn.c_proj.weight": 1, |
|
"mlp.c_fc1.weight": 0, |
|
"mlp.c_fc2.weight": 0, |
|
"mlp.c_proj.weight": 1 |
|
} |
|
|
|
|
|
def meta_weights_for_nano_model( |
|
*, |
|
output_dir: Path = Path("checkpoints/lit-llama"), |
|
checkpoint_dir: Path = Path("checkpoints/llama/"), |
|
model_size: str = "7B", |
|
dtype: str = "float32", |
|
) -> None: |
|
output_dir = output_dir / model_size |
|
checkpoint_dir = checkpoint_dir / model_size |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
shutil.copy(checkpoint_dir.parent / "tokenizer.model", output_dir.parent) |
|
|
|
dt = getattr(torch, dtype, None) |
|
if not isinstance(dt, torch.dtype): |
|
raise ValueError(f"{dtype} is not a valid dtype.") |
|
dtype = dt |
|
|
|
checkpoint_files = sorted(checkpoint_dir.glob("*.pth")) |
|
checkpoint_files.sort() |
|
n_checkpoints = len(checkpoint_files) |
|
|
|
if n_checkpoints == 0: |
|
raise RuntimeError(f"No checkpoints were found at checkpoint_dir {checkpoint_dir}. `consolidated.0*.pth` files expected at that location.") |
|
|
|
|
|
|
|
combined = None |
|
for file in tqdm(checkpoint_files, total=n_checkpoints): |
|
checkpoint = torch.load(file, map_location="cpu") |
|
converted = convert_state_dict(checkpoint, dtype=dtype) |
|
if combined is None: |
|
combined = converted |
|
continue |
|
for name, param in converted.items(): |
|
dim = None |
|
for k, d in shard_dims.items(): |
|
if k in name: |
|
dim = d |
|
break |
|
if dim is None: |
|
|
|
|
|
continue |
|
combined[name] = torch.cat((combined[name], param), dim=dim) |
|
|
|
del checkpoint |
|
del converted |
|
gc.collect() |
|
|
|
for name, param in combined.items(): |
|
if "c_attn" not in name: |
|
continue |
|
|
|
|
|
|
|
src_chunk_len = param.shape[0] // n_checkpoints |
|
mat_len = src_chunk_len // 3 |
|
dst_chunk_len = mat_len * n_checkpoints |
|
attn = torch.clone(param) |
|
for i in range(n_checkpoints): |
|
for j in range(3): |
|
param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \ |
|
attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len] |
|
|
|
del attn |
|
gc.collect() |
|
|
|
torch.save(combined, output_dir / "lit-llama.pth") |
|
|
|
|
|
if __name__ == "__main__": |
|
from jsonargparse import CLI |
|
|
|
CLI(meta_weights_for_nano_model) |
|
|