|
|
|
from typing import * |
|
from abc import abstractmethod |
|
|
|
import torch |
|
|
|
from .base import LMScorer |
|
|
|
|
|
class BatchedLMScorer(LMScorer): |
|
|
|
def _build(self, model_name: str, options: Dict[str, Any]) -> None: |
|
super()._build(model_name, options) |
|
|
|
batch_size = options.get("batch_size", 1) |
|
if batch_size < 1: |
|
raise ValueError("The batch_size option must be positive") |
|
|
|
self.batch_size = batch_size |
|
|
|
|
|
def _tokens_log_prob( |
|
self, text: List[str] |
|
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]: |
|
outputs = [] |
|
for i in range(0, len(text), self.batch_size): |
|
batch = text[i : i + self.batch_size] |
|
outputs.extend(self._tokens_log_prob_for_batch(batch)) |
|
return outputs |
|
|
|
@abstractmethod |
|
def _tokens_log_prob_for_batch( |
|
self, text: List[str] |
|
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]: |
|
... |
|
|