Nischay103
commited on
Commit
•
f3e0ff3
1
Parent(s):
cde31bc
Upload tokenizer.py
Browse files- tokenizer.py +108 -0
tokenizer.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.nn.utils.rnn import pad_sequence
|
6 |
+
|
7 |
+
|
8 |
+
class BaseTokenizer(ABC):
|
9 |
+
|
10 |
+
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
|
11 |
+
self._itos = specials_first + tuple(charset + '[UNK]') + specials_last
|
12 |
+
self._stoi = {s: i for i, s in enumerate(self._itos)}
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self._itos)
|
16 |
+
|
17 |
+
def _tok2ids(self, tokens: str) -> List[int]:
|
18 |
+
return [self._stoi[s] for s in tokens]
|
19 |
+
|
20 |
+
def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
|
21 |
+
tokens = [self._itos[i] for i in token_ids]
|
22 |
+
return ''.join(tokens) if join else tokens
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
|
30 |
+
"""Internal method which performs the necessary filtering prior to decoding."""
|
31 |
+
raise NotImplementedError
|
32 |
+
|
33 |
+
def decode(self, token_dists: Tensor, beam_width: int = 1, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
|
34 |
+
if beam_width > 1:
|
35 |
+
return self.beam_search_decode(token_dists, beam_width, raw)
|
36 |
+
else:
|
37 |
+
return self.greedy_decode(token_dists, raw)
|
38 |
+
|
39 |
+
def greedy_decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
|
40 |
+
batch_tokens = []
|
41 |
+
batch_probs = []
|
42 |
+
for dist in token_dists:
|
43 |
+
probs, ids = dist.max(-1) # greedy selection
|
44 |
+
if not raw:
|
45 |
+
probs, ids = self._filter(probs, ids)
|
46 |
+
tokens = self._ids2tok(ids, not raw)
|
47 |
+
batch_tokens.append(tokens)
|
48 |
+
batch_probs.append(probs)
|
49 |
+
return batch_tokens, batch_probs
|
50 |
+
|
51 |
+
def beam_search_decode(self, token_dists: Tensor, beam_width: int, raw: bool) -> Tuple[List[str], List[Tensor]]:
|
52 |
+
batch_tokens = []
|
53 |
+
batch_probs = []
|
54 |
+
|
55 |
+
for dist in token_dists:
|
56 |
+
sequences = [([], 1.0)]
|
57 |
+
for step_dist in dist:
|
58 |
+
all_candidates = []
|
59 |
+
for seq, score in sequences:
|
60 |
+
top_probs, top_ids = step_dist.topk(beam_width)
|
61 |
+
for i in range(beam_width):
|
62 |
+
candidate = (seq + [top_ids[i].item()],
|
63 |
+
score * top_probs[i].item())
|
64 |
+
all_candidates.append(candidate)
|
65 |
+
ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
|
66 |
+
sequences = ordered[:beam_width]
|
67 |
+
|
68 |
+
best_sequence, best_score = sequences[0]
|
69 |
+
if not raw:
|
70 |
+
best_score_tensor = torch.tensor([best_score])
|
71 |
+
best_sequence_tensor = torch.tensor(best_sequence)
|
72 |
+
best_score_tensor, best_sequence = self._filter(
|
73 |
+
best_score_tensor, best_sequence_tensor)
|
74 |
+
best_score = best_score_tensor.item()
|
75 |
+
tokens = self._ids2tok(best_sequence, not raw)
|
76 |
+
batch_tokens.append(tokens)
|
77 |
+
batch_probs.append(best_score)
|
78 |
+
|
79 |
+
return batch_tokens, batch_probs
|
80 |
+
|
81 |
+
|
82 |
+
class Tokenizer(BaseTokenizer):
|
83 |
+
BOS = '[B]'
|
84 |
+
EOS = '[E]'
|
85 |
+
PAD = '[P]'
|
86 |
+
|
87 |
+
def __init__(self, charset: str) -> None:
|
88 |
+
specials_first = (self.EOS,)
|
89 |
+
specials_last = (self.BOS, self.PAD)
|
90 |
+
super().__init__(charset, specials_first, specials_last)
|
91 |
+
self.eos_id, self.bos_id, self.pad_id = [
|
92 |
+
self._stoi[s] for s in specials_first + specials_last]
|
93 |
+
|
94 |
+
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
|
95 |
+
batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
|
96 |
+
for y in labels]
|
97 |
+
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
|
98 |
+
|
99 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
|
100 |
+
ids = ids.tolist()
|
101 |
+
try:
|
102 |
+
eos_idx = ids.index(self.eos_id)
|
103 |
+
except ValueError:
|
104 |
+
eos_idx = len(ids)
|
105 |
+
# Truncate after EOS
|
106 |
+
ids = ids[:eos_idx]
|
107 |
+
probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
|
108 |
+
return probs, ids
|