|
from transformers import AutoTokenizer, TextGenerationPipeline |
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig |
|
import logging |
|
|
|
pretrained_model_dir: str = "models/WizardLM-7B-Uncensored" |
|
quantized_model_dir: str = "./" |
|
config: dict = dict( |
|
quantize_config=dict(bits=8, desc_act=True, true_sequential=True, model_file_base_name='WizardLM-7B-Uncensored'), |
|
use_safetensors=True |
|
) |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) |
|
examples: list[dict[str, list[int]]] = [tokenizer("It was a cold night")] |
|
|
|
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, BaseQuantizeConfig(**config['quantize_config'])) |
|
model.quantize(examples) |
|
|
|
model.save_quantized(quantized_model_dir, use_safetensors=config['use_safetensors']) |
|
|