|
import torch |
|
from abc import ABC, abstractmethod |
|
from typing import List, Optional, Tuple |
|
from torch import Tensor |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
class BaseTokenizer(ABC): |
|
|
|
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: |
|
self._itos = specials_first + tuple(charset + '[UNK]') + specials_last |
|
self._stoi = {s: i for i, s in enumerate(self._itos)} |
|
|
|
def __len__(self): |
|
return len(self._itos) |
|
|
|
def _tok2ids(self, tokens: str) -> List[int]: |
|
return [self._stoi[s] for s in tokens] |
|
|
|
def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: |
|
tokens = [self._itos[i] for i in token_ids] |
|
return ''.join(tokens) if join else tokens |
|
|
|
@abstractmethod |
|
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: |
|
"""Internal method which performs the necessary filtering prior to decoding.""" |
|
raise NotImplementedError |
|
|
|
def decode(self, token_dists: Tensor, beam_width: int = 1, raw: bool = False) -> Tuple[List[str], List[Tensor]]: |
|
if beam_width > 1: |
|
return self.beam_search_decode(token_dists, beam_width, raw) |
|
else: |
|
return self.greedy_decode(token_dists, raw) |
|
|
|
def greedy_decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: |
|
batch_tokens = [] |
|
batch_probs = [] |
|
for dist in token_dists: |
|
probs, ids = dist.max(-1) |
|
if not raw: |
|
probs, ids = self._filter(probs, ids) |
|
tokens = self._ids2tok(ids, not raw) |
|
batch_tokens.append(tokens) |
|
batch_probs.append(probs) |
|
return batch_tokens, batch_probs |
|
|
|
def beam_search_decode(self, token_dists: Tensor, beam_width: int, raw: bool) -> Tuple[List[str], List[Tensor]]: |
|
batch_tokens = [] |
|
batch_probs = [] |
|
|
|
for dist in token_dists: |
|
sequences = [([], 1.0)] |
|
for step_dist in dist: |
|
all_candidates = [] |
|
for seq, score in sequences: |
|
top_probs, top_ids = step_dist.topk(beam_width) |
|
for i in range(beam_width): |
|
candidate = (seq + [top_ids[i].item()], |
|
score * top_probs[i].item()) |
|
all_candidates.append(candidate) |
|
ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True) |
|
sequences = ordered[:beam_width] |
|
|
|
best_sequence, best_score = sequences[0] |
|
if not raw: |
|
best_score_tensor = torch.tensor([best_score]) |
|
best_sequence_tensor = torch.tensor(best_sequence) |
|
best_score_tensor, best_sequence = self._filter( |
|
best_score_tensor, best_sequence_tensor) |
|
best_score = best_score_tensor.item() |
|
tokens = self._ids2tok(best_sequence, not raw) |
|
batch_tokens.append(tokens) |
|
batch_probs.append(best_score) |
|
|
|
return batch_tokens, batch_probs |
|
|
|
|
|
class Tokenizer(BaseTokenizer): |
|
BOS = '[B]' |
|
EOS = '[E]' |
|
PAD = '[P]' |
|
|
|
def __init__(self, charset: str) -> None: |
|
specials_first = (self.EOS,) |
|
specials_last = (self.BOS, self.PAD) |
|
super().__init__(charset, specials_first, specials_last) |
|
self.eos_id, self.bos_id, self.pad_id = [ |
|
self._stoi[s] for s in specials_first + specials_last] |
|
|
|
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: |
|
batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) |
|
for y in labels] |
|
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) |
|
|
|
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: |
|
ids = ids.tolist() |
|
try: |
|
eos_idx = ids.index(self.eos_id) |
|
except ValueError: |
|
eos_idx = len(ids) |
|
ids = ids[:eos_idx] |
|
probs = probs[:eos_idx + 1] |
|
return probs, ids |
|
|