|
import argparse |
|
import os |
|
import glog |
|
import torch |
|
from transformers import AutoTokenizer |
|
from model.version import MODEL_VERSION |
|
from model.llama import LlamaForCausalLM as llama_fuse |
|
from model.mistral import MistralForCausalLM |
|
from lib import codebook |
|
from lib.utils.unsafe_import import model_from_hf_path |
|
import time |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--quantized_path', type=str) |
|
parser.add_argument('--hf_output_path', type=str) |
|
|
|
|
|
def unpack_quip(module, saved_layer, codebook_id, codesz): |
|
(m, n) = saved_layer['Qidxs'].shape |
|
if codebook_id in codebook.cache_permute_set: |
|
module.Qidxs.copy_(saved_layer['Qidxs'].view(m, n // codesz, |
|
codesz).permute(1, 0, |
|
2).reshape(m, n).contiguous()) |
|
else: |
|
module.Qidxs.copy_(saved_layer['Qidxs']) |
|
|
|
if module.rank > 0: |
|
module.A.copy_(saved_layer['A']) |
|
module.B.copy_(saved_layer['B']) |
|
module.SU.copy_(saved_layer['SU']) |
|
module.SV.copy_(saved_layer['SV']) |
|
if module.rescale_WH: |
|
module.scaleWH.copy_(saved_layer['scaleWH']) |
|
|
|
module.codebook_id.copy_(codebook_id) |
|
|
|
|
|
def main(args): |
|
assert os.path.exists(args.quantized_path) |
|
saved_config = torch.load(os.path.join(args.quantized_path, 'config.pt')) |
|
model_config = saved_config['model_config'] |
|
|
|
codebook_id = codebook.get_id(model_config.quip_params['codebook']) |
|
codesz = model_config.quip_params['codesz'] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_config._name_or_path) |
|
|
|
model_type = model_config.model_type |
|
model_config.quip_params['model_version'] = MODEL_VERSION |
|
|
|
if model_type == 'llama': |
|
model_cls = llama_fuse |
|
elif model_type == 'mistral': |
|
model_cls = MistralForCausalLM |
|
else: |
|
raise Exception |
|
|
|
model = model_cls.from_pretrained(model_config._name_or_path, |
|
torch_dtype='auto', |
|
low_cpu_mem_usage=True, |
|
config=model_config).half() |
|
|
|
for ii in range(len(model.model.layers)): |
|
glog.info(f'updating layer {ii}') |
|
|
|
layer = model.model.layers[ii] |
|
cpu = torch.device('cpu') |
|
|
|
glog.info(f'loading layer {ii} qkv') |
|
saved_layer = torch.load(f'{args.quantized_path}/{ii}_qkv.pt', map_location=cpu) |
|
layer.self_attn.qkv_proj.fuse_scales[0].copy_(saved_layer['W_q_scale']) |
|
layer.self_attn.qkv_proj.fuse_scales[1].copy_(saved_layer['W_k_scale']) |
|
layer.self_attn.qkv_proj.fuse_scales[2].copy_(saved_layer['W_v_scale']) |
|
layer.self_attn.qkv_proj.Wscale.copy_(saved_layer['Wscale']) |
|
unpack_quip(layer.self_attn.qkv_proj, saved_layer, codebook_id, codesz) |
|
|
|
glog.info(f'loading layer {ii} up') |
|
saved_layer = torch.load(f'{args.quantized_path}/{ii}_up.pt', map_location=cpu) |
|
layer.mlp.upgate_proj.fuse_scales[0].copy_(saved_layer['W_up_scale']) |
|
layer.mlp.upgate_proj.fuse_scales[1].copy_(saved_layer['W_gate_scale']) |
|
layer.mlp.upgate_proj.Wscale.copy_(saved_layer['Wscale']) |
|
unpack_quip(layer.mlp.upgate_proj, saved_layer, codebook_id, codesz) |
|
|
|
glog.info(f'loading layer {ii} o') |
|
saved_layer = torch.load(f'{args.quantized_path}/{ii}_o.pt', map_location=cpu) |
|
layer.self_attn.o_proj.Wscale.copy_(saved_layer['W_o_scale'] * saved_layer['Wscale']) |
|
unpack_quip(layer.self_attn.o_proj, saved_layer, codebook_id, codesz) |
|
|
|
glog.info(f'loading layer {ii} down') |
|
saved_layer = torch.load(f'{args.quantized_path}/{ii}_down.pt', map_location=cpu) |
|
layer.mlp.down_proj.Wscale.copy_(saved_layer['W_down_scale'] * saved_layer['Wscale']) |
|
if model_config.quip_params['outlier_channel_split']: |
|
layer.mlp.down_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds'])) |
|
unpack_quip(layer.mlp.down_proj, saved_layer, codebook_id, codesz) |
|
|
|
glog.info(f'saving model...') |
|
model.save_pretrained(args.hf_output_path, safe_serialization=True) |
|
|
|
del model |
|
|
|
model, _ = model_from_hf_path(args.hf_output_path, use_cuda_graph=False) |
|
|
|
glog.info('successfully loaded hfized model') |
|
|
|
glog.info('generating some text...') |
|
|
|
start = time.time() |
|
prompt = 'It is a truth universally acknowledged that' |
|
inputs = tokenizer(prompt, return_tensors='pt') |
|
outputs = model.generate(input_ids=inputs['input_ids'].cuda(), |
|
attention_mask=inputs['attention_mask'].cuda(), |
|
max_new_tokens=64, |
|
return_dict_in_generate=True) |
|
token = outputs.sequences[0, :] |
|
output_str = tokenizer.decode(token) |
|
glog.info(output_str) |
|
glog.info(f'elapsed: {time.time() - start}') |
|
|
|
|
|
if __name__ == '__main__': |
|
torch.set_grad_enabled(False) |
|
torch.manual_seed(0) |
|
args = parser.parse_args() |
|
main(args) |
|
|