File size: 2,154 Bytes
3494c6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
import numpy as np
from .utils import *
def is_acc_error(model, tokens):
# Check whether or not the model's prediction for a batch element is correct
labels = tokens["labels"]
logits = model(**tokens).logits
probs = torch.softmax(logits, -1).squeeze()
argmaxs = torch.argmax(probs, dim=-1).squeeze()
return labels != argmaxs
def Accuracy(model, tokens):
labels = tokens["labels"]
new_tokens = {f"{k}" : v for k, v in tokens.items() if k != "labels"}
logits = model(**new_tokens).logits
probs = torch.softmax(logits, -1).squeeze()
argmaxs = torch.argmax(probs, dim=-1).squeeze()
return (labels == argmaxs).float().mean()
def is_qa_error(model, tokens):
preds = model.generate(tokens["input_ids"], max_length=20).squeeze() # Run model to get its predictions
labels = tokens["labels"]#[tokens["labels"] != -100]
if (len(preds) != len(labels)) or ((preds == labels).sum() != len(preds)):
return True
else:
return False
def PPL(model, batch):
input_ids = batch["input_ids"][:, :1024]#.to(device)
if "labels" not in batch:
target_ids = batch["input_ids"][:, :1024].clone()
else:
target_ids = batch["labels"][:, :1024].clone()
with torch.no_grad():
outputs = model(input_ids=input_ids, labels=target_ids)
nll = outputs.loss
ppl = torch.exp(nll)#.clip(0, 100)
return ppl
def F1(model, batch):
try:
preds = model.generate(batch["input_ids"], max_length=20).squeeze()
if len(preds) > 1:
preds = preds[preds != model.tokenizer.pad_token_id]
gold_toks = batch["labels"][batch["labels"] != -100].cpu().squeeze() # -100 might be nonsense
num_same = len(np.intersect1d(preds.cpu().squeeze(), gold_toks))
if (num_same == 0) or (len(preds.squeeze()) == 0):
return 0
precision = num_same / len(preds.squeeze())
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
except:
# Every once in a while, the model just returns the stop token
return 0
|