File size: 5,048 Bytes
7d52396 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gc
import shutil
from pathlib import Path
from typing import Dict
import torch
from tqdm import tqdm
"""
Sample usage:
```bash
python -m scripts.convert_checkpoint -h
python -m scripts.convert_checkpoint converted
```
"""
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
converted = {}
converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype)
converted["lm_head.weight"] = state_dict["output.weight"].to(dtype)
converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype)
for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
# attention
# the wq, wk, wv from the FB model are stacked in our model as c_attn
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat(
(
state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype),
state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype),
state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype),
)
)
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[
f"layers.{layer_idx}.attention.wo.weight"
].to(dtype)
# mlp
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[
f"layers.{layer_idx}.feed_forward.w1.weight"
].to(dtype)
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[
f"layers.{layer_idx}.feed_forward.w2.weight"
].to(dtype)
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[
f"layers.{layer_idx}.feed_forward.w3.weight"
].to(dtype)
# rms norm
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype)
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype)
return converted
shard_dims = {
"lm_head.weight": 0,
"wte.weight": 1,
"attn.c_attn.weight": 0,
"attn.c_proj.weight": 1,
"mlp.c_fc1.weight": 0,
"mlp.c_fc2.weight": 0,
"mlp.c_proj.weight": 1
}
def meta_weights_for_nano_model(
*,
output_dir: Path = Path("checkpoints/lit-llama"),
checkpoint_dir: Path = Path("checkpoints/llama/"),
model_size: str = "7B",
dtype: str = "float32",
) -> None:
output_dir = output_dir / model_size
checkpoint_dir = checkpoint_dir / model_size
output_dir.mkdir(parents=True, exist_ok=True)
# the tokenizer is the same for all model sizes, so we store it in the parent dir
shutil.copy(checkpoint_dir.parent / "tokenizer.model", output_dir.parent)
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
dtype = dt
checkpoint_files = sorted(checkpoint_dir.glob("*.pth"))
checkpoint_files.sort()
n_checkpoints = len(checkpoint_files)
if n_checkpoints == 0:
raise RuntimeError(f"No checkpoints were found at checkpoint_dir {checkpoint_dir}. `consolidated.0*.pth` files expected at that location.")
# for the bigger models, there are multiple model-parallel checkpoints
# and we combine them into one single file
combined = None
for file in tqdm(checkpoint_files, total=n_checkpoints):
checkpoint = torch.load(file, map_location="cpu")
converted = convert_state_dict(checkpoint, dtype=dtype)
if combined is None:
combined = converted
continue
for name, param in converted.items():
dim = None
for k, d in shard_dims.items():
if k in name:
dim = d
break
if dim is None:
# Extra check: assert that tensors are the same if not sharded
# assert torch.allclose(combined[name], param)
continue
combined[name] = torch.cat((combined[name], param), dim=dim)
del checkpoint
del converted
gc.collect()
for name, param in combined.items():
if "c_attn" not in name:
continue
# Turn [Q1, K1, V1, Q2, K2, V2, ...] into [Q1, Q2, ..., K1, K2, .., V1, V2, ...]
src_chunk_len = param.shape[0] // n_checkpoints
mat_len = src_chunk_len // 3
dst_chunk_len = mat_len * n_checkpoints
attn = torch.clone(param)
for i in range(n_checkpoints):
for j in range(3):
param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \
attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len]
del attn
gc.collect()
torch.save(combined, output_dir / "lit-llama.pth")
if __name__ == "__main__":
from jsonargparse import CLI
CLI(meta_weights_for_nano_model)
|