|
from dataclasses import dataclass, field |
|
import os |
|
from os.path import isdir, isfile |
|
from pathlib import Path |
|
import sys |
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
@dataclass |
|
class GptqConfig: |
|
ckpt: str = field( |
|
default=None, |
|
metadata={ |
|
"help": "Load quantized model. The path to the local GPTQ checkpoint." |
|
}, |
|
) |
|
wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) |
|
groupsize: int = field( |
|
default=-1, |
|
metadata={"help": "Groupsize to use for quantization; default uses full row."}, |
|
) |
|
act_order: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, |
|
) |
|
|
|
|
|
def load_gptq_quantized(model_name, gptq_config: GptqConfig): |
|
print("Loading GPTQ quantized model...") |
|
|
|
try: |
|
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
|
module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") |
|
|
|
sys.path.insert(0, module_path) |
|
from llama import load_quant |
|
except ImportError as e: |
|
print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") |
|
print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") |
|
sys.exit(-1) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
|
|
if gptq_config.act_order: |
|
model = load_quant( |
|
model_name, |
|
find_gptq_ckpt(gptq_config), |
|
gptq_config.wbits, |
|
gptq_config.groupsize, |
|
act_order=gptq_config.act_order, |
|
) |
|
else: |
|
|
|
model = load_quant( |
|
model_name, |
|
find_gptq_ckpt(gptq_config), |
|
gptq_config.wbits, |
|
gptq_config.groupsize, |
|
) |
|
|
|
return model, tokenizer |
|
|
|
|
|
def find_gptq_ckpt(gptq_config: GptqConfig): |
|
if Path(gptq_config.ckpt).is_file(): |
|
return gptq_config.ckpt |
|
|
|
for ext in ["*.pt", "*.safetensors"]: |
|
matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) |
|
if len(matched_result) > 0: |
|
return str(matched_result[-1]) |
|
|
|
print("Error: gptq checkpoint not found") |
|
sys.exit(1) |
|
|