chargoddard
commited on
Add script for weight conversion
Browse files- convert_weights.py +100 -0
convert_weights.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# 1/17/2024
|
3 |
+
# Charles O. Goddard
|
4 |
+
"""Convert internlm2 weights to Llama format."""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import einops
|
9 |
+
import tqdm
|
10 |
+
from mergekit.io import LazyTensorLoader, TensorWriter
|
11 |
+
from mergekit.common import ModelReference
|
12 |
+
from transformers import LlamaTokenizer
|
13 |
+
|
14 |
+
MODEL_IN = "internlm/internlm2-20b"
|
15 |
+
OUT_PATH = "./internlm2-20b-llama"
|
16 |
+
|
17 |
+
model_ref = ModelReference.parse(MODEL_IN)
|
18 |
+
cfg = model_ref.config(trust_remote_code=True)
|
19 |
+
head_dim = cfg.hidden_size // cfg.num_attention_heads
|
20 |
+
num_key_value_groups = cfg.num_attention_heads // cfg.num_key_value_heads
|
21 |
+
loader = LazyTensorLoader(model_ref.tensor_index(), lazy_unpickle=True)
|
22 |
+
writer = TensorWriter(OUT_PATH)
|
23 |
+
|
24 |
+
SIMPLE_REPLACEMENTS = {
|
25 |
+
"feed_forward.w1": "mlp.gate_proj",
|
26 |
+
"feed_forward.w2": "mlp.down_proj",
|
27 |
+
"feed_forward.w3": "mlp.up_proj",
|
28 |
+
"attention.wo": "self_attn.o_proj",
|
29 |
+
"ffn_norm": "post_attention_layernorm",
|
30 |
+
"attention_norm": "input_layernorm",
|
31 |
+
"tok_embeddings": "embed_tokens",
|
32 |
+
"output.weight": "lm_head.weight",
|
33 |
+
}
|
34 |
+
|
35 |
+
for tensor_name in tqdm.tqdm(loader.index.tensor_paths):
|
36 |
+
tensor = loader.get_tensor(tensor_name)
|
37 |
+
if "attention.wqkv" in tensor_name:
|
38 |
+
# make me think about tensor shapes will you >:(
|
39 |
+
|
40 |
+
# ((cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim, cfg.hidden_size) x (batch_sz, sq_len, cfg.hidden_size)
|
41 |
+
# -> (batch_sz, sq_len, (cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim)
|
42 |
+
# qkv_states = rearrange(
|
43 |
+
# qkv_states,
|
44 |
+
# "b q (h gs d) -> b q h gs d",
|
45 |
+
# gs=2 + self.num_key_value_groups,
|
46 |
+
# d=self.head_dim,
|
47 |
+
# )
|
48 |
+
# ->(batch_sz, sq_len, h, 2 + self.num_key_value_groups, head_dim)
|
49 |
+
qkv_vecs = einops.rearrange(
|
50 |
+
tensor, "(h gs d) z -> h gs d z", gs=2 + num_key_value_groups, d=head_dim
|
51 |
+
)
|
52 |
+
q_proj = (
|
53 |
+
qkv_vecs[:, :num_key_value_groups, ...]
|
54 |
+
.reshape(-1, cfg.hidden_size)
|
55 |
+
.contiguous()
|
56 |
+
)
|
57 |
+
k_proj = qkv_vecs[:, -2, ...].reshape(-1, cfg.hidden_size).contiguous()
|
58 |
+
v_proj = qkv_vecs[:, -1, ...].reshape(-1, cfg.hidden_size).contiguous()
|
59 |
+
assert k_proj.shape == v_proj.shape
|
60 |
+
|
61 |
+
writer.save_tensor(
|
62 |
+
tensor_name.replace("attention.wqkv", "self_attn.q_proj"),
|
63 |
+
q_proj,
|
64 |
+
clone=True,
|
65 |
+
)
|
66 |
+
writer.save_tensor(
|
67 |
+
tensor_name.replace("attention.wqkv", "self_attn.k_proj"),
|
68 |
+
k_proj,
|
69 |
+
clone=True,
|
70 |
+
)
|
71 |
+
writer.save_tensor(
|
72 |
+
tensor_name.replace("attention.wqkv", "self_attn.v_proj"),
|
73 |
+
v_proj,
|
74 |
+
clone=True,
|
75 |
+
)
|
76 |
+
continue
|
77 |
+
|
78 |
+
out_name = tensor_name
|
79 |
+
for pattern, sub in SIMPLE_REPLACEMENTS.items():
|
80 |
+
if pattern in out_name:
|
81 |
+
out_name = out_name.replace(pattern, sub)
|
82 |
+
writer.save_tensor(out_name, tensor)
|
83 |
+
writer.finalize()
|
84 |
+
|
85 |
+
cfg_dict = json.loads(cfg.to_json_string())
|
86 |
+
del cfg_dict["auto_map"]
|
87 |
+
cfg_dict["architectures"] = "LlamaForCausalLM"
|
88 |
+
cfg_dict["model_type"] = "llama"
|
89 |
+
if "rope_scaling" in cfg_dict and cfg_dict["rope_scaling"]["factor"] == 1.0:
|
90 |
+
del cfg_dict["rope_scaling"]
|
91 |
+
with open(os.path.join(OUT_PATH, "config.json"), "w", encoding="utf-8") as fp:
|
92 |
+
json.dump(cfg_dict, fp, indent=2)
|
93 |
+
|
94 |
+
# InternLMTokenizer differences:
|
95 |
+
# 1. clean_up_tokenization() hardcoded to always be called
|
96 |
+
# 2. might prepend a space to some tokens that LlamaTokenizer doesn't if they're the first token
|
97 |
+
# 1 is easy to fix, 2... is not important
|
98 |
+
tok = LlamaTokenizer.from_pretrained(MODEL_IN, trust_remote_code=False, legacy=True)
|
99 |
+
tok.clean_up_tokenization_spaces = True
|
100 |
+
tok.save_pretrained(OUT_PATH)
|