|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
def perplexity( |
|
model: AutoModelForCausalLM, |
|
tok: AutoTokenizer, |
|
text: str, |
|
max_input_length: int = None, |
|
): |
|
""" |
|
Computes perplexity of a piece of text, measured on a reference model. |
|
Text is truncated to max_input_length tokens. |
|
""" |
|
|
|
inputs = tok( |
|
[text], return_tensors="pt", max_length=max_input_length, truncation=True |
|
).to("cuda") |
|
|
|
logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2) |
|
log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0] |
|
|
|
|
|
return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item() |
|
|