from typing import * |
import torch |
from transformers import AutoTokenizer, GPT2LMHeadModel |
from transformers import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP |
from transformers.tokenization_utils import BatchEncoding |
from .abc.transformers import TransformersLMScorer |
class GPT2LMScorer(TransformersLMScorer): |
def _build(self, model_name: str, options: Dict[str, Any]) -> None: |
super()._build(model_name, options) |
self.tokenizer = AutoTokenizer.from_pretrained( |
model_name, use_fast=True, add_special_tokens=False |
) |
self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]}) |
self.tokenizer.pad_token = "<|pad|>" |
self.model = GPT2LMHeadModel.from_pretrained(model_name) |
self.model.resize_token_embeddings(len(self.tokenizer)) |
self.model.eval() |
if "device" in options: |
self.model.to(options["device"]) |
def _add_special_tokens(self, text: str) -> str: |
return self.tokenizer.bos_token + text + self.tokenizer.eos_token |
def _tokens_log_prob_for_batch( |
self, text: List[str] |
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]: |
outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = [] |
if len(text) == 0: |
return outputs |
text = list(map(self._add_special_tokens, text)) |
encoding: BatchEncoding = self.tokenizer.batch_encode_plus( |
text, return_tensors="pt", |
) |
with torch.no_grad(): |
ids = encoding["input_ids"].to(self.model.device) |
attention_mask = encoding["attention_mask"].to(self.model.device) |
nopad_mask = ids != self.tokenizer.pad_token_id |
logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0] |
for sent_index in range(len(text)): |
sent_nopad_mask = nopad_mask[sent_index] |
sent_tokens = [ |
tok |
for i, tok in enumerate(encoding.tokens(sent_index)) |
if sent_nopad_mask[i] and i != 0 |
] |
sent_ids = ids[sent_index, sent_nopad_mask][1:] |
sent_logits = logits[sent_index, sent_nopad_mask][:-1, :] |
sent_logits[:, self.tokenizer.pad_token_id] = float("-inf") |
sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1) |
sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1) |
sent_log_probs = cast(torch.DoubleTensor, sent_log_probs) |
sent_ids = cast(torch.LongTensor, sent_ids) |
output = (sent_log_probs, sent_ids, sent_tokens) |
outputs.append(output) |
return outputs |
@classmethod |
def _supported_model_names(cls) -> Iterable[str]: |