Spaces:
Runtime error
Runtime error
from tqdm import tqdm | |
import torch | |
from torch.nn import CrossEntropyLoss | |
def evaluate_model(model, tokenizer, dl): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
losses = [] | |
for batch in dl: | |
batch = tokenizer(batch, padding=True, return_tensors='pt', truncation=True, max_length=150) | |
labels = torch.tensor([ | |
[-100 if mask == 0 else token for mask, token in mask_and_tokens] for mask_and_tokens in [zip(masks, labels) for masks, labels in zip(batch['attention_mask'], batch['input_ids'])] | |
]) | |
batch['labels'] = labels | |
batch = {k: v.to(device) for k, v in batch.items()} | |
with torch.no_grad(): | |
outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels']) | |
shift_logits = outputs.logits[..., :-1, :].contiguous() | |
shift_labels = batch['labels'][..., 1:].contiguous() | |
loss_fct = CrossEntropyLoss(reduction='none') | |
loss = loss_fct(shift_logits.transpose(1,2), shift_labels) | |
num_tokens = torch.sum(shift_labels != -100, dim=1) | |
loss_sum = torch.sum(loss, dim=1) | |
loss = loss_sum / num_tokens | |
losses.append(loss) | |
losses = torch.cat(losses) | |
return losses | |