captcha_recognition / tokenizer.py
Nischay103's picture
Update tokenizer.py
ca86d15 verified
raw
history blame
4.18 kB
import torch
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
class BaseTokenizer(ABC):
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
self._itos = specials_first + tuple(charset + '[UNK]') + specials_last
self._stoi = {s: i for i, s in enumerate(self._itos)}
def __len__(self):
return len(self._itos)
def _tok2ids(self, tokens: str) -> List[int]:
return [self._stoi[s] for s in tokens]
def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
tokens = [self._itos[i] for i in token_ids]
return ''.join(tokens) if join else tokens
@abstractmethod
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
raise NotImplementedError
@abstractmethod
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
"""Internal method which performs the necessary filtering prior to decoding."""
raise NotImplementedError
def decode(self, token_dists: Tensor, beam_width: int = 1, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
if beam_width > 1:
return self.beam_search_decode(token_dists, beam_width, raw)
else:
return self.greedy_decode(token_dists, raw)
def greedy_decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
batch_tokens = []
batch_probs = []
for dist in token_dists:
probs, ids = dist.max(-1)
if not raw:
probs, ids = self._filter(probs, ids)
tokens = self._ids2tok(ids, not raw)
batch_tokens.append(tokens)
batch_probs.append(probs)
return batch_tokens, batch_probs
def beam_search_decode(self, token_dists: Tensor, beam_width: int, raw: bool) -> Tuple[List[str], List[Tensor]]:
batch_tokens = []
batch_probs = []
for dist in token_dists:
sequences = [([], 1.0)]
for step_dist in dist:
all_candidates = []
for seq, score in sequences:
top_probs, top_ids = step_dist.topk(beam_width)
for i in range(beam_width):
candidate = (seq + [top_ids[i].item()],
score * top_probs[i].item())
all_candidates.append(candidate)
ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
sequences = ordered[:beam_width]
best_sequence, best_score = sequences[0]
if not raw:
best_score_tensor = torch.tensor([best_score])
best_sequence_tensor = torch.tensor(best_sequence)
best_score_tensor, best_sequence = self._filter(
best_score_tensor, best_sequence_tensor)
best_score = best_score_tensor.item()
tokens = self._ids2tok(best_sequence, not raw)
batch_tokens.append(tokens)
batch_probs.append(best_score)
return batch_tokens, batch_probs
class Tokenizer(BaseTokenizer):
BOS = '[B]'
EOS = '[E]'
PAD = '[P]'
def __init__(self, charset: str) -> None:
specials_first = (self.EOS,)
specials_last = (self.BOS, self.PAD)
super().__init__(charset, specials_first, specials_last)
self.eos_id, self.bos_id, self.pad_id = [
self._stoi[s] for s in specials_first + specials_last]
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
for y in labels]
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
ids = ids.tolist()
try:
eos_idx = ids.index(self.eos_id)
except ValueError:
eos_idx = len(ids)
ids = ids[:eos_idx]
probs = probs[:eos_idx + 1]
return probs, ids