File size: 2,719 Bytes
dd7d688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f3503c
dd7d688
 
 
 
 
72c6202
dd7d688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import glob
import json
import torch
from safetensors import safe_open
from safetensors.torch import save_file

def main():
    print("densifying...")
    files = glob.glob("model-*-of-*.safetensors")
    densified = {}

    for mf in sorted(files):
        st = safe_open(mf, framework='pt')
        print(f"processing {mf}")

        for key in st.keys():
            tensor = st.get_tensor(key)
            nk = key.replace("block_sparse_moe.experts.0.w2.weight", "mlp.down_proj.weight")
            nk = nk.replace("block_sparse_moe.experts.0.w1.weight", "mlp.gate_proj.weight")
            nk = nk.replace("block_sparse_moe.experts.0.w3.weight", "mlp.up_proj.weight")

            if "block_sparse_moe.gate.weight" not in key:
                densified[nk] = tensor
                print(f"+ {key} -> {nk}")
            else:
                print(f"- {key}")
        save_file(densified, f"{mf}")
        densified = {}

    with open("config.json", "r") as read_conf_file:
        conf = json.load(read_conf_file)
    with open("config.json.old", "w") as conf_file:
        json.dump(conf, conf_file, indent=4)

    if 'num_local_experts' in conf:
        del conf['num_local_experts']
    if 'num_experts_per_tok' in conf:
        del conf['num_experts_per_tok']
    if 'output_router_logits' in conf:
        del conf['output_router_logits']
    if 'router_aux_loss_coef' in conf:
        del conf['router_aux_loss_coef']
    if 'router_jitter_noise' in conf:
        del conf['router_jitter_noise']

    conf['architectures'] = ['MistralForCausalLM']
    conf['model_type'] = 'mistral'
    conf['_name_or_path'] = 'lodrick-the-lafted/Densefin-291-Mistral-22B'

    with open("config.json", "w") as write_conf_file:
        json.dump(conf, write_conf_file, indent=4)

    with open("model.safetensors.index.json", "r") as index_file:
        index = json.load(index_file)

    new_index = {}
    new_index['metadata'] = index['metadata']

    wm = index['weight_map']
    nwm = {}

    for weight in wm.keys():
        nk = weight.replace("block_sparse_moe.experts.0.w2.weight", "mlp.down_proj.weight")
        nk = nk.replace("block_sparse_moe.experts.0.w1.weight", "mlp.gate_proj.weight")
        nk = nk.replace("block_sparse_moe.experts.0.w3.weight", "mlp.up_proj.weight")
        #print(f"{weight} -> {nk}")

        if "block_sparse_moe.gate.weight" not in weight:
            nwm[nk] = wm[weight]

    new_index['weight_map'] = nwm
 
    with open("model.safetensors.index.json.old", "w") as backup:
        json.dump(index, backup, indent=4)
    with open("model.safetensors.index.json", "w") as new_st_index:
        json.dump(new_index, new_st_index, indent=4)

if __name__ == "__main__":
    main()