Qubitium commited on
Commit
8695694
1 Parent(s): 14aa5ff

Create convert_v2.py

Browse files
Files changed (1) hide show
  1. convert_v2.py +90 -0
convert_v2.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from https://huggingface.co/fahadh4ilyas
2
+ mport argparse, json
3
+ from safetensors import safe_open
4
+ from safetensors.torch import save_file
5
+ from pathlib import Path
6
+
7
+ parser = argparse.ArgumentParser(description="Convert original dbrx model into quantizable model")
8
+
9
+ parser.add_argument("--model-dir", type=str, required=True, help="directory to the original dbrx model")
10
+ parser.add_argument("--output-dir", type=str, required=True, help="directory for the converted dbrx model")
11
+ args = parser.parse_args()
12
+
13
+ model_dir = Path(args.model_dir)
14
+ output_dir = Path(args.output_dir)
15
+ output_dir.mkdir(parents=True, exist_ok=True)
16
+
17
+ NUM_EXPERTS = 16
18
+ HIDDEN_SIZE = 6144
19
+ HEAD_DIM = 128
20
+ NUM_KV_HEAD = 8
21
+ FFN_HIDDEN_SIZE = 10752
22
+
23
+ def change_tensor_attn(tensor):
24
+
25
+ return [x.contiguous() for x in tensor.split([HIDDEN_SIZE, NUM_KV_HEAD*HEAD_DIM, NUM_KV_HEAD*HEAD_DIM])]
26
+
27
+ def change_attn(tensors):
28
+
29
+ keys = list(tensors.keys())
30
+ for k in keys:
31
+ if 'Wqkv' in k:
32
+ prefix = k.removesuffix('.Wqkv.weight')
33
+ tensor = tensors.pop(k)
34
+ output_tensor = change_tensor_attn(tensor)
35
+ for dtype,t in zip(['q_proj', 'k_proj', 'v_proj'], output_tensor):
36
+ tensors[f'{prefix}.{dtype}.weight'] = t
37
+
38
+ return tensors
39
+
40
+ def change_tensor_mlp(tensor, reverse=False):
41
+
42
+ output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)]
43
+
44
+ return output
45
+
46
+ def change_mlp(tensors):
47
+
48
+ keys = list(tensors.keys())
49
+ for k in keys:
50
+ if any([x in k for x in ['w1', 'v1', 'w2']]):
51
+ prefix,dtype = k.rsplit('.', 1)
52
+ tensor = tensors.pop(k)
53
+ output_tensor = change_tensor_mlp(tensor, dtype=='w2')
54
+ for i,t in enumerate(output_tensor):
55
+ tensors[f'{prefix}.{i}.{dtype}.weight'] = t
56
+
57
+ return tensors
58
+
59
+ for file in sorted(list(model_dir.glob('*.safetensors'))):
60
+ print(file)
61
+ tensors = {}
62
+ with safe_open(file, 'pt') as f:
63
+ metadata = f.metadata()
64
+ for k in f.keys():
65
+ tensors[k] = f.get_tensor(k)
66
+ tensors = change_attn(tensors)
67
+ tensors = change_mlp(tensors)
68
+ save_file(tensors, (output_dir / file.name).as_posix(), metadata)
69
+
70
+ with open(model_dir / 'model.safetensors.index.json') as f:
71
+ weight_map = json.load(f)
72
+
73
+ weight_keys = list(weight_map['weight_map'])
74
+ for k in weight_keys:
75
+ if any([x in k for x in ['w1', 'v1', 'w2']]):
76
+ prefix,dtype = k.rsplit('.', 1)
77
+ value = weight_map['weight_map'].pop(k)
78
+ for i in range(NUM_EXPERTS):
79
+ weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value
80
+ elif 'Wqkv' in k:
81
+ prefix = k.removesuffix('.Wqkv.weight')
82
+ value = weight_map['weight_map'].pop(k)
83
+ for dtype in ['q_proj', 'k_proj', 'v_proj']:
84
+ weight_map['weight_map'][f'{prefix}.{dtype}.weight'] = value
85
+
86
+ sorted_map = sorted(weight_map['weight_map'].items())
87
+ weight_map['weight_map'] = dict(sorted_map)
88
+
89
+ with open(output_dir / 'model.safetensors.index.json', 'w') as f:
90
+ json.dump(weight_map, f, indent=4)