File size: 3,112 Bytes
8695694 d4ec07a 8695694 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# code adapted from https://huggingface.co/fahadh4ilyas
import argparse, json
from safetensors import safe_open
from safetensors.torch import save_file
from pathlib import Path
parser = argparse.ArgumentParser(description="Convert original dbrx model into quantizable model")
parser.add_argument("--model-dir", type=str, required=True, help="directory to the original dbrx model")
parser.add_argument("--output-dir", type=str, required=True, help="directory for the converted dbrx model")
args = parser.parse_args()
model_dir = Path(args.model_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
NUM_EXPERTS = 16
HIDDEN_SIZE = 6144
HEAD_DIM = 128
NUM_KV_HEAD = 8
FFN_HIDDEN_SIZE = 10752
def change_tensor_attn(tensor):
return [x.contiguous() for x in tensor.split([HIDDEN_SIZE, NUM_KV_HEAD*HEAD_DIM, NUM_KV_HEAD*HEAD_DIM])]
def change_attn(tensors):
keys = list(tensors.keys())
for k in keys:
if 'Wqkv' in k:
prefix = k.removesuffix('.Wqkv.weight')
tensor = tensors.pop(k)
output_tensor = change_tensor_attn(tensor)
for dtype,t in zip(['q_proj', 'k_proj', 'v_proj'], output_tensor):
tensors[f'{prefix}.{dtype}.weight'] = t
return tensors
def change_tensor_mlp(tensor, reverse=False):
output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)]
return output
def change_mlp(tensors):
keys = list(tensors.keys())
for k in keys:
if any([x in k for x in ['w1', 'v1', 'w2']]):
prefix,dtype = k.rsplit('.', 1)
tensor = tensors.pop(k)
output_tensor = change_tensor_mlp(tensor, dtype=='w2')
for i,t in enumerate(output_tensor):
tensors[f'{prefix}.{i}.{dtype}.weight'] = t
return tensors
for file in sorted(list(model_dir.glob('*.safetensors'))):
print(file)
tensors = {}
with safe_open(file, 'pt') as f:
metadata = f.metadata()
for k in f.keys():
tensors[k] = f.get_tensor(k)
tensors = change_attn(tensors)
tensors = change_mlp(tensors)
save_file(tensors, (output_dir / file.name).as_posix(), metadata)
with open(model_dir / 'model.safetensors.index.json') as f:
weight_map = json.load(f)
weight_keys = list(weight_map['weight_map'])
for k in weight_keys:
if any([x in k for x in ['w1', 'v1', 'w2']]):
prefix,dtype = k.rsplit('.', 1)
value = weight_map['weight_map'].pop(k)
for i in range(NUM_EXPERTS):
weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value
elif 'Wqkv' in k:
prefix = k.removesuffix('.Wqkv.weight')
value = weight_map['weight_map'].pop(k)
for dtype in ['q_proj', 'k_proj', 'v_proj']:
weight_map['weight_map'][f'{prefix}.{dtype}.weight'] = value
sorted_map = sorted(weight_map['weight_map'].items())
weight_map['weight_map'] = dict(sorted_map)
with open(output_dir / 'model.safetensors.index.json', 'w') as f:
json.dump(weight_map, f, indent=4) |