Phu92kt commited on
Commit
c139768
·
verified ·
1 Parent(s): 9e3ad3d

Create tokenizer_base.py

Browse files
Files changed (1) hide show
  1. tokenizer_base.py +128 -0
tokenizer_base.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from itertools import groupby
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+
11
+ class CharsetAdapter:
12
+ """Transforms labels according to the target charset."""
13
+
14
+ def __init__(self, target_charset) -> None:
15
+ super().__init__()
16
+ self.charset = target_charset ###
17
+ self.lowercase_only = target_charset == target_charset.lower()
18
+ self.uppercase_only = target_charset == target_charset.upper()
19
+ # self.unsupported = f'[^{re.escape(target_charset)}]'
20
+
21
+ def __call__(self, label):
22
+ if self.lowercase_only:
23
+ label = label.lower()
24
+ elif self.uppercase_only:
25
+ label = label.upper()
26
+ return label
27
+
28
+
29
+ class BaseTokenizer(ABC):
30
+
31
+ def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
32
+ self._itos = specials_first + tuple(charset+'[UNK]') + specials_last
33
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
34
+
35
+ def __len__(self):
36
+ return len(self._itos)
37
+
38
+ def _tok2ids(self, tokens: str) -> List[int]:
39
+ return [self._stoi[s] for s in tokens]
40
+
41
+ def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
42
+ tokens = [self._itos[i] for i in token_ids]
43
+ return ''.join(tokens) if join else tokens
44
+
45
+ @abstractmethod
46
+ def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
47
+ """Encode a batch of labels to a representation suitable for the model.
48
+ Args:
49
+ labels: List of labels. Each can be of arbitrary length.
50
+ device: Create tensor on this device.
51
+ Returns:
52
+ Batched tensor representation padded to the max label length. Shape: N, L
53
+ """
54
+ raise NotImplementedError
55
+
56
+ @abstractmethod
57
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
58
+ """Internal method which performs the necessary filtering prior to decoding."""
59
+ raise NotImplementedError
60
+
61
+ def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
62
+ """Decode a batch of token distributions.
63
+ Args:
64
+ token_dists: softmax probabilities over the token distribution. Shape: N, L, C
65
+ raw: return unprocessed labels (will return list of list of strings)
66
+ Returns:
67
+ list of string labels (arbitrary length) and
68
+ their corresponding sequence probabilities as a list of Tensors
69
+ """
70
+ batch_tokens = []
71
+ batch_probs = []
72
+ for dist in token_dists:
73
+ probs, ids = dist.max(-1) # greedy selection
74
+ if not raw:
75
+ probs, ids = self._filter(probs, ids)
76
+ tokens = self._ids2tok(ids, not raw)
77
+ batch_tokens.append(tokens)
78
+ batch_probs.append(probs)
79
+ return batch_tokens, batch_probs
80
+
81
+
82
+ class Tokenizer(BaseTokenizer):
83
+ BOS = '[B]'
84
+ EOS = '[E]'
85
+ PAD = '[P]'
86
+
87
+ def __init__(self, charset: str) -> None:
88
+ specials_first = (self.EOS,)
89
+ specials_last = (self.BOS, self.PAD)
90
+ super().__init__(charset, specials_first, specials_last)
91
+ self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
92
+
93
+ def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
94
+ batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
95
+ for y in labels]
96
+ return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
97
+
98
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
99
+ ids = ids.tolist()
100
+ try:
101
+ eos_idx = ids.index(self.eos_id)
102
+ except ValueError:
103
+ eos_idx = len(ids) # Nothing to truncate.
104
+ # Truncate after EOS
105
+ ids = ids[:eos_idx]
106
+ probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
107
+ return probs, ids
108
+
109
+
110
+ class CTCTokenizer(BaseTokenizer):
111
+ BLANK = '[B]'
112
+
113
+ def __init__(self, charset: str) -> None:
114
+ # BLANK uses index == 0 by default
115
+ super().__init__(charset, specials_first=(self.BLANK,))
116
+ self.blank_id = self._stoi[self.BLANK]
117
+
118
+ def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
119
+ # We use a padded representation since we don't want to use CUDNN's CTC implementation
120
+ batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
121
+ return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
122
+
123
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
124
+ # Best path decoding:
125
+ ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
126
+ ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
127
+ # `probs` is just pass-through since all positions are considered part of the path
128
+ return probs, ids