File size: 10,217 Bytes
2720487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from typing import List
import torch
from PIL import Image

from surya.input.processing import convert_if_not_rgb
from surya.postprocessing.math.latex import fix_math, contains_math
from surya.postprocessing.text import truncate_repetitions
from surya.settings import settings
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F


def get_batch_size():
    batch_size = settings.RECOGNITION_BATCH_SIZE
    if batch_size is None:
        batch_size = 32
        if settings.TORCH_DEVICE_MODEL == "mps":
            batch_size = 64 # 12GB RAM max
        if settings.TORCH_DEVICE_MODEL == "cuda":
            batch_size = 256
    return batch_size


def batch_recognition(images: List, languages: List[List[str]], model, processor, batch_size=None):
    import inspect
    print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&",inspect.getargspec(processor).args)
    assert all([isinstance(image, Image.Image) for image in images])
    assert len(images) == len(languages)

    for l in languages:
        assert len(l) <= settings.RECOGNITION_MAX_LANGS, f"OCR only supports up to {settings.RECOGNITION_MAX_LANGS} languages per image, you passed {l}."

    images = [image.convert("RGB") for image in images] # also copies the images
    if batch_size is None:
        batch_size = get_batch_size()

    output_text = []
    confidences = []

    dec_config = model.config.decoder
    layer_count = dec_config.decoder_layers
    kv_heads = dec_config.kv_heads
    head_dim = int(dec_config.d_model / dec_config.decoder_attention_heads)
    min_val = torch.finfo(model.dtype).min

    if settings.RECOGNITION_STATIC_CACHE:
        # We'll re-use these for all batches to avoid recopying
        kv_mask = torch.full((batch_size, 1, 1, settings.RECOGNITION_MAX_TOKENS + 1), min_val, dtype=model.dtype, device=model.device)
        # The +1 accounts for start token
        initial_attn_mask = torch.full((batch_size, 1, settings.RECOGNITION_MAX_LANGS + 1, settings.RECOGNITION_MAX_LANGS + 1), min_val, dtype=model.dtype, device=model.device)

        # Decoder kv cache
        # 7 (layers) x 2 (kv) x bs x 4 (heads) x max tokens x 64 (head dim)
        decoder_cache = [torch.zeros((2, batch_size, kv_heads, settings.RECOGNITION_MAX_TOKENS, head_dim), dtype=model.dtype, device=model.device) for _ in range(layer_count)]

        # Prefill
        decoder_input = torch.zeros((batch_size, settings.RECOGNITION_MAX_LANGS + 1), dtype=torch.long, device=model.device)
    else:
        initial_kv_mask = torch.zeros((batch_size, 1, 1, 1), dtype=model.dtype, device=model.device)
        initial_attn_mask = torch.zeros((batch_size, 1, settings.RECOGNITION_MAX_LANGS + 1, settings.RECOGNITION_MAX_LANGS + 1), dtype=model.dtype, device=model.device)

    processed_batches = processor(text=[""] * len(images), images=images, lang=languages)

    for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"):
        batch_langs = languages[i:i+batch_size]
        has_math = ["_math" in lang for lang in batch_langs]

        batch_pixel_values = processed_batches["pixel_values"][i:i+batch_size]
        batch_langs = processed_batches["langs"][i:i+batch_size]
        max_lang_len = max([len(lang) for lang in batch_langs])

        # Pad languages to max length if needed, to ensure we can convert to a tensor
        for lang_idx in range(len(batch_langs)):
            lang_len = len(batch_langs[lang_idx])
            if lang_len < max_lang_len:
                batch_langs[lang_idx] = [processor.tokenizer.pad_id] * (max_lang_len - lang_len) + batch_langs[lang_idx]

        batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
        current_batch_size = len(batch_pixel_values)

        batch_langs = torch.tensor(np.stack(batch_langs, axis=0), dtype=torch.long, device=model.device)
        batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device)
        batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)

        token_count = 0
        inference_token_count = batch_decoder_input.shape[-1]
        batch_predictions = [[] for _ in range(current_batch_size)]

        decoder_input_pad = torch.zeros((batch_size - current_batch_size, 1), dtype=torch.long, device=model.device)

        if settings.RECOGNITION_STATIC_CACHE:
            # Reset shared tensors
            if i > 0:
                # Decoder cache
                for layer_cache in decoder_cache:
                    layer_cache.fill_(0)

                # KV mask
                kv_mask.fill_(min_val)
                kv_mask[:, :, :, -1] = 0
                kv_mask[:, :, :, :inference_token_count] = 0

                # Attention mask
                initial_attn_mask.fill_(min_val)

                # Prefill
                decoder_input.fill_(0)

            # Prefill attention mask
            attention_mask = initial_attn_mask
            attention_mask[:, :, -inference_token_count:, -inference_token_count:] = 0

            # Prefill input
            decoder_input[:current_batch_size, -inference_token_count:] = batch_decoder_input
            batch_decoder_input = decoder_input

            # Pad to max batch size
            batch_langs = torch.cat([batch_langs, torch.zeros((batch_size - current_batch_size, batch_langs.shape[-1]), dtype=torch.long, device=model.device)], dim=0)
            batch_pixel_values = torch.cat([batch_pixel_values, torch.zeros((batch_size - current_batch_size,) + batch_pixel_values.shape[1:], dtype=model.dtype, device=model.device)], dim=0)
        else:
            # Select seed attention mask
            kv_mask = initial_kv_mask[:current_batch_size]
            kv_mask.fill_(0)

            # Select prefill attention mask
            attention_mask = initial_attn_mask[:current_batch_size, :, :inference_token_count, :inference_token_count]

            decoder_cache = [None] * layer_count

        encoder_outputs = None
        sequence_scores = None
        encoder_cache = [None] * layer_count
        all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)

        with torch.no_grad(): # inference_mode doesn't work with torch.compile
            # Run post-prefill tokens
            while token_count < settings.RECOGNITION_MAX_TOKENS:
                is_prefill = token_count == 0
                return_dict = model(
                    decoder_input_ids=batch_decoder_input,
                    decoder_attention_mask=attention_mask,
                    decoder_self_kv_cache=None if is_prefill else decoder_cache,
                    decoder_cross_kv_cache=None if is_prefill else encoder_cache,
                    decoder_past_token_count=token_count,
                    decoder_langs=batch_langs,
                    pixel_values=batch_pixel_values,
                    encoder_outputs=encoder_outputs,
                    return_dict=True,
                )

                logits = return_dict["logits"][:current_batch_size] # Ignore batch padding
                preds = torch.argmax(logits[:, -1], dim=-1)
                scores = torch.max(F.softmax(logits, dim=-1), dim=-1).values
                done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id)
                done = done
                all_done = all_done | done

                scores[all_done == 1] = 0

                if is_prefill:
                    sequence_scores = scores
                    encoder_outputs = (return_dict["encoder_last_hidden_state"],)
                else:
                    sequence_scores = torch.cat([sequence_scores, scores], dim=1)

                if all_done.all():
                    break

                past_key_values = return_dict["past_key_values"]
                token_range = torch.arange(token_count, token_count + inference_token_count, device=model.device)

                for layer_idx, layer in enumerate(past_key_values):
                    if is_prefill:
                        encoder_cache[layer_idx] = layer[1]

                    if settings.RECOGNITION_STATIC_CACHE:
                        # Fill in entries in static kv cache
                        decoder_cache[layer_idx][:, :, :, token_range, :] = layer[0][:, :, :, -inference_token_count:, :]
                    else:
                        # Cat to generate new kv cache including current tokens
                        if is_prefill:
                            decoder_cache[layer_idx] = layer[0]
                        else:
                            decoder_cache[layer_idx] = torch.cat([decoder_cache[layer_idx], layer[0]], dim=3)

                batch_decoder_input = preds.unsqueeze(1)
                if settings.RECOGNITION_STATIC_CACHE:
                    # Setup new attention mask and input token
                    kv_mask[:, :, :, token_count:(token_count + inference_token_count)] = 0
                    batch_decoder_input = torch.cat([batch_decoder_input, decoder_input_pad], dim=0) # Pad to full batch
                else:
                    kv_mask = torch.cat([kv_mask, torch.zeros((current_batch_size, 1, 1, inference_token_count), dtype=model.dtype, device=model.device)], dim=-1)

                attention_mask = kv_mask

                for j, (pred, status) in enumerate(zip(preds, all_done)):
                    if not status:
                        batch_predictions[j].append(int(pred))

                token_count += inference_token_count
                inference_token_count = batch_decoder_input.shape[-1]

        sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
        detected_text = processor.tokenizer.batch_decode(batch_predictions)
        detected_text = [truncate_repetitions(dt) for dt in detected_text]

        # Postprocess to fix LaTeX output (add $$ signs, etc)
        detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)]
        output_text.extend(detected_text)
        confidences.extend(sequence_scores.tolist())

    return output_text, confidences