File size: 3,762 Bytes
65f8153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
# 1/17/2024
# Charles O. Goddard
# https://huggingface.co/chargoddard/internlm2-7b-llama/raw/main/convert_weights.py
"""Convert internlm2 weights to Llama format."""

import json
import os
import einops
import tqdm
from mergekit.io import LazyTensorLoader, TensorWriter
from mergekit.common import ModelReference
from transformers import LlamaTokenizer

MODEL_IN = "raw weights"
OUT_PATH = "llamafied weights"

model_ref = ModelReference.parse(MODEL_IN)
cfg = model_ref.config(trust_remote_code=True)
head_dim = cfg.hidden_size // cfg.num_attention_heads
num_key_value_groups = cfg.num_attention_heads // cfg.num_key_value_heads
loader = LazyTensorLoader(model_ref.tensor_index(), lazy_unpickle=True)
writer = TensorWriter(OUT_PATH)

SIMPLE_REPLACEMENTS = {
    "feed_forward.w1": "mlp.gate_proj",
    "feed_forward.w2": "mlp.down_proj",
    "feed_forward.w3": "mlp.up_proj",
    "attention.wo": "self_attn.o_proj",
    "ffn_norm": "post_attention_layernorm",
    "attention_norm": "input_layernorm",
    "tok_embeddings": "embed_tokens",
    "output.weight": "lm_head.weight",
}

for tensor_name in tqdm.tqdm(loader.index.tensor_paths):
    tensor = loader.get_tensor(tensor_name)
    if "attention.wqkv" in tensor_name:
        # make me think about tensor shapes will you >:(

        # ((cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim, cfg.hidden_size) x (batch_sz, sq_len, cfg.hidden_size)
        # -> (batch_sz, sq_len, (cfg.num_attention_heads + 2 * cfg.num_key_value_heads) * head_dim)
        # qkv_states = rearrange(
        #     qkv_states,
        #     "b q (h gs d) -> b q h gs d",
        #     gs=2 + self.num_key_value_groups,
        #     d=self.head_dim,
        # )
        # ->(batch_sz, sq_len, h, 2 + self.num_key_value_groups, head_dim)
        qkv_vecs = einops.rearrange(
            tensor, "(h gs d) z -> h gs d z", gs=2 + num_key_value_groups, d=head_dim
        )
        q_proj = (
            qkv_vecs[:, :num_key_value_groups, ...]
            .reshape(-1, cfg.hidden_size)
            .contiguous()
        )
        k_proj = qkv_vecs[:, -2, ...].reshape(-1, cfg.hidden_size).contiguous()
        v_proj = qkv_vecs[:, -1, ...].reshape(-1, cfg.hidden_size).contiguous()
        assert k_proj.shape == v_proj.shape

        writer.save_tensor(
            tensor_name.replace("attention.wqkv", "self_attn.q_proj"),
            q_proj,
            clone=True,
        )
        writer.save_tensor(
            tensor_name.replace("attention.wqkv", "self_attn.k_proj"),
            k_proj,
            clone=True,
        )
        writer.save_tensor(
            tensor_name.replace("attention.wqkv", "self_attn.v_proj"),
            v_proj,
            clone=True,
        )
        continue

    out_name = tensor_name
    for pattern, sub in SIMPLE_REPLACEMENTS.items():
        if pattern in out_name:
            out_name = out_name.replace(pattern, sub)
    writer.save_tensor(out_name, tensor)
writer.finalize()

cfg_dict = json.loads(cfg.to_json_string())
del cfg_dict["auto_map"]
cfg_dict["architectures"] = ["LlamaForCausalLM"]
cfg_dict["model_type"] = "llama"
if "rope_scaling" in cfg_dict and cfg_dict["rope_scaling"]["factor"] == 1.0:
    del cfg_dict["rope_scaling"]
with open(os.path.join(OUT_PATH, "config.json"), "w", encoding="utf-8") as fp:
    json.dump(cfg_dict, fp, indent=2)

# InternLMTokenizer differences:
# 1. clean_up_tokenization() hardcoded to always be called
# 2. might prepend a space to some tokens that LlamaTokenizer doesn't if they're the first token
# 1 is easy to fix, 2... is not important
tok = LlamaTokenizer.from_pretrained(MODEL_IN, trust_remote_code=False, legacy=True)
tok.clean_up_tokenization_spaces = True
tok.save_pretrained(OUT_PATH)