kinensake's picture
Modify: requirements.txt
2ea9ced
raw
history blame
1.15 kB
# pylint: disable=abstract-method
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
from abc import abstractmethod
import torch
from .base import LMScorer
class BatchedLMScorer(LMScorer):
# @overrides
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
super()._build(model_name, options)
batch_size = options.get("batch_size", 1)
if batch_size < 1:
raise ValueError("The batch_size option must be positive")
# pylint: disable=attribute-defined-outside-init
self.batch_size = batch_size
# @overrides
def _tokens_log_prob(
self, text: List[str]
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
outputs = []
for i in range(0, len(text), self.batch_size):
batch = text[i : i + self.batch_size]
outputs.extend(self._tokens_log_prob_for_batch(batch))
return outputs
@abstractmethod
def _tokens_log_prob_for_batch(
self, text: List[str]
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
... # pragma: no cover