Spaces:
Runtime error
Runtime error
File size: 4,183 Bytes
9c56e39 ca86d15 9c56e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
class BaseTokenizer(ABC):
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
self._itos = specials_first + tuple(charset + '[UNK]') + specials_last
self._stoi = {s: i for i, s in enumerate(self._itos)}
def __len__(self):
return len(self._itos)
def _tok2ids(self, tokens: str) -> List[int]:
return [self._stoi[s] for s in tokens]
def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
tokens = [self._itos[i] for i in token_ids]
return ''.join(tokens) if join else tokens
@abstractmethod
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
raise NotImplementedError
@abstractmethod
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
"""Internal method which performs the necessary filtering prior to decoding."""
raise NotImplementedError
def decode(self, token_dists: Tensor, beam_width: int = 1, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
if beam_width > 1:
return self.beam_search_decode(token_dists, beam_width, raw)
else:
return self.greedy_decode(token_dists, raw)
def greedy_decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
batch_tokens = []
batch_probs = []
for dist in token_dists:
probs, ids = dist.max(-1)
if not raw:
probs, ids = self._filter(probs, ids)
tokens = self._ids2tok(ids, not raw)
batch_tokens.append(tokens)
batch_probs.append(probs)
return batch_tokens, batch_probs
def beam_search_decode(self, token_dists: Tensor, beam_width: int, raw: bool) -> Tuple[List[str], List[Tensor]]:
batch_tokens = []
batch_probs = []
for dist in token_dists:
sequences = [([], 1.0)]
for step_dist in dist:
all_candidates = []
for seq, score in sequences:
top_probs, top_ids = step_dist.topk(beam_width)
for i in range(beam_width):
candidate = (seq + [top_ids[i].item()],
score * top_probs[i].item())
all_candidates.append(candidate)
ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
sequences = ordered[:beam_width]
best_sequence, best_score = sequences[0]
if not raw:
best_score_tensor = torch.tensor([best_score])
best_sequence_tensor = torch.tensor(best_sequence)
best_score_tensor, best_sequence = self._filter(
best_score_tensor, best_sequence_tensor)
best_score = best_score_tensor.item()
tokens = self._ids2tok(best_sequence, not raw)
batch_tokens.append(tokens)
batch_probs.append(best_score)
return batch_tokens, batch_probs
class Tokenizer(BaseTokenizer):
BOS = '[B]'
EOS = '[E]'
PAD = '[P]'
def __init__(self, charset: str) -> None:
specials_first = (self.EOS,)
specials_last = (self.BOS, self.PAD)
super().__init__(charset, specials_first, specials_last)
self.eos_id, self.bos_id, self.pad_id = [
self._stoi[s] for s in specials_first + specials_last]
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
for y in labels]
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
ids = ids.tolist()
try:
eos_idx = ids.index(self.eos_id)
except ValueError:
eos_idx = len(ids)
ids = ids[:eos_idx]
probs = probs[:eos_idx + 1]
return probs, ids
|