# code adapted from https://huggingface.co/fahadh4ilyas import argparse, json from safetensors import safe_open from safetensors.torch import save_file from pathlib import Path parser = argparse.ArgumentParser(description="Convert original dbrx model into quantizable model") parser.add_argument("--model-dir", type=str, required=True, help="directory to the original dbrx model") parser.add_argument("--output-dir", type=str, required=True, help="directory for the converted dbrx model") args = parser.parse_args() model_dir = Path(args.model_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) NUM_EXPERTS = 16 HIDDEN_SIZE = 6144 HEAD_DIM = 128 NUM_KV_HEAD = 8 FFN_HIDDEN_SIZE = 10752 def change_tensor_attn(tensor): return [x.contiguous() for x in tensor.split([HIDDEN_SIZE, NUM_KV_HEAD*HEAD_DIM, NUM_KV_HEAD*HEAD_DIM])] def change_attn(tensors): keys = list(tensors.keys()) for k in keys: if 'Wqkv' in k: prefix = k.removesuffix('.Wqkv.weight') tensor = tensors.pop(k) output_tensor = change_tensor_attn(tensor) for dtype,t in zip(['q_proj', 'k_proj', 'v_proj'], output_tensor): tensors[f'{prefix}.{dtype}.weight'] = t return tensors def change_tensor_mlp(tensor, reverse=False): output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)] return output def change_mlp(tensors): keys = list(tensors.keys()) for k in keys: if any([x in k for x in ['w1', 'v1', 'w2']]): prefix,dtype = k.rsplit('.', 1) tensor = tensors.pop(k) output_tensor = change_tensor_mlp(tensor, dtype=='w2') for i,t in enumerate(output_tensor): tensors[f'{prefix}.{i}.{dtype}.weight'] = t return tensors for file in sorted(list(model_dir.glob('*.safetensors'))): print(file) tensors = {} with safe_open(file, 'pt') as f: metadata = f.metadata() for k in f.keys(): tensors[k] = f.get_tensor(k) tensors = change_attn(tensors) tensors = change_mlp(tensors) save_file(tensors, (output_dir / file.name).as_posix(), metadata) with open(model_dir / 'model.safetensors.index.json') as f: weight_map = json.load(f) weight_keys = list(weight_map['weight_map']) for k in weight_keys: if any([x in k for x in ['w1', 'v1', 'w2']]): prefix,dtype = k.rsplit('.', 1) value = weight_map['weight_map'].pop(k) for i in range(NUM_EXPERTS): weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value elif 'Wqkv' in k: prefix = k.removesuffix('.Wqkv.weight') value = weight_map['weight_map'].pop(k) for dtype in ['q_proj', 'k_proj', 'v_proj']: weight_map['weight_map'][f'{prefix}.{dtype}.weight'] = value sorted_map = sorted(weight_map['weight_map'].items()) weight_map['weight_map'] = dict(sorted_map) with open(output_dir / 'model.safetensors.index.json', 'w') as f: json.dump(weight_map, f, indent=4)