|
import os |
|
import math |
|
import json |
|
import argparse |
|
import torch |
|
import datasets |
|
from lib.utils import gptq_data_utils |
|
from lib.utils.unsafe_import import model_from_hf_path |
|
import random |
|
import glog |
|
|
|
from tqdm import tqdm |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--seed', default=0, type=int) |
|
parser.add_argument('--hf_path', default='hfized/quantized_hada_70b', type=str) |
|
parser.add_argument('--seqlen', default=4096, type=int) |
|
parser.add_argument('--no_use_cuda_graph', action='store_true') |
|
parser.add_argument('--no_use_flash_attn', action='store_true') |
|
|
|
|
|
def main(args): |
|
datasets = ['wikitext2', 'c4'] |
|
model, model_str = model_from_hf_path(args.hf_path, |
|
use_cuda_graph=not args.no_use_cuda_graph, |
|
use_flash_attn=not args.no_use_flash_attn) |
|
|
|
for dataset in datasets: |
|
input_tok = gptq_data_utils.get_test_tokens(dataset, |
|
seed=args.seed, |
|
seqlen=args.seqlen, |
|
model=model_str) |
|
nsamples = input_tok.numel() // args.seqlen |
|
input_tok = input_tok[0, :(args.seqlen * nsamples)].view(nsamples, args.seqlen) |
|
|
|
if not args.no_use_cuda_graph: |
|
model.reset() |
|
|
|
loss_fct = torch.nn.CrossEntropyLoss().cuda() |
|
acc_loss = 0.0 |
|
progress = tqdm(range(nsamples)) |
|
for ii in progress: |
|
input = input_tok[ii, :].cuda().view(1, -1) |
|
output = model(input, |
|
use_cache=False, |
|
output_hidden_states=False, |
|
output_attentions=False)[0] |
|
shift_logits = output[:, :-1, :].contiguous() |
|
shift_labels = input[:, 1:] |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
acc_loss += loss.item() |
|
progress.set_description(f"avg_loss = {acc_loss/(ii+1)}") |
|
|
|
avg_loss = acc_loss / nsamples |
|
|
|
ppl = torch.exp(torch.tensor(avg_loss)).item() |
|
glog.info(f'{dataset} perplexity: {ppl}') |
|
|
|
|
|
if __name__ == '__main__': |
|
torch.set_grad_enabled(False) |
|
args = parser.parse_args() |
|
random.seed(args.seed) |
|
torch.random.manual_seed(args.seed) |
|
main(args) |
|
|