import argparse import os import glog import torch from transformers import AutoTokenizer from model.version import MODEL_VERSION from model.llama import LlamaForCausalLM as llama_fuse from model.mistral import MistralForCausalLM from lib import codebook from lib.utils.unsafe_import import model_from_hf_path import time torch.set_grad_enabled(False) parser = argparse.ArgumentParser() parser.add_argument('--quantized_path', type=str) parser.add_argument('--hf_output_path', type=str) def unpack_quip(module, saved_layer, codebook_id, codesz): (m, n) = saved_layer['Qidxs'].shape if codebook_id in codebook.cache_permute_set: module.Qidxs.copy_(saved_layer['Qidxs'].view(m, n // codesz, codesz).permute(1, 0, 2).reshape(m, n).contiguous()) else: module.Qidxs.copy_(saved_layer['Qidxs']) if module.rank > 0: module.A.copy_(saved_layer['A']) module.B.copy_(saved_layer['B']) module.SU.copy_(saved_layer['SU']) module.SV.copy_(saved_layer['SV']) if module.rescale_WH: module.scaleWH.copy_(saved_layer['scaleWH']) module.codebook_id.copy_(codebook_id) def main(args): assert os.path.exists(args.quantized_path) saved_config = torch.load(os.path.join(args.quantized_path, 'config.pt')) model_config = saved_config['model_config'] codebook_id = codebook.get_id(model_config.quip_params['codebook']) codesz = model_config.quip_params['codesz'] tokenizer = AutoTokenizer.from_pretrained(model_config._name_or_path) model_type = model_config.model_type model_config.quip_params['model_version'] = MODEL_VERSION if model_type == 'llama': model_cls = llama_fuse elif model_type == 'mistral': model_cls = MistralForCausalLM else: raise Exception model = model_cls.from_pretrained(model_config._name_or_path, torch_dtype='auto', low_cpu_mem_usage=True, config=model_config).half() for ii in range(len(model.model.layers)): glog.info(f'updating layer {ii}') layer = model.model.layers[ii] cpu = torch.device('cpu') glog.info(f'loading layer {ii} qkv') saved_layer = torch.load(f'{args.quantized_path}/{ii}_qkv.pt', map_location=cpu) layer.self_attn.qkv_proj.fuse_scales[0].copy_(saved_layer['W_q_scale']) layer.self_attn.qkv_proj.fuse_scales[1].copy_(saved_layer['W_k_scale']) layer.self_attn.qkv_proj.fuse_scales[2].copy_(saved_layer['W_v_scale']) layer.self_attn.qkv_proj.Wscale.copy_(saved_layer['Wscale']) unpack_quip(layer.self_attn.qkv_proj, saved_layer, codebook_id, codesz) glog.info(f'loading layer {ii} up') saved_layer = torch.load(f'{args.quantized_path}/{ii}_up.pt', map_location=cpu) layer.mlp.upgate_proj.fuse_scales[0].copy_(saved_layer['W_up_scale']) layer.mlp.upgate_proj.fuse_scales[1].copy_(saved_layer['W_gate_scale']) layer.mlp.upgate_proj.Wscale.copy_(saved_layer['Wscale']) unpack_quip(layer.mlp.upgate_proj, saved_layer, codebook_id, codesz) glog.info(f'loading layer {ii} o') saved_layer = torch.load(f'{args.quantized_path}/{ii}_o.pt', map_location=cpu) layer.self_attn.o_proj.Wscale.copy_(saved_layer['W_o_scale'] * saved_layer['Wscale']) unpack_quip(layer.self_attn.o_proj, saved_layer, codebook_id, codesz) glog.info(f'loading layer {ii} down') saved_layer = torch.load(f'{args.quantized_path}/{ii}_down.pt', map_location=cpu) layer.mlp.down_proj.Wscale.copy_(saved_layer['W_down_scale'] * saved_layer['Wscale']) if model_config.quip_params['outlier_channel_split']: layer.mlp.down_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds'])) unpack_quip(layer.mlp.down_proj, saved_layer, codebook_id, codesz) glog.info(f'saving model...') model.save_pretrained(args.hf_output_path, safe_serialization=True) del model model, _ = model_from_hf_path(args.hf_output_path, use_cuda_graph=False) glog.info('successfully loaded hfized model') glog.info('generating some text...') start = time.time() prompt = 'It is a truth universally acknowledged that' inputs = tokenizer(prompt, return_tensors='pt') outputs = model.generate(input_ids=inputs['input_ids'].cuda(), attention_mask=inputs['attention_mask'].cuda(), max_new_tokens=64, return_dict_in_generate=True) token = outputs.sequences[0, :] output_str = tokenizer.decode(token) glog.info(output_str) glog.info(f'elapsed: {time.time() - start}') if __name__ == '__main__': torch.set_grad_enabled(False) torch.manual_seed(0) args = parser.parse_args() main(args)