from dataclasses import dataclass, field from pathlib import Path import sys import torch from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils @dataclass class AWQConfig: ckpt: str = field( default=None, metadata={ "help": "Load quantized model. The path to the local AWQ checkpoint." }, ) wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) groupsize: int = field( default=-1, metadata={"help": "Groupsize to use for quantization; default uses full row."}, ) def load_awq_quantized(model_name, awq_config: AWQConfig, device): print("Loading AWQ quantized model...") try: from tinychat.utils import load_quant from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp except ImportError as e: print(f"Error: Failed to import tinychat. {e}") print("Please double check if you have successfully installed AWQ") print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") sys.exit(-1) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=False, trust_remote_code=True ) def skip(*args, **kwargs): pass torch.nn.init.kaiming_uniform_ = skip torch.nn.init.kaiming_normal_ = skip torch.nn.init.uniform_ = skip torch.nn.init.normal_ = skip modeling_utils._init_weights = False torch.set_default_dtype(torch.half) model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): model = load_quant.load_awq_llama_fast( model, find_awq_ckpt(awq_config), awq_config.wbits, awq_config.groupsize, device, ) make_quant_attn(model, device) make_quant_norm(model) make_fused_mlp(model) else: model = load_quant.load_awq_model( model, find_awq_ckpt(awq_config), awq_config.wbits, awq_config.groupsize, device, ) return model, tokenizer def find_awq_ckpt(awq_config: AWQConfig): if Path(awq_config.ckpt).is_file(): return awq_config.ckpt for ext in ["*.pt", "*.safetensors"]: matched_result = sorted(Path(awq_config.ckpt).glob(ext)) if len(matched_result) > 0: return str(matched_result[-1]) print("Error: AWQ checkpoint not found") sys.exit(1)