bert-perplexity / perplexity.py
eson's picture
Duplicate from eson/bert-perplexity-debug
74a60bc
# coding=utf-8
# author: xusong
# time: 2022/8/22 12:06
import numpy as np
import torch
from transformers import FillMaskPipeline
class PerplexityPipeline(FillMaskPipeline):
def create_sequential_mask(self, input_data, mask_count=1):
_, seq_length = input_data["input_ids"].shape
mask_count = seq_length - 2
input_ids = input_data["input_ids"]
new_input_ids = torch.repeat_interleave(input_data["input_ids"], repeats=mask_count, dim=0)
token_type_ids = torch.repeat_interleave(input_data["token_type_ids"], repeats=mask_count, dim=0)
attention_mask = torch.repeat_interleave(input_data["attention_mask"], repeats=mask_count, dim=0)
masked_lm_labels = []
masked_lm_positions = list(range(1, mask_count + 1))
for i in masked_lm_positions:
new_input_ids[i - 1][i] = self.tokenizer.mask_token_id
masked_lm_labels.append(input_ids[0][i].item())
new_data = {"input_ids": new_input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return new_data, masked_lm_positions, masked_lm_labels
def __call__(self, input_text, *args, **kwargs):
"""
Compute perplexity for given sentence.
"""
if not isinstance(input_text, str):
return None
# 1. create sequential mask
model_inputs = self.tokenizer(input_text, return_tensors='pt')
new_data, masked_lm_positions, masked_lm_labels = self.create_sequential_mask(model_inputs.data)
model_inputs.data = new_data
labels = torch.tensor(masked_lm_labels)
# 2. predict
model_outputs = self.model(**model_inputs)
# 3. compute perplexity
sentence = {}
tokens = []
for i in range(len(labels)):
model_outputs_i = {}
model_outputs_i["input_ids"] = model_inputs["input_ids"][i:i + 1]
model_outputs_i["logits"] = model_outputs["logits"][i:i + 1]
outputs = self.postprocess(model_outputs_i, target_ids=labels[i:i + 1])
print(outputs)
tokens.append({"token": outputs[0]["token_str"],
"prob": outputs[0]["score"]})
sentence["tokens"] = tokens
sentence["ppl"] = float(np.exp(- sum(np.log(token["prob"]) for token in tokens) / len(tokens)))
return sentence