import os import math import json import argparse import torch import datasets from lib.utils import gptq_data_utils from lib.utils.unsafe_import import model_from_hf_path import random import glog from tqdm import tqdm torch.set_grad_enabled(False) parser = argparse.ArgumentParser() parser.add_argument('--seed', default=0, type=int) parser.add_argument('--hf_path', default='hfized/quantized_hada_70b', type=str) parser.add_argument('--seqlen', default=4096, type=int) parser.add_argument('--no_use_cuda_graph', action='store_true') parser.add_argument('--no_use_flash_attn', action='store_true') def main(args): datasets = ['wikitext2', 'c4'] model, model_str = model_from_hf_path(args.hf_path, use_cuda_graph=not args.no_use_cuda_graph, use_flash_attn=not args.no_use_flash_attn) for dataset in datasets: input_tok = gptq_data_utils.get_test_tokens(dataset, seed=args.seed, seqlen=args.seqlen, model=model_str) nsamples = input_tok.numel() // args.seqlen input_tok = input_tok[0, :(args.seqlen * nsamples)].view(nsamples, args.seqlen) if not args.no_use_cuda_graph: model.reset() loss_fct = torch.nn.CrossEntropyLoss().cuda() acc_loss = 0.0 progress = tqdm(range(nsamples)) for ii in progress: input = input_tok[ii, :].cuda().view(1, -1) output = model(input, use_cache=False, output_hidden_states=False, output_attentions=False)[0] shift_logits = output[:, :-1, :].contiguous() shift_labels = input[:, 1:] loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) acc_loss += loss.item() progress.set_description(f"avg_loss = {acc_loss/(ii+1)}") avg_loss = acc_loss / nsamples ppl = torch.exp(torch.tensor(avg_loss)).item() glog.info(f'{dataset} perplexity: {ppl}') if __name__ == '__main__': torch.set_grad_enabled(False) args = parser.parse_args() random.seed(args.seed) torch.random.manual_seed(args.seed) main(args)