File size: 3,588 Bytes
2720487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from itertools import chain
from typing import List, Union
from transformers import ByT5Tokenizer
import numpy as np
import torch
from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET


def text_to_utf16_numbers(text):
    utf16_bytes = text.encode('utf-16le')  # Little-endian to simplify byte order handling

    numbers = []

    # Iterate through each pair of bytes and combine them into a single number
    for i in range(0, len(utf16_bytes), 2):
        # Combine two adjacent bytes into a single number
        number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8)
        numbers.append(number)

    return numbers


def utf16_numbers_to_text(numbers):
    byte_array = bytearray()
    for number in numbers:
        # Extract the two bytes from the number and add them to the byte array
        byte_array.append(number & 0xFF)         # Lower byte
        byte_array.append((number >> 8) & 0xFF)  # Upper byte

    text = byte_array.decode('utf-16le', errors="ignore")
    return text


def _tokenize(text: str, langs: List[str], eos_token_id: int = 1, add_eos: bool = True, add_bos: bool = True):
    tokens = text_to_utf16_numbers(text)
    tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens

    lang_list = []
    for lang in langs:
        code = LANGUAGE_MAP[lang]
        lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS)

    tokens = lang_list + tokens

    if add_eos:
        tokens.append(eos_token_id)
    if add_bos:
        tokens.insert(0, eos_token_id)

    return tokens, lang_list


class Byt5LangTokenizer(ByT5Tokenizer):
    def __init__(self,
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        model_max_length=None,
        **kwargs,
    ):
        self.pad_token = pad_token
        self.eos_token = eos_token
        self.unk_token = unk_token
        self.bos_token = eos_token
        self.offset = TOKEN_OFFSET

        self.pad_id = 0
        self.eos_id = 1
        self.unk_id = 2

        self.model_max_length = model_max_length
        self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS

        super().__init__()

    def __call__(self, texts: Union[List[str], str], langs: Union[List[List[str]], List[str]], pad_token_id: int = 0, **kwargs):
        tokenized = []
        all_langs = []

        is_list = True
        # Convert to list of lists format
        if isinstance(texts, str):
            texts = [texts]
            is_list = False

        if isinstance(langs[0], str):
            langs = [langs]

        # One language input per text input
        assert len(langs) == len(texts)

        for text, lang in zip(texts, langs):
            tokens, lang_list = _tokenize(text, lang)
            tokenized.append(tokens)
            all_langs.append(lang_list)

        # Convert back to flat format
        if not is_list:
            tokenized = tokenized[0]
            all_langs = all_langs[0]

        return {"input_ids": tokenized, "langs": all_langs}

    def decode(
        self,
        token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: bool = None,
        **kwargs,
    ) -> str:
        if isinstance(token_ids, (np.ndarray, torch.Tensor)):
            token_ids = token_ids.tolist()

        token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start]
        token_ids = [t - TOKEN_OFFSET for t in token_ids]
        text = utf16_numbers_to_text(token_ids)
        return text