| import torch | |
| class PerplexityEvaluator(object): | |
| def __init__(self, model, tokenizer, ignore_index=-1): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.ignore_index = ignore_index | |
| def __call__(self, text, context=None): | |
| return self.log_perplexity(text, context) | |
| def log_perplexity(self, text, context=None): | |
| """ | |
| Evaluate log perplexity of text with respect to the language model | |
| based on the context | |
| :param text: | |
| :param context: | |
| :return: | |
| """ | |
| device = self.model.device | |
| text_ids = self.tokenizer(text, return_tensors='pt') | |
| if context: | |
| context_ids = self.tokenizer(context, return_tensors='pt') | |
| input_ids = torch.concatenate([context_ids['input_ids'], text_ids['input_ids']], axis=1) | |
| labels = torch.concatenate([torch.ones_like(context_ids['input_ids']) * self.ignore_index, | |
| text_ids['input_ids']], axis=1) | |
| print("Warning, need to remove context length when reporting lppx") | |
| else: | |
| input_ids = text_ids['input_ids'] | |
| labels = input_ids | |
| loss = self.model(input_ids=input_ids.to(device), labels=labels.to(device)).loss | |
| return loss.cpu().detach().numpy() |