kinensake's picture
Modify: requirements.txt
2ea9ced
raw
history blame
3.5 kB
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel
from transformers import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.tokenization_utils import BatchEncoding
from .abc.transformers import TransformersLMScorer
class GPT2LMScorer(TransformersLMScorer):
# @overrides
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
super()._build(model_name, options)
# pylint: disable=attribute-defined-outside-init
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, use_fast=True, add_special_tokens=False
)
# Add the pad token to GPT2 dictionary.
# len(tokenizer) = vocab_size + 1
self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]})
self.tokenizer.pad_token = "<|pad|>"
self.model = GPT2LMHeadModel.from_pretrained(model_name)
# We need to resize the embedding layer because we added the pad token.
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.eval()
if "device" in options:
self.model.to(options["device"])
def _add_special_tokens(self, text: str) -> str:
return self.tokenizer.bos_token + text + self.tokenizer.eos_token
# @overrides
def _tokens_log_prob_for_batch(
self, text: List[str]
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = []
if len(text) == 0:
return outputs
# TODO: Handle overflowing elements for long sentences
text = list(map(self._add_special_tokens, text))
encoding: BatchEncoding = self.tokenizer.batch_encode_plus(
text, return_tensors="pt",
)
with torch.no_grad():
ids = encoding["input_ids"].to(self.model.device)
attention_mask = encoding["attention_mask"].to(self.model.device)
nopad_mask = ids != self.tokenizer.pad_token_id
logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0]
for sent_index in range(len(text)):
sent_nopad_mask = nopad_mask[sent_index]
# len(tokens) = len(text[sent_index]) + 1
sent_tokens = [
tok
for i, tok in enumerate(encoding.tokens(sent_index))
if sent_nopad_mask[i] and i != 0
]
# sent_ids.shape = [len(text[sent_index]) + 1]
sent_ids = ids[sent_index, sent_nopad_mask][1:]
# logits.shape = [len(text[sent_index]) + 1, vocab_size]
sent_logits = logits[sent_index, sent_nopad_mask][:-1, :]
sent_logits[:, self.tokenizer.pad_token_id] = float("-inf")
# ids_scores.shape = [seq_len + 1]
sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1)
# log_prob.shape = [seq_len + 1]
sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1)
sent_log_probs = cast(torch.DoubleTensor, sent_log_probs)
sent_ids = cast(torch.LongTensor, sent_ids)
output = (sent_log_probs, sent_ids, sent_tokens)
outputs.append(output)
return outputs
# @overrides
@classmethod
def _supported_model_names(cls) -> Iterable[str]:
return GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()