File size: 3,693 Bytes
7d774e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67517f8
7d774e5
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
# 1/17/2024
# Charles O. Goddard
"""Convert internlm2 weights to Llama format."""

import json
import os
import einops
import tqdm
from mergekit.io import LazyTensorLoader, TensorWriter
from mergekit.common import ModelReference
from transformers import LlamaTokenizer

MODEL_IN = "internlm/internlm2-20b"
OUT_PATH = "./internlm2-20b-llama"

model_ref = ModelReference.parse(MODEL_IN)
cfg = model_ref.config(trust_remote_code=True)
head_dim = cfg.hidden_size // cfg.num_attention_heads
num_key_value_groups = cfg.num_attention_heads // cfg.num_key_value_heads
loader = LazyTensorLoader(model_ref.tensor_index(), lazy_unpickle=True)
writer = TensorWriter(OUT_PATH)

SIMPLE_REPLACEMENTS = {
    "feed_forward.w1": "mlp.gate_proj",
    "feed_forward.w2": "mlp.down_proj",
    "feed_forward.w3": "mlp.up_proj",
    "attention.wo": "self_attn.o_proj",
    "ffn_norm": "post_attention_layernorm",
    "attention_norm": "input_layernorm",
    "tok_embeddings": "embed_tokens",
    "output.weight": "lm_head.weight",
}

for tensor_name in tqdm.tqdm(loader.index.tensor_paths):
    tensor = loader.get_tensor(tensor_name)
    if "attention.wqkv" in tensor_name:
        # make me think about tensor shapes will you >:(

        # ((cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim, cfg.hidden_size) x (batch_sz, sq_len, cfg.hidden_size)
        # -> (batch_sz, sq_len, (cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim)
        # qkv_states = rearrange(
        #     qkv_states,
        #     "b q (h gs d) -> b q h gs d",
        #     gs=2 + self.num_key_value_groups,
        #     d=self.head_dim,
        # )
        # ->(batch_sz, sq_len, h, 2 + self.num_key_value_groups, head_dim)
        qkv_vecs = einops.rearrange(
            tensor, "(h gs d) z -> h gs d z", gs=2 + num_key_value_groups, d=head_dim
        )
        q_proj = (
            qkv_vecs[:, :num_key_value_groups, ...]
            .reshape(-1, cfg.hidden_size)
            .contiguous()
        )
        k_proj = qkv_vecs[:, -2, ...].reshape(-1, cfg.hidden_size).contiguous()
        v_proj = qkv_vecs[:, -1, ...].reshape(-1, cfg.hidden_size).contiguous()
        assert k_proj.shape == v_proj.shape

        writer.save_tensor(
            tensor_name.replace("attention.wqkv", "self_attn.q_proj"),
            q_proj,
            clone=True,
        )
        writer.save_tensor(
            tensor_name.replace("attention.wqkv", "self_attn.k_proj"),
            k_proj,
            clone=True,
        )
        writer.save_tensor(
            tensor_name.replace("attention.wqkv", "self_attn.v_proj"),
            v_proj,
            clone=True,
        )
        continue

    out_name = tensor_name
    for pattern, sub in SIMPLE_REPLACEMENTS.items():
        if pattern in out_name:
            out_name = out_name.replace(pattern, sub)
    writer.save_tensor(out_name, tensor)
writer.finalize()

cfg_dict = json.loads(cfg.to_json_string())
del cfg_dict["auto_map"]
cfg_dict["architectures"] = ["LlamaForCausalLM"]
cfg_dict["model_type"] = "llama"
if "rope_scaling" in cfg_dict and cfg_dict["rope_scaling"]["factor"] == 1.0:
    del cfg_dict["rope_scaling"]
with open(os.path.join(OUT_PATH, "config.json"), "w", encoding="utf-8") as fp:
    json.dump(cfg_dict, fp, indent=2)

# InternLMTokenizer differences:
# 1. clean_up_tokenization() hardcoded to always be called
# 2. might prepend a space to some tokens that LlamaTokenizer doesn't if they're the first token
# 1 is easy to fix, 2... is not important
tok = LlamaTokenizer.from_pretrained(MODEL_IN, trust_remote_code=False, legacy=True)
tok.clean_up_tokenization_spaces = True
tok.save_pretrained(OUT_PATH)