|
from itertools import chain |
|
from typing import List, Union |
|
from transformers import ByT5Tokenizer |
|
import numpy as np |
|
import torch |
|
from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET |
|
|
|
|
|
def text_to_utf16_numbers(text): |
|
utf16_bytes = text.encode('utf-16le') |
|
|
|
numbers = [] |
|
|
|
|
|
for i in range(0, len(utf16_bytes), 2): |
|
|
|
number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8) |
|
numbers.append(number) |
|
|
|
return numbers |
|
|
|
|
|
def utf16_numbers_to_text(numbers): |
|
byte_array = bytearray() |
|
for number in numbers: |
|
|
|
byte_array.append(number & 0xFF) |
|
byte_array.append((number >> 8) & 0xFF) |
|
|
|
text = byte_array.decode('utf-16le', errors="ignore") |
|
return text |
|
|
|
|
|
def _tokenize(text: str, langs: List[str], eos_token_id: int = 1, add_eos: bool = True, add_bos: bool = True): |
|
tokens = text_to_utf16_numbers(text) |
|
tokens = [t + TOKEN_OFFSET for t in tokens] |
|
|
|
lang_list = [] |
|
for lang in langs: |
|
code = LANGUAGE_MAP[lang] |
|
lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS) |
|
|
|
tokens = lang_list + tokens |
|
|
|
if add_eos: |
|
tokens.append(eos_token_id) |
|
if add_bos: |
|
tokens.insert(0, eos_token_id) |
|
|
|
return tokens, lang_list |
|
|
|
|
|
class Byt5LangTokenizer(ByT5Tokenizer): |
|
def __init__(self, |
|
eos_token="</s>", |
|
unk_token="<unk>", |
|
pad_token="<pad>", |
|
model_max_length=None, |
|
**kwargs, |
|
): |
|
self.pad_token = pad_token |
|
self.eos_token = eos_token |
|
self.unk_token = unk_token |
|
self.bos_token = eos_token |
|
self.offset = TOKEN_OFFSET |
|
|
|
self.pad_id = 0 |
|
self.eos_id = 1 |
|
self.unk_id = 2 |
|
|
|
self.model_max_length = model_max_length |
|
self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS |
|
|
|
super().__init__() |
|
|
|
def __call__(self, texts: Union[List[str], str], langs: Union[List[List[str]], List[str]], pad_token_id: int = 0, **kwargs): |
|
tokenized = [] |
|
all_langs = [] |
|
|
|
is_list = True |
|
|
|
if isinstance(texts, str): |
|
texts = [texts] |
|
is_list = False |
|
|
|
if isinstance(langs[0], str): |
|
langs = [langs] |
|
|
|
|
|
assert len(langs) == len(texts) |
|
|
|
for text, lang in zip(texts, langs): |
|
tokens, lang_list = _tokenize(text, lang) |
|
tokenized.append(tokens) |
|
all_langs.append(lang_list) |
|
|
|
|
|
if not is_list: |
|
tokenized = tokenized[0] |
|
all_langs = all_langs[0] |
|
|
|
return {"input_ids": tokenized, "langs": all_langs} |
|
|
|
def decode( |
|
self, |
|
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], |
|
skip_special_tokens: bool = False, |
|
clean_up_tokenization_spaces: bool = None, |
|
**kwargs, |
|
) -> str: |
|
if isinstance(token_ids, (np.ndarray, torch.Tensor)): |
|
token_ids = token_ids.tolist() |
|
|
|
token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start] |
|
token_ids = [t - TOKEN_OFFSET for t in token_ids] |
|
text = utf16_numbers_to_text(token_ids) |
|
return text |
|
|