import time import os import logging from transformers import AutoTokenizer, TextGenerationPipeline from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig import numpy as np import torch import torch.nn as nn import argparse def get_wikitext2(nsamples, seed, seqlen, tokenizer): from datasets import load_dataset logger = logging.getLogger(__name__) wikidata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') wikilist = [' \n' if s == '' else s for s in wikidata['text'] ] text = ''.join(wikilist) saved_tokens = "./wikitext.trainenc.pth" if not os.path.isfile(saved_tokens): logger.info("Tokenising wikitext2") trainenc = tokenizer(text, return_tensors='pt') logger.info("Saving wikitext2 tokens for later re-use") torch.save(trainenc, saved_tokens) else: logger.info("Loading saved tokens for wikitext2") trainenc = torch.load(saved_tokens) import random random.seed(seed) np.random.seed(0) torch.random.manual_seed(0) traindataset = [] for _ in range(nsamples): i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = trainenc.input_ids[:, i:j] attention_mask = torch.ones_like(inp) traindataset.append({'input_ids':inp,'attention_mask': attention_mask}) return traindataset def get_c4(nsamples, seed, seqlen, tokenizer): from datasets import load_dataset traindata = load_dataset( 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False ) import random random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(traindata) - 1) trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') if trainenc.input_ids.shape[1] >= seqlen: break i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = trainenc.input_ids[:, i:j] attention_mask = torch.ones_like(inp) trainloader.append({'input_ids':inp,'attention_mask': attention_mask}) return trainloader def quantize(model_dir, output_dir, traindataset, bits, group_size, desc_act, damp, batch_size = 1, use_triton=False): quantize_config = BaseQuantizeConfig( bits=bits, group_size=group_size, desc_act=desc_act, damp_percent=damp ) logger.info(f"Loading model from {model_dir}") model = AutoGPTQForCausalLM.from_pretrained(model_dir, quantize_config=quantize_config, low_cpu_mem_usage=True) logger.info(f"Starting quantization to {output_dir} with use_triton={use_triton}") start_time = time.time() model.quantize(traindataset, use_triton=use_triton, batch_size=batch_size) logger.info(f"Time to quantize model at {output_dir} with use_triton={use_triton}: {time.time() - start_time:.2f}") logger.info(f"Saving quantized model to {output_dir}") model.save_quantized(output_dir, use_safetensors=True) logger.info("Done.") if __name__ == "__main__": logger = logging.getLogger() logging.basicConfig( format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" ) parser = argparse.ArgumentParser(description='quantise') parser.add_argument('pretrained_model_dir', type=str, help='Repo name') parser.add_argument('output_dir_base', type=str, help='Output base folder') parser.add_argument('dataset', type=str, help='Output base folder') parser.add_argument('--use_triton', action="store_true", help='Use Triton for quantization') parser.add_argument('--bits', type=int, nargs='+', default=[4], help='Quantize bit(s)') parser.add_argument('--group_size', type=int, nargs='+', default=[32, 128, 1024, -1], help='Quantize group size(s)') parser.add_argument('--damp', type=float, nargs='+', default=[0.01], help='Quantize damp_percent(s)') parser.add_argument('--desc_act', type=int, nargs='+', default=[0, 1], help='Quantize desc_act(s) - 1 = True, 0 = False') parser.add_argument('--batch_size', type=int, default=1, help='Quantize batch size for processing dataset samples') parser.add_argument('--stop_file', type=str, help='Filename to look for to stop inference, specific to this instance') args = parser.parse_args() stop_file = args.stop_file or "" tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_dir, use_fast=True) if args.dataset == 'wikitext': traindataset = get_wikitext2(128, 0, 2048, tokenizer) elif args.dataset == 'c4': traindataset = get_c4(128, 0, 2048, tokenizer) else: logger.error(f"Unsupported dataset: {dataset}") raise ValueError(f"Unsupported dataset: {dataset}") abort = False iterations=[] for bits in args.bits: for group_size in args.group_size: for desc_act in args.desc_act: for damp in args.damp: desc_act = desc_act == 1 and True or False iterations.append({"bits": bits, "group_size": group_size, "desc_act": desc_act, "damp": damp}) num_iters = len(iterations) logger.info(f"Starting {num_iters} quantizations.") count=1 for iter in iterations: if not os.path.isfile("/workspace/gptq-ppl-test/STOP") and not os.path.isfile(stop_file) and not abort: bits = iter['bits'] group_size = iter['group_size'] desc_act = iter['desc_act'] damp = iter['damp'] model_name = f"{bits}bit-{group_size}g-desc_{desc_act}-damp-{damp}" output_base = args.output_dir_base + f"/{args.dataset}/" output_dir = output_base + model_name try: os.makedirs(output_dir, exist_ok=False) # Log file has same name as directory + .quantize.log, and is placed alongside model directory, not inside it # This ensures that we can delete the output_dir in case of error or abort, without losing the logfile. # Therefore the existence of the output_dir is a reliable indicator of whether a model has started or not. log_file = output_dir + ".quantize.log" fh = logging.FileHandler(log_file) fh.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s %(levelname)s [%(name)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") fh.setFormatter(formatter) logger.addHandler(fh) logger.info(f"[{count} / {num_iters}] Quantizing: bits = {bits} - group_size = {group_size} - desc_act = {desc_act} - damp_percent = {damp} to {output_dir}") try: quantize(args.pretrained_model_dir, output_dir, traindataset, bits, group_size, desc_act, damp, args.batch_size, args.use_triton) except KeyboardInterrupt: logger.error(f"Aborted. Will delete {output_dir}") os.rmdir(output_dir) abort = True except: logger.error(f"Quantize failed due to exception. Will delete {output_dir}") os.rmdir(output_dir) move_log = log_file.replace(".log", ".log.failed") if os.path.isfile(move_log): os.remove(move_log) os.rename(log_file, move_log) raise finally: logger.removeHandler(fh) except FileExistsError: # model directory already exists, therefore this model has already been done, or is in-progress on another node logger.warning(f"Already exists: {output_dir}. Skipping.") finally: count += 1 else: logger.error(f"Aborting - told to stop!") break