|
import math |
|
import argparse |
|
import torch |
|
import random |
|
|
|
from eval_utils import get_test_dataset |
|
from .modeling_bitnet import BitnetForCausalLM |
|
from .tokenization_bitnet import BitnetTokenizer |
|
|
|
from tqdm import tqdm |
|
torch.set_grad_enabled(False) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--seed', default=0, type=int) |
|
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) |
|
parser.add_argument('--seqlen', default=2048, type=int) |
|
|
|
|
|
def calulate_loss(model, input, loss_fct): |
|
output = model(input, |
|
use_cache=False, |
|
output_hidden_states=False, |
|
output_attentions=False)[0] |
|
shift_logits = output[:, :-1, :].contiguous() |
|
shift_labels = input[:, 1:] |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
return loss |
|
|
|
|
|
def main(args): |
|
datasets = ['c4', 'wikitext2'] |
|
model = BitnetForCausalLM.from_pretrained( |
|
args.hf_path, |
|
device_map='auto', |
|
low_cpu_mem_usage=True, |
|
use_flash_attention_2=True, |
|
torch_dtype=torch.float16, |
|
).half() |
|
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) |
|
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda() |
|
|
|
ppl = [] |
|
for dataset in datasets: |
|
testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen) |
|
acc_loss, count = 0.0, 0 |
|
progress = tqdm(range(len(testdata))) |
|
for ii in progress: |
|
input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) |
|
loss = calulate_loss(model, input, loss_fct) |
|
count += (input.size(-1) - 1) |
|
acc_loss += loss.item() |
|
progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}") |
|
|
|
avg_loss = acc_loss / count / math.log(2) |
|
ppl.append(2 ** avg_loss) |
|
print("{} PPL: {}".format(dataset, ppl[-1])) |
|
|
|
print(ppl) |
|
print("Avg PPL:", sum(ppl) / len(ppl)) |
|
|
|
|
|
if __name__ == '__main__': |
|
torch.set_grad_enabled(False) |
|
args = parser.parse_args() |
|
random.seed(args.seed) |
|
torch.random.manual_seed(args.seed) |
|
main(args) |