File size: 3,112 Bytes
8695694
d4ec07a
8695694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# code adapted from https://huggingface.co/fahadh4ilyas
import argparse, json
from safetensors import safe_open
from safetensors.torch import save_file
from pathlib import Path

parser = argparse.ArgumentParser(description="Convert original dbrx model into quantizable model")

parser.add_argument("--model-dir", type=str, required=True, help="directory to the original dbrx model")
parser.add_argument("--output-dir", type=str, required=True, help="directory for the converted dbrx model")
args = parser.parse_args()

model_dir = Path(args.model_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

NUM_EXPERTS = 16
HIDDEN_SIZE = 6144
HEAD_DIM = 128
NUM_KV_HEAD = 8
FFN_HIDDEN_SIZE = 10752

def change_tensor_attn(tensor):

    return [x.contiguous() for x in tensor.split([HIDDEN_SIZE, NUM_KV_HEAD*HEAD_DIM, NUM_KV_HEAD*HEAD_DIM])]

def change_attn(tensors):

    keys = list(tensors.keys())
    for k in keys:
        if 'Wqkv' in k:
            prefix = k.removesuffix('.Wqkv.weight')
            tensor = tensors.pop(k)
            output_tensor = change_tensor_attn(tensor)
            for dtype,t in zip(['q_proj', 'k_proj', 'v_proj'], output_tensor):
                tensors[f'{prefix}.{dtype}.weight'] = t
    
    return tensors

def change_tensor_mlp(tensor, reverse=False):

    output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)]

    return output

def change_mlp(tensors):

    keys = list(tensors.keys())
    for k in keys:
        if any([x in k for x in ['w1', 'v1', 'w2']]):
            prefix,dtype = k.rsplit('.', 1)
            tensor = tensors.pop(k)
            output_tensor = change_tensor_mlp(tensor, dtype=='w2')
            for i,t in enumerate(output_tensor):
                tensors[f'{prefix}.{i}.{dtype}.weight'] = t

    return tensors

for file in sorted(list(model_dir.glob('*.safetensors'))):
    print(file)
    tensors = {}
    with safe_open(file, 'pt') as f:
        metadata = f.metadata()
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
    tensors = change_attn(tensors)
    tensors = change_mlp(tensors)
    save_file(tensors, (output_dir / file.name).as_posix(), metadata)

with open(model_dir / 'model.safetensors.index.json') as f:
    weight_map = json.load(f)

weight_keys = list(weight_map['weight_map'])
for k in weight_keys:
    if any([x in k for x in ['w1', 'v1', 'w2']]):
        prefix,dtype = k.rsplit('.', 1)
        value = weight_map['weight_map'].pop(k)
        for i in range(NUM_EXPERTS):
            weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value
    elif 'Wqkv' in k:
        prefix = k.removesuffix('.Wqkv.weight')
        value = weight_map['weight_map'].pop(k)
        for dtype in ['q_proj', 'k_proj', 'v_proj']:
            weight_map['weight_map'][f'{prefix}.{dtype}.weight'] = value

sorted_map = sorted(weight_map['weight_map'].items())
weight_map['weight_map'] = dict(sorted_map)

with open(output_dir / 'model.safetensors.index.json', 'w') as f:
    json.dump(weight_map, f, indent=4)