tianleliphoebe's picture
Upload folder using huggingface_hub
ec0c335 verified
raw
history blame contribute delete
No virus
2.64 kB
from dataclasses import dataclass, field
from pathlib import Path
import sys
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils
@dataclass
class AWQConfig:
ckpt: str = field(
default=None,
metadata={
"help": "Load quantized model. The path to the local AWQ checkpoint."
},
)
wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"})
groupsize: int = field(
default=-1,
metadata={"help": "Groupsize to use for quantization; default uses full row."},
)
def load_awq_quantized(model_name, awq_config: AWQConfig, device):
print("Loading AWQ quantized model...")
try:
from tinychat.utils import load_quant
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
except ImportError as e:
print(f"Error: Failed to import tinychat. {e}")
print("Please double check if you have successfully installed AWQ")
print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md")
sys.exit(-1)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_name, use_fast=False, trust_remote_code=True
)
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]):
model = load_quant.load_awq_llama_fast(
model,
find_awq_ckpt(awq_config),
awq_config.wbits,
awq_config.groupsize,
device,
)
make_quant_attn(model, device)
make_quant_norm(model)
make_fused_mlp(model)
else:
model = load_quant.load_awq_model(
model,
find_awq_ckpt(awq_config),
awq_config.wbits,
awq_config.groupsize,
device,
)
return model, tokenizer
def find_awq_ckpt(awq_config: AWQConfig):
if Path(awq_config.ckpt).is_file():
return awq_config.ckpt
for ext in ["*.pt", "*.safetensors"]:
matched_result = sorted(Path(awq_config.ckpt).glob(ext))
if len(matched_result) > 0:
return str(matched_result[-1])
print("Error: AWQ checkpoint not found")
sys.exit(1)