Spaces:
Runtime error
Runtime error
import re | |
from abc import ABC, abstractmethod | |
from itertools import groupby | |
from typing import List, Optional, Tuple | |
import torch | |
from torch import Tensor | |
from torch.nn.utils.rnn import pad_sequence | |
class CharsetAdapter: | |
"""Transforms labels according to the target charset.""" | |
def __init__(self, target_charset) -> None: | |
super().__init__() | |
self.charset = target_charset ### | |
self.lowercase_only = target_charset == target_charset.lower() | |
self.uppercase_only = target_charset == target_charset.upper() | |
# self.unsupported = f'[^{re.escape(target_charset)}]' | |
def __call__(self, label): | |
if self.lowercase_only: | |
label = label.lower() | |
elif self.uppercase_only: | |
label = label.upper() | |
return label | |
class BaseTokenizer(ABC): | |
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: | |
self._itos = specials_first + tuple(charset+'[UNK]') + specials_last | |
self._stoi = {s: i for i, s in enumerate(self._itos)} | |
def __len__(self): | |
return len(self._itos) | |
def _tok2ids(self, tokens: str) -> List[int]: | |
return [self._stoi[s] for s in tokens] | |
def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: | |
tokens = [self._itos[i] for i in token_ids] | |
return ''.join(tokens) if join else tokens | |
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: | |
"""Encode a batch of labels to a representation suitable for the model. | |
Args: | |
labels: List of labels. Each can be of arbitrary length. | |
device: Create tensor on this device. | |
Returns: | |
Batched tensor representation padded to the max label length. Shape: N, L | |
""" | |
raise NotImplementedError | |
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: | |
"""Internal method which performs the necessary filtering prior to decoding.""" | |
raise NotImplementedError | |
def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: | |
"""Decode a batch of token distributions. | |
Args: | |
token_dists: softmax probabilities over the token distribution. Shape: N, L, C | |
raw: return unprocessed labels (will return list of list of strings) | |
Returns: | |
list of string labels (arbitrary length) and | |
their corresponding sequence probabilities as a list of Tensors | |
""" | |
batch_tokens = [] | |
batch_probs = [] | |
for dist in token_dists: | |
probs, ids = dist.max(-1) # greedy selection | |
if not raw: | |
probs, ids = self._filter(probs, ids) | |
tokens = self._ids2tok(ids, not raw) | |
batch_tokens.append(tokens) | |
batch_probs.append(probs) | |
return batch_tokens, batch_probs | |
class Tokenizer(BaseTokenizer): | |
BOS = '[B]' | |
EOS = '[E]' | |
PAD = '[P]' | |
def __init__(self, charset: str) -> None: | |
specials_first = (self.EOS,) | |
specials_last = (self.BOS, self.PAD) | |
super().__init__(charset, specials_first, specials_last) | |
self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] | |
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: | |
batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) | |
for y in labels] | |
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) | |
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: | |
ids = ids.tolist() | |
try: | |
eos_idx = ids.index(self.eos_id) | |
except ValueError: | |
eos_idx = len(ids) # Nothing to truncate. | |
# Truncate after EOS | |
ids = ids[:eos_idx] | |
probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) | |
return probs, ids | |
class CTCTokenizer(BaseTokenizer): | |
BLANK = '[B]' | |
def __init__(self, charset: str) -> None: | |
# BLANK uses index == 0 by default | |
super().__init__(charset, specials_first=(self.BLANK,)) | |
self.blank_id = self._stoi[self.BLANK] | |
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: | |
# We use a padded representation since we don't want to use CUDNN's CTC implementation | |
batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels] | |
return pad_sequence(batch, batch_first=True, padding_value=self.blank_id) | |
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: | |
# Best path decoding: | |
ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens | |
ids = [x for x in ids if x != self.blank_id] # Remove BLANKs | |
# `probs` is just pass-through since all positions are considered part of the path | |
return probs, ids |