File size: 758 Bytes
3494c6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def perplexity(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
text: str,
max_input_length: int = None,
):
"""
Computes perplexity of a piece of text, measured on a reference model.
Text is truncated to max_input_length tokens.
"""
inputs = tok(
[text], return_tensors="pt", max_length=max_input_length, truncation=True
).to("cuda")
logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2)
log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0]
# Perplexity = exp(-1/N * log P(x_1, ..., x_n))
return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item()
|