kousw's picture
Upload 21 files
29964ce verified
import argparse
import math
import random
import torch
from tqdm import tqdm
from eval_utils import get_test_dataset
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
torch.set_grad_enabled(False)
parser = argparse.ArgumentParser()
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--hf_path", default="./", type=str) # 1bitLLM/bitnet_b1_58-3B
parser.add_argument("--seqlen", default=2048, type=int)
parser.add_argument("--max_dataset_size", default=100000, type=int)
def calulate_loss(model, input, loss_fct):
output = model(
input, use_cache=False, output_hidden_states=False, output_attentions=False
)[0]
shift_logits = output[:, :-1, :].contiguous()
shift_labels = input[:, 1:]
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return loss
def main(args):
datasets = ["wikitext2"] # ['c4', 'wikitext2']
model = BitnetForCausalLM.from_pretrained(
args.hf_path,
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda()
ppl = []
for dataset in datasets:
testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen)
acc_loss, count = 0.0, 0
dataset_size = (
args.max_dataset_size
if len(testdata) > args.max_dataset_size
else len(testdata)
)
progress = tqdm(range(dataset_size))
for ii in progress:
input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1)
loss = calulate_loss(model, input, loss_fct)
count += input.size(-1) - 1
acc_loss += loss.item()
progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}")
avg_loss = acc_loss / count / math.log(2)
ppl.append(2**avg_loss)
print("{} PPL: {}".format(dataset, ppl[-1]))
print(ppl)
print("Avg PPL:", sum(ppl) / len(ppl))
if __name__ == "__main__":
torch.set_grad_enabled(False)
args = parser.parse_args()
random.seed(args.seed)
torch.random.manual_seed(args.seed)
main(args)