File size: 10,701 Bytes
c801df7
6b8b7fc
b2c5cb8
962d22d
5594430
0eb08d6
 
 
 
6b8b7fc
b2c5cb8
 
6b8b7fc
b2c5cb8
0eb08d6
 
6b8b7fc
b2c5cb8
 
da048ad
6b8b7fc
 
c0fa0f9
da048ad
 
5594430
 
 
c0fa0f9
da048ad
 
6b8b7fc
 
 
 
 
 
c0fa0f9
da048ad
 
6b8b7fc
 
 
da048ad
 
 
6b8b7fc
c0fa0f9
6b8b7fc
b2c5cb8
c0fa0f9
da048ad
 
962d22d
 
 
c0fa0f9
da048ad
 
6b8b7fc
962d22d
 
 
 
 
 
 
da048ad
 
962d22d
 
 
 
 
 
6b8b7fc
962d22d
 
b2c5cb8
c0fa0f9
da048ad
 
6b8b7fc
962d22d
 
 
 
 
 
da048ad
 
962d22d
 
 
 
 
 
6b8b7fc
 
962d22d
 
 
 
 
 
da048ad
 
962d22d
 
 
 
 
 
b2c5cb8
161e0b2
da048ad
 
b2c5cb8
 
 
c0fa0f9
b2c5cb8
 
 
962d22d
 
6b8b7fc
 
 
 
 
 
 
 
 
 
 
 
962d22d
 
c0fa0f9
da048ad
6b8b7fc
e71abcc
962d22d
 
 
b2c5cb8
da048ad
6b8b7fc
962d22d
 
 
6b8b7fc
da048ad
6b8b7fc
 
962d22d
6b8b7fc
 
da048ad
b2c5cb8
 
 
 
c0fa0f9
da048ad
 
6b8b7fc
 
 
 
c0fa0f9
962d22d
 
 
 
 
 
 
 
 
c0fa0f9
da048ad
5594430
91fc8ec
 
 
 
 
 
 
 
 
5594430
b2c5cb8
c0fa0f9
da048ad
 
962d22d
 
 
 
 
 
 
b2c5cb8
c0fa0f9
da048ad
 
962d22d
 
5594430
962d22d
 
 
c0fa0f9
da048ad
962d22d
 
 
 
 
 
 
 
 
 
c0fa0f9
da048ad
962d22d
 
 
c0fa0f9
da048ad
962d22d
 
 
 
 
 
 
c0fa0f9
da048ad
962d22d
 
c0fa0f9
da048ad
962d22d
 
91fc8ec
 
 
962d22d
b2c5cb8
 
da048ad
 
6b8b7fc
 
 
c0fa0f9
da048ad
962d22d
 
 
 
78efca1
962d22d
 
78efca1
962d22d
 
78efca1
962d22d
 
 
c0fa0f9
da048ad
78efca1
c0fa0f9
da048ad
962d22d
 
 
c0fa0f9
da048ad
6b8b7fc
 
 
962d22d
 
6b8b7fc
b2c5cb8
6b8b7fc
 
 
962d22d
 
b2c5cb8
c0fa0f9
da048ad
6b8b7fc
 
 
 
 
 
c0fa0f9
6b8b7fc
 
 
 
 
 
c0fa0f9
6b8b7fc
 
 
 
 
 
c0fa0f9
6b8b7fc
 
 
 
 
 
c0fa0f9
6b8b7fc
 
 
 
 
 
c0fa0f9
0eb08d6
c0fa0f9
da048ad
6b8b7fc
c0fa0f9
da048ad
6b8b7fc
 
 
c0fa0f9
da048ad
962d22d
 
 
 
 
 
c0fa0f9
da048ad
6b8b7fc
 
 
 
 
 
 
 
 
 
c0fa0f9
6b8b7fc
 
 
c0fa0f9
da048ad
5594430
da048ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import os
import time
import torch
import spaces
import warnings
import tempfile
import sys
from io import StringIO
from contextlib import contextmanager
from threading import Thread
from PIL import Image
from transformers import (
    AutoProcessor,
    AutoModelForCausalLM,
    AutoModel,
    AutoTokenizer,
    Qwen2_5_VLForConditionalGeneration,
    TextIteratorStreamer
)
from huggingface_hub import snapshot_download
from qwen_vl_utils import process_vision_info




# Suppress the warning about uninitialized weights
warnings.filterwarnings('ignore', message='Some weights.*were not initialized')




# Try importing Qwen3VL if available
try:
    from transformers import Qwen3VLForConditionalGeneration
except ImportError:
    Qwen3VLForConditionalGeneration = None




MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
CACHE_DIR = os.getenv("HF_CACHE_DIR", "./models")




device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")




print(f"Initial Device: {device}")
print(f"CUDA Available: {torch.cuda.is_available()}")




# Load Chandra-OCR
try:
    MODEL_ID_V = "datalab-to/chandra"
    processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
    if Qwen3VLForConditionalGeneration:
        model_v = Qwen3VLForConditionalGeneration.from_pretrained(
            MODEL_ID_V,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            device_map="auto"
        ).eval()
        print("✓ Chandra-OCR loaded")
    else:
        model_v = None
        print("✗ Chandra-OCR: Qwen3VL not available")
except Exception as e:
    model_v = None
    processor_v = None
    print(f"✗ Chandra-OCR: Failed to load - {str(e)}")




# Load Nanonets-OCR2-3B
try:
    MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
    processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
    model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID_X,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto"
    ).eval()
    print("✓ Nanonets-OCR2-3B loaded")
except Exception as e:
    model_x = None
    processor_x = None
    print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}")

# Load olmOCR-2-7B-1025
try:
    MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
    processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
    model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID_M,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto"
    ).eval()
    print("✓ olmOCR-2-7B-1025 loaded")
except Exception as e:
    model_m = None
    processor_m = None
    print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}")




@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
                   max_new_tokens: int, temperature: float, top_p: float,
                   top_k: int, repetition_penalty: float):
    """
    Generates responses using the selected model for image input.
    Yields raw text and Markdown-formatted text.
    This function is decorated with @spaces.GPU to ensure it runs on GPU
    when available in Hugging Face Spaces.
    Args:
        model_name: Name of the OCR model to use
        text: Prompt text for the model
        image: PIL Image object to process
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature
        top_p: Nucleus sampling parameter
        top_k: Top-k sampling parameter
        repetition_penalty: Penalty for repeating tokens
    Yields:
        tuple: (raw_text, markdown_text)
    """
    # Device will be cuda when @spaces.GPU decorator activates
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    # Select model and processor based on model_name
    if model_name == "olmOCR-2-7B-1025":
        if model_m is None:
            yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available."
            return
        processor = processor_m
        model = model_m
    elif model_name == "Nanonets-OCR2-3B":
        if model_x is None:
            yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available."
            return
        processor = processor_x
        model = model_x
    elif model_name == "Chandra-OCR":
        if model_v is None:
            yield "Chandra-OCR is not available.", "Chandra-OCR is not available."
            return
        processor = processor_v
        model = model_v
    else:
        yield "Invalid model selected.", "Invalid model selected."
        return




    if image is None:
        yield "Please upload an image.", "Please upload an image."
        return


    try:
        # Prepare messages in chat format
        messages = [{
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": text},
            ]
        }]


        # Apply chat template with fallback
        try:
            prompt_full = processor.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )
        except Exception as template_error:
            # Fallback: create a simple prompt without chat template
            print(f"Chat template error: {template_error}. Using fallback prompt.")
            prompt_full = f"{text}"




        # Process inputs
        inputs = processor(
            text=[prompt_full],
            images=[image],
            return_tensors="pt",
            padding=True
        ).to(device)




        # Setup streaming generation
        streamer = TextIteratorStreamer(
            processor.tokenizer if hasattr(processor, 'tokenizer') else processor, 
            skip_prompt=True, 
            skip_special_tokens=True
        )


        generation_kwargs = {
            **inputs,
            "streamer": streamer,
            "max_new_tokens": max_new_tokens,
            "do_sample": True,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
            "repetition_penalty": repetition_penalty,
        }


        # Start generation in separate thread
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()


        # Stream the results
        buffer = ""
        for new_text in streamer:
            buffer += new_text
            buffer = buffer.replace("<|im_end|>", "")
            time.sleep(0.01)
            yield buffer, buffer


        # Ensure thread completes
        thread.join()


    except Exception as e:
        error_msg = f"Error during generation: {str(e)}"
        print(f"Full error: {e}")
        import traceback
        traceback.print_exc()
        yield error_msg, error_msg




# Example usage for Gradio interface
if __name__ == "__main__":
    import gradio as gr


    # Determine available models
    available_models = []
    if model_m is not None:
        available_models.append("olmOCR-2-7B-1025")
        print("  Added: olmOCR-2-7B-1025")
    if model_x is not None:
        available_models.append("Nanonets-OCR2-3B")
        print("  Added: Nanonets-OCR2-3B")
    if model_v is not None:
        available_models.append("Chandra-OCR")
        print("  Added: Chandra-OCR")
    if not available_models:
        print("ERROR: No models were loaded successfully!")
        exit(1)


    print(f"\n✓ Available models for dropdown: {', '.join(available_models)}")


    with gr.Blocks(title="Multi-Model OCR") as demo:
        gr.Markdown("# 🔍 Multi-Model OCR Application")
        gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.")


        with gr.Row():
            with gr.Column():
                model_selector = gr.Dropdown(
                    choices=available_models,
                    value=available_models[0] if available_models else None,
                    label="Select OCR Model"
                )
                image_input = gr.Image(type="pil", label="Upload Image")
                text_input = gr.Textbox(
                    value="Extract all text from this image.",
                    label="Prompt",
                    lines=2
                )


                with gr.Accordion("Advanced Settings", open=False):
                    max_tokens = gr.Slider(
                        minimum=1,
                        maximum=MAX_MAX_NEW_TOKENS,
                        value=DEFAULT_MAX_NEW_TOKENS,
                        step=1,
                        label="Max New Tokens"
                    )
                    temperature = gr.Slider(
                        minimum=0.1,
                        maximum=2.0,
                        value=0.7,
                        step=0.1,
                        label="Temperature"
                    )
                    top_p = gr.Slider(
                        minimum=0.0,
                        maximum=1.0,
                        value=0.9,
                        step=0.05,
                        label="Top P"
                    )
                    top_k = gr.Slider(
                        minimum=1,
                        maximum=100,
                        value=50,
                        step=1,
                        label="Top K"
                    )
                    repetition_penalty = gr.Slider(
                        minimum=1.0,
                        maximum=2.0,
                        value=1.1,
                        step=0.1,
                        label="Repetition Penalty"
                    )


                submit_btn = gr.Button("Extract Text", variant="primary")


            with gr.Column():
                output_text = gr.Textbox(label="Extracted Text", lines=20)
                output_markdown = gr.Markdown(label="Formatted Output")


        gr.Markdown("""
        ### Available Models:
        - **olmOCR-2-7B-1025**: Allen AI's OCR model
        - **Nanonets-OCR2-3B**: Nanonets OCR model
        - **Chandra-OCR**: Datalab OCR model
        """)


        submit_btn.click(
            fn=generate_image,
            inputs=[
                model_selector,
                text_input,
                image_input,
                max_tokens,
                temperature,
                top_p,
                top_k,
                repetition_penalty
            ],
            outputs=[output_text, output_markdown]
        )


    # Launch with share=True for Hugging Face Spaces
    demo.launch(share=True)