orca_mini_3b_gptq_badtest / quantize_alpaca.py
SinanAkkoyun
test
56224bb
raw
history blame contribute delete
No virus
6.87 kB
import json
import random
import time
from argparse import ArgumentParser
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from datasets import Dataset
from transformers import AutoTokenizer, TextGenerationPipeline
def load_data(data_path, tokenizer, n_samples):
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
raw_data = random.sample(raw_data, k=min(n_samples, len(raw_data)))
def dummy_gen():
return raw_data
def tokenize(examples):
instructions = examples["instruction"]
inputs = examples["input"]
outputs = examples["output"]
prompts = []
texts = []
input_ids = []
attention_mask = []
for istr, inp, opt in zip(instructions, inputs, outputs):
if inp:
prompt = f"### User:\n{istr}\n\n### Input:\n{inp}\n\nResponse:\n"
text = prompt + opt
else:
prompt = f"### User:\n{istr}\n\nResponse:\n"
text = prompt + opt
if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length:
continue
tokenized_data = tokenizer(text)
input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
prompts.append(prompt)
texts.append(text)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompt": prompts
}
dataset = Dataset.from_generator(dummy_gen)
dataset = dataset.map(
tokenize,
batched=True,
batch_size=len(dataset),
num_proc=1,
keep_in_memory=True,
load_from_cache_file=False,
remove_columns=["instruction", "input"]
)
dataset = dataset.to_list()
for sample in dataset:
sample["input_ids"] = torch.LongTensor(sample["input_ids"])
sample["attention_mask"] = torch.LongTensor(sample["attention_mask"])
return dataset
def main():
parser = ArgumentParser()
parser.add_argument("--pretrained_model_dir", type=str)
parser.add_argument("--quantized_model_dir", type=str, default=None)
parser.add_argument("--bits", type=int, default=4, choices=[2, 3, 4, 8])
parser.add_argument("--group_size", type=int, default=128, help="group size, -1 means no grouping or full rank")
parser.add_argument("--desc_act", action="store_true", help="whether to quantize with desc_act")
parser.add_argument("--num_samples", type=int, default=128, help="how many samples will be used to quantize model")
parser.add_argument("--save_and_reload", action="store_true", help="whether save quantized model to disk and reload back")
parser.add_argument("--fast_tokenizer", action="store_true", help="whether use fast tokenizer")
parser.add_argument("--use_triton", action="store_true", help="whether use triton to speedup at inference")
parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="max memory used to load model per gpu")
parser.add_argument("--cpu_max_memory", type=int, default=None, help="max memory used to offload model to cpu")
parser.add_argument("--quant_batch_size", type=int, default=1, help="examples batch size for quantization")
parser.add_argument("--trust_remote_code", action="store_true", help="whether to trust remote code when loading model")
args = parser.parse_args()
max_memory = dict()
if args.per_gpu_max_memory is not None and args.per_gpu_max_memory > 0:
if torch.cuda.is_available():
max_memory.update(
{i: f"{args.per_gpu_max_memory}GIB" for i in range(torch.cuda.device_count())}
)
if args.cpu_max_memory is not None and args.cpu_max_memory > 0 and max_memory:
max_memory["cpu"] = f"{args.cpu_max_memory}GIB"
if not max_memory:
max_memory = None
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_dir,
use_fast=args.fast_tokenizer,
trust_remote_code=args.trust_remote_code
)
model = AutoGPTQForCausalLM.from_pretrained(
args.pretrained_model_dir,
quantize_config=BaseQuantizeConfig(bits=args.bits, group_size=args.group_size, desc_act=args.desc_act),
max_memory=max_memory,
trust_remote_code=args.trust_remote_code
)
examples = load_data("dataset/alpaca_data_cleaned.json", tokenizer, args.num_samples)
examples_for_quant = [
{"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]}
for example in examples
]
start = time.time()
model.quantize(
examples_for_quant,
batch_size=args.quant_batch_size,
use_triton=args.use_triton,
autotune_warmup_after_quantized=args.use_triton,
)
end = time.time()
print(f"quantization took: {end - start: .4f}s")
if not args.quantized_model_dir:
args.quantized_model_dir = args.pretrained_model_dir
if args.save_and_reload:
model.save_quantized(args.quantized_model_dir, use_safetensors=True)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
model = AutoGPTQForCausalLM.from_quantized(
args.quantized_model_dir,
device="cuda:0",
use_triton=args.use_triton,
max_memory=max_memory,
inject_fused_mlp=True,
inject_fused_attention=True,
trust_remote_code=args.trust_remote_code
)
pipeline_init_kwargs = {"model": model, "tokenizer": tokenizer}
if not max_memory:
pipeline_init_kwargs["device"] = "cuda:0"
pipeline = TextGenerationPipeline(**pipeline_init_kwargs)
for example in random.sample(examples, k=min(4, len(examples))):
print(f"prompt: {example['prompt']}")
print("-" * 42)
print(f"golden: {example['output']}")
print("-" * 42)
start = time.time()
generated_text = pipeline(
example['prompt'],
return_full_text=False,
num_beams=1,
max_length=len(example["input_ids"]) + 128 # use this instead of max_new_token to disable UserWarning when integrate with logging
)[0]['generated_text']
end = time.time()
print(f"quant: {generated_text}")
num_new_tokens = len(tokenizer(generated_text)["input_ids"])
print(f"generate {num_new_tokens} tokens using {end-start: .4f}s, {num_new_tokens / (end - start)} tokens/s.")
print("=" * 42)
if __name__ == "__main__":
import logging
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
main()