|
from typing import * |
|
|
|
|
|
import torch |
|
from transformers import AutoTokenizer, GPT2LMHeadModel |
|
from transformers import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP |
|
from transformers.tokenization_utils import BatchEncoding |
|
|
|
from .abc.transformers import TransformersLMScorer |
|
|
|
|
|
class GPT2LMScorer(TransformersLMScorer): |
|
|
|
def _build(self, model_name: str, options: Dict[str, Any]) -> None: |
|
super()._build(model_name, options) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, use_fast=True, add_special_tokens=False |
|
) |
|
|
|
|
|
self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]}) |
|
self.tokenizer.pad_token = "<|pad|>" |
|
|
|
self.model = GPT2LMHeadModel.from_pretrained(model_name) |
|
|
|
self.model.resize_token_embeddings(len(self.tokenizer)) |
|
self.model.eval() |
|
if "device" in options: |
|
self.model.to(options["device"]) |
|
|
|
def _add_special_tokens(self, text: str) -> str: |
|
return self.tokenizer.bos_token + text + self.tokenizer.eos_token |
|
|
|
|
|
def _tokens_log_prob_for_batch( |
|
self, text: List[str] |
|
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]: |
|
outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = [] |
|
if len(text) == 0: |
|
return outputs |
|
|
|
|
|
text = list(map(self._add_special_tokens, text)) |
|
encoding: BatchEncoding = self.tokenizer.batch_encode_plus( |
|
text, return_tensors="pt", |
|
) |
|
with torch.no_grad(): |
|
ids = encoding["input_ids"].to(self.model.device) |
|
attention_mask = encoding["attention_mask"].to(self.model.device) |
|
nopad_mask = ids != self.tokenizer.pad_token_id |
|
logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0] |
|
|
|
for sent_index in range(len(text)): |
|
sent_nopad_mask = nopad_mask[sent_index] |
|
|
|
sent_tokens = [ |
|
tok |
|
for i, tok in enumerate(encoding.tokens(sent_index)) |
|
if sent_nopad_mask[i] and i != 0 |
|
] |
|
|
|
|
|
sent_ids = ids[sent_index, sent_nopad_mask][1:] |
|
|
|
sent_logits = logits[sent_index, sent_nopad_mask][:-1, :] |
|
sent_logits[:, self.tokenizer.pad_token_id] = float("-inf") |
|
|
|
sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1) |
|
|
|
sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1) |
|
|
|
sent_log_probs = cast(torch.DoubleTensor, sent_log_probs) |
|
sent_ids = cast(torch.LongTensor, sent_ids) |
|
|
|
output = (sent_log_probs, sent_ids, sent_tokens) |
|
outputs.append(output) |
|
|
|
return outputs |
|
|
|
|
|
@classmethod |
|
def _supported_model_names(cls) -> Iterable[str]: |
|
return GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys() |
|
|