File size: 3,588 Bytes
2720487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from itertools import chain
from typing import List, Union
from transformers import ByT5Tokenizer
import numpy as np
import torch
from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET
def text_to_utf16_numbers(text):
utf16_bytes = text.encode('utf-16le') # Little-endian to simplify byte order handling
numbers = []
# Iterate through each pair of bytes and combine them into a single number
for i in range(0, len(utf16_bytes), 2):
# Combine two adjacent bytes into a single number
number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8)
numbers.append(number)
return numbers
def utf16_numbers_to_text(numbers):
byte_array = bytearray()
for number in numbers:
# Extract the two bytes from the number and add them to the byte array
byte_array.append(number & 0xFF) # Lower byte
byte_array.append((number >> 8) & 0xFF) # Upper byte
text = byte_array.decode('utf-16le', errors="ignore")
return text
def _tokenize(text: str, langs: List[str], eos_token_id: int = 1, add_eos: bool = True, add_bos: bool = True):
tokens = text_to_utf16_numbers(text)
tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens
lang_list = []
for lang in langs:
code = LANGUAGE_MAP[lang]
lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS)
tokens = lang_list + tokens
if add_eos:
tokens.append(eos_token_id)
if add_bos:
tokens.insert(0, eos_token_id)
return tokens, lang_list
class Byt5LangTokenizer(ByT5Tokenizer):
def __init__(self,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
model_max_length=None,
**kwargs,
):
self.pad_token = pad_token
self.eos_token = eos_token
self.unk_token = unk_token
self.bos_token = eos_token
self.offset = TOKEN_OFFSET
self.pad_id = 0
self.eos_id = 1
self.unk_id = 2
self.model_max_length = model_max_length
self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS
super().__init__()
def __call__(self, texts: Union[List[str], str], langs: Union[List[List[str]], List[str]], pad_token_id: int = 0, **kwargs):
tokenized = []
all_langs = []
is_list = True
# Convert to list of lists format
if isinstance(texts, str):
texts = [texts]
is_list = False
if isinstance(langs[0], str):
langs = [langs]
# One language input per text input
assert len(langs) == len(texts)
for text, lang in zip(texts, langs):
tokens, lang_list = _tokenize(text, lang)
tokenized.append(tokens)
all_langs.append(lang_list)
# Convert back to flat format
if not is_list:
tokenized = tokenized[0]
all_langs = all_langs[0]
return {"input_ids": tokenized, "langs": all_langs}
def decode(
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
if isinstance(token_ids, (np.ndarray, torch.Tensor)):
token_ids = token_ids.tolist()
token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start]
token_ids = [t - TOKEN_OFFSET for t in token_ids]
text = utf16_numbers_to_text(token_ids)
return text
|