dbrx-base-converted-v2 / convert_v2.py
Qubitium's picture
Fix import typo (#2)
3babc68 verified
raw history blame
No virus
3.11 kB
# code adapted from https://huggingface.co/fahadh4ilyas
import argparse, json
from safetensors import safe_open
from safetensors.torch import save_file
from pathlib import Path
parser = argparse.ArgumentParser(description="Convert original dbrx model into quantizable model")
parser.add_argument("--model-dir", type=str, required=True, help="directory to the original dbrx model")
parser.add_argument("--output-dir", type=str, required=True, help="directory for the converted dbrx model")
args = parser.parse_args()
model_dir = Path(args.model_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
NUM_EXPERTS = 16
HIDDEN_SIZE = 6144
HEAD_DIM = 128
NUM_KV_HEAD = 8
FFN_HIDDEN_SIZE = 10752
def change_tensor_attn(tensor):
return [x.contiguous() for x in tensor.split([HIDDEN_SIZE, NUM_KV_HEAD*HEAD_DIM, NUM_KV_HEAD*HEAD_DIM])]
def change_attn(tensors):
keys = list(tensors.keys())
for k in keys:
if 'Wqkv' in k:
prefix = k.removesuffix('.Wqkv.weight')
tensor = tensors.pop(k)
output_tensor = change_tensor_attn(tensor)
for dtype,t in zip(['q_proj', 'k_proj', 'v_proj'], output_tensor):
tensors[f'{prefix}.{dtype}.weight'] = t
return tensors
def change_tensor_mlp(tensor, reverse=False):
output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)]
return output
def change_mlp(tensors):
keys = list(tensors.keys())
for k in keys:
if any([x in k for x in ['w1', 'v1', 'w2']]):
prefix,dtype = k.rsplit('.', 1)
tensor = tensors.pop(k)
output_tensor = change_tensor_mlp(tensor, dtype=='w2')
for i,t in enumerate(output_tensor):
tensors[f'{prefix}.{i}.{dtype}.weight'] = t
return tensors
for file in sorted(list(model_dir.glob('*.safetensors'))):
print(file)
tensors = {}
with safe_open(file, 'pt') as f:
metadata = f.metadata()
for k in f.keys():
tensors[k] = f.get_tensor(k)
tensors = change_attn(tensors)
tensors = change_mlp(tensors)
save_file(tensors, (output_dir / file.name).as_posix(), metadata)
with open(model_dir / 'model.safetensors.index.json') as f:
weight_map = json.load(f)
weight_keys = list(weight_map['weight_map'])
for k in weight_keys:
if any([x in k for x in ['w1', 'v1', 'w2']]):
prefix,dtype = k.rsplit('.', 1)
value = weight_map['weight_map'].pop(k)
for i in range(NUM_EXPERTS):
weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value
elif 'Wqkv' in k:
prefix = k.removesuffix('.Wqkv.weight')
value = weight_map['weight_map'].pop(k)
for dtype in ['q_proj', 'k_proj', 'v_proj']:
weight_map['weight_map'][f'{prefix}.{dtype}.weight'] = value
sorted_map = sorted(weight_map['weight_map'].items())
weight_map['weight_map'] = dict(sorted_map)
with open(output_dir / 'model.safetensors.index.json', 'w') as f:
json.dump(weight_map, f, indent=4)