File size: 3,693 Bytes
7d774e5 67517f8 7d774e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
#!/usr/bin/env python3
# 1/17/2024
# Charles O. Goddard
"""Convert internlm2 weights to Llama format."""
import json
import os
import einops
import tqdm
from mergekit.io import LazyTensorLoader, TensorWriter
from mergekit.common import ModelReference
from transformers import LlamaTokenizer
MODEL_IN = "internlm/internlm2-20b"
OUT_PATH = "./internlm2-20b-llama"
model_ref = ModelReference.parse(MODEL_IN)
cfg = model_ref.config(trust_remote_code=True)
head_dim = cfg.hidden_size // cfg.num_attention_heads
num_key_value_groups = cfg.num_attention_heads // cfg.num_key_value_heads
loader = LazyTensorLoader(model_ref.tensor_index(), lazy_unpickle=True)
writer = TensorWriter(OUT_PATH)
SIMPLE_REPLACEMENTS = {
"feed_forward.w1": "mlp.gate_proj",
"feed_forward.w2": "mlp.down_proj",
"feed_forward.w3": "mlp.up_proj",
"attention.wo": "self_attn.o_proj",
"ffn_norm": "post_attention_layernorm",
"attention_norm": "input_layernorm",
"tok_embeddings": "embed_tokens",
"output.weight": "lm_head.weight",
}
for tensor_name in tqdm.tqdm(loader.index.tensor_paths):
tensor = loader.get_tensor(tensor_name)
if "attention.wqkv" in tensor_name:
# make me think about tensor shapes will you >:(
# ((cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim, cfg.hidden_size) x (batch_sz, sq_len, cfg.hidden_size)
# -> (batch_sz, sq_len, (cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim)
# qkv_states = rearrange(
# qkv_states,
# "b q (h gs d) -> b q h gs d",
# gs=2 + self.num_key_value_groups,
# d=self.head_dim,
# )
# ->(batch_sz, sq_len, h, 2 + self.num_key_value_groups, head_dim)
qkv_vecs = einops.rearrange(
tensor, "(h gs d) z -> h gs d z", gs=2 + num_key_value_groups, d=head_dim
)
q_proj = (
qkv_vecs[:, :num_key_value_groups, ...]
.reshape(-1, cfg.hidden_size)
.contiguous()
)
k_proj = qkv_vecs[:, -2, ...].reshape(-1, cfg.hidden_size).contiguous()
v_proj = qkv_vecs[:, -1, ...].reshape(-1, cfg.hidden_size).contiguous()
assert k_proj.shape == v_proj.shape
writer.save_tensor(
tensor_name.replace("attention.wqkv", "self_attn.q_proj"),
q_proj,
clone=True,
)
writer.save_tensor(
tensor_name.replace("attention.wqkv", "self_attn.k_proj"),
k_proj,
clone=True,
)
writer.save_tensor(
tensor_name.replace("attention.wqkv", "self_attn.v_proj"),
v_proj,
clone=True,
)
continue
out_name = tensor_name
for pattern, sub in SIMPLE_REPLACEMENTS.items():
if pattern in out_name:
out_name = out_name.replace(pattern, sub)
writer.save_tensor(out_name, tensor)
writer.finalize()
cfg_dict = json.loads(cfg.to_json_string())
del cfg_dict["auto_map"]
cfg_dict["architectures"] = ["LlamaForCausalLM"]
cfg_dict["model_type"] = "llama"
if "rope_scaling" in cfg_dict and cfg_dict["rope_scaling"]["factor"] == 1.0:
del cfg_dict["rope_scaling"]
with open(os.path.join(OUT_PATH, "config.json"), "w", encoding="utf-8") as fp:
json.dump(cfg_dict, fp, indent=2)
# InternLMTokenizer differences:
# 1. clean_up_tokenization() hardcoded to always be called
# 2. might prepend a space to some tokens that LlamaTokenizer doesn't if they're the first token
# 1 is easy to fix, 2... is not important
tok = LlamaTokenizer.from_pretrained(MODEL_IN, trust_remote_code=False, legacy=True)
tok.clean_up_tokenization_spaces = True
tok.save_pretrained(OUT_PATH)
|