IFMedTechdemo commited on
Commit
962d22d
·
verified ·
1 Parent(s): ccbebae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -112
app.py CHANGED
@@ -1,6 +1,12 @@
 
 
 
 
 
1
  import os
2
  import time
3
  import torch
 
4
  from threading import Thread
5
  from PIL import Image
6
  from transformers import (
@@ -23,55 +29,89 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
23
 
24
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
 
 
 
 
26
  # Load Chandra-OCR
27
- MODEL_ID_V = "datalab-to/chandra"
28
- processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
29
- if Qwen3VLForConditionalGeneration:
30
- model_v = Qwen3VLForConditionalGeneration.from_pretrained(
31
- MODEL_ID_V,
32
- trust_remote_code=True,
33
- torch_dtype=torch.float16
34
- ).to(device).eval()
35
- else:
 
 
 
 
 
36
  model_v = None
 
 
37
 
38
  # Load Nanonets-OCR2-3B
39
- MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
40
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
41
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
42
- MODEL_ID_X,
43
- trust_remote_code=True,
44
- torch_dtype=torch.float16
45
- ).to(device).eval()
 
 
 
 
 
 
46
 
47
- # Load Dots.OCR from the local, patched directory
48
- MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
49
- processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
50
- model_d = AutoModelForCausalLM.from_pretrained(
51
- MODEL_PATH_D,
52
- attn_implementation="flash_attention_2",
53
- torch_dtype=torch.bfloat16,
54
- device_map="auto",
55
- trust_remote_code=True
56
- ).eval()
 
 
 
 
 
57
 
58
  # Load olmOCR-2-7B-1025
59
- MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
60
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
61
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
- MODEL_ID_M,
63
- trust_remote_code=True,
64
- torch_dtype=torch.float16
65
- ).to(device).eval()
 
 
 
 
 
 
66
 
67
  # Load DeepSeek-OCR
68
- MODEL_ID_DS = "deepseek-ai/deepseek-ocr"
69
- processor_ds = AutoProcessor.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
70
- model_ds = Qwen2_5_VLForConditionalGeneration.from_pretrained(
71
- MODEL_ID_DS,
72
- trust_remote_code=True,
73
- torch_dtype=torch.float16
74
- ).to(device).eval()
 
 
 
 
 
 
 
75
 
76
  @spaces.GPU
77
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -81,6 +121,9 @@ def generate_image(model_name: str, text: str, image: Image.Image,
81
  Generates responses using the selected model for image input.
82
  Yields raw text and Markdown-formatted text.
83
 
 
 
 
84
  Args:
85
  model_name: Name of the OCR model to use
86
  text: Prompt text for the model
@@ -94,25 +137,40 @@ def generate_image(model_name: str, text: str, image: Image.Image,
94
  Yields:
95
  tuple: (raw_text, markdown_text)
96
  """
 
 
 
97
  # Select model and processor based on model_name
98
  if model_name == "olmOCR-2-7B-1025":
 
 
 
99
  processor = processor_m
100
- model = model_m
101
  elif model_name == "Nanonets-OCR2-3B":
 
 
 
102
  processor = processor_x
103
- model = model_x
104
  elif model_name == "Chandra-OCR":
105
  if model_v is None:
106
- yield "Chandra-OCR model not available.", "Chandra-OCR model not available."
107
  return
108
  processor = processor_v
109
- model = model_v
110
  elif model_name == "Dots.OCR":
 
 
 
111
  processor = processor_d
112
- model = model_d
113
  elif model_name == "DeepSeek-OCR":
 
 
 
114
  processor = processor_ds
115
- model = model_ds
116
  else:
117
  yield "Invalid model selected.", "Invalid model selected."
118
  return
@@ -121,89 +179,108 @@ def generate_image(model_name: str, text: str, image: Image.Image,
121
  yield "Please upload an image.", "Please upload an image."
122
  return
123
 
124
- # Prepare messages in chat format
125
- messages = [{
126
- "role": "user",
127
- "content": [
128
- {"type": "image"},
129
- {"type": "text", "text": text},
130
- ]
131
- }]
132
-
133
- # Apply chat template
134
- prompt_full = processor.apply_chat_template(
135
- messages,
136
- tokenize=False,
137
- add_generation_prompt=True
138
- )
 
139
 
140
- # Process inputs
141
- inputs = processor(
142
- text=[prompt_full],
143
- images=[image],
144
- return_tensors="pt",
145
- padding=True
146
- ).to(device)
147
 
148
- # Setup streaming generation
149
- streamer = TextIteratorStreamer(
150
- processor,
151
- skip_prompt=True,
152
- skip_special_tokens=True
153
- )
154
-
155
- generation_kwargs = {
156
- **inputs,
157
- "streamer": streamer,
158
- "max_new_tokens": max_new_tokens,
159
- "do_sample": True,
160
- "temperature": temperature,
161
- "top_p": top_p,
162
- "top_k": top_k,
163
- "repetition_penalty": repetition_penalty,
164
- }
165
-
166
- # Start generation in separate thread
167
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
168
- thread.start()
169
-
170
- # Stream the results
171
- buffer = ""
172
- for new_text in streamer:
173
- buffer += new_text
174
- buffer = buffer.replace("<|im_end|>", "")
175
- time.sleep(0.01)
176
- yield buffer, buffer
177
-
178
- # Ensure thread completes
179
- thread.join()
 
 
 
 
180
 
181
 
182
  # Example usage for Gradio interface
183
  if __name__ == "__main__":
184
  import gradio as gr
185
 
186
- with gr.Blocks() as demo:
187
- gr.Markdown("# Multi-Model OCR Application")
188
- gr.Markdown("Upload an image and select a model to extract text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  with gr.Row():
191
  with gr.Column():
192
  model_selector = gr.Dropdown(
193
- choices=[
194
- "olmOCR-2-7B-1025",
195
- "Nanonets-OCR2-3B",
196
- "Chandra-OCR",
197
- "Dots.OCR",
198
- "DeepSeek-OCR"
199
- ],
200
- value="DeepSeek-OCR",
201
  label="Select OCR Model"
202
  )
203
  image_input = gr.Image(type="pil", label="Upload Image")
204
  text_input = gr.Textbox(
205
  value="Extract all text from this image.",
206
- label="Prompt"
 
207
  )
208
 
209
  with gr.Accordion("Advanced Settings", open=False):
@@ -249,6 +326,15 @@ if __name__ == "__main__":
249
  output_text = gr.Textbox(label="Extracted Text", lines=20)
250
  output_markdown = gr.Markdown(label="Formatted Output")
251
 
 
 
 
 
 
 
 
 
 
252
  submit_btn.click(
253
  fn=generate_image,
254
  inputs=[
 
1
+ """
2
+ OCR Application with Multiple Models including DeepSeek OCR
3
+ Fixed version with @spaces.GPU decorator for Hugging Face Spaces
4
+ """
5
+
6
  import os
7
  import time
8
  import torch
9
+ import spaces
10
  from threading import Thread
11
  from PIL import Image
12
  from transformers import (
 
29
 
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
 
32
+ print(f"Initial Device: {device}")
33
+ print(f"CUDA Available: {torch.cuda.is_available()}")
34
+
35
  # Load Chandra-OCR
36
+ try:
37
+ MODEL_ID_V = "datalab-to/chandra"
38
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
39
+ if Qwen3VLForConditionalGeneration:
40
+ model_v = Qwen3VLForConditionalGeneration.from_pretrained(
41
+ MODEL_ID_V,
42
+ trust_remote_code=True,
43
+ torch_dtype=torch.float16
44
+ ).eval()
45
+ print("✓ Chandra-OCR loaded")
46
+ else:
47
+ model_v = None
48
+ print("✗ Chandra-OCR: Qwen3VL not available")
49
+ except Exception as e:
50
  model_v = None
51
+ processor_v = None
52
+ print(f"✗ Chandra-OCR: Failed to load - {str(e)}")
53
 
54
  # Load Nanonets-OCR2-3B
55
+ try:
56
+ MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
57
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
58
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
59
+ MODEL_ID_X,
60
+ trust_remote_code=True,
61
+ torch_dtype=torch.float16
62
+ ).eval()
63
+ print("✓ Nanonets-OCR2-3B loaded")
64
+ except Exception as e:
65
+ model_x = None
66
+ processor_x = None
67
+ print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}")
68
 
69
+ # Load Dots.OCR - will be moved to GPU when needed
70
+ try:
71
+ MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
72
+ processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
73
+ model_d = AutoModelForCausalLM.from_pretrained(
74
+ MODEL_PATH_D,
75
+ attn_implementation="flash_attention_2",
76
+ torch_dtype=torch.bfloat16,
77
+ trust_remote_code=True
78
+ ).eval()
79
+ print("✓ Dots.OCR loaded")
80
+ except Exception as e:
81
+ model_d = None
82
+ processor_d = None
83
+ print(f"✗ Dots.OCR: Failed to load - {str(e)}")
84
 
85
  # Load olmOCR-2-7B-1025
86
+ try:
87
+ MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
88
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
89
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90
+ MODEL_ID_M,
91
+ trust_remote_code=True,
92
+ torch_dtype=torch.float16
93
+ ).eval()
94
+ print("✓ olmOCR-2-7B-1025 loaded")
95
+ except Exception as e:
96
+ model_m = None
97
+ processor_m = None
98
+ print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}")
99
 
100
  # Load DeepSeek-OCR
101
+ try:
102
+ MODEL_ID_DS = "deepseek-ai/deepseek-ocr"
103
+ processor_ds = AutoProcessor.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
104
+ model_ds = Qwen2_5_VLForConditionalGeneration.from_pretrained(
105
+ MODEL_ID_DS,
106
+ trust_remote_code=True,
107
+ torch_dtype=torch.float16
108
+ ).eval()
109
+ print("✓ DeepSeek-OCR loaded")
110
+ except Exception as e:
111
+ model_ds = None
112
+ processor_ds = None
113
+ print(f"✗ DeepSeek-OCR: Failed to load - {str(e)}")
114
+
115
 
116
  @spaces.GPU
117
  def generate_image(model_name: str, text: str, image: Image.Image,
 
121
  Generates responses using the selected model for image input.
122
  Yields raw text and Markdown-formatted text.
123
 
124
+ This function is decorated with @spaces.GPU to ensure it runs on GPU
125
+ when available in Hugging Face Spaces.
126
+
127
  Args:
128
  model_name: Name of the OCR model to use
129
  text: Prompt text for the model
 
137
  Yields:
138
  tuple: (raw_text, markdown_text)
139
  """
140
+ # Device will be cuda when @spaces.GPU decorator activates
141
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142
+
143
  # Select model and processor based on model_name
144
  if model_name == "olmOCR-2-7B-1025":
145
+ if model_m is None:
146
+ yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available."
147
+ return
148
  processor = processor_m
149
+ model = model_m.to(device)
150
  elif model_name == "Nanonets-OCR2-3B":
151
+ if model_x is None:
152
+ yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available."
153
+ return
154
  processor = processor_x
155
+ model = model_x.to(device)
156
  elif model_name == "Chandra-OCR":
157
  if model_v is None:
158
+ yield "Chandra-OCR is not available.", "Chandra-OCR is not available."
159
  return
160
  processor = processor_v
161
+ model = model_v.to(device)
162
  elif model_name == "Dots.OCR":
163
+ if model_d is None:
164
+ yield "Dots.OCR is not available.", "Dots.OCR is not available."
165
+ return
166
  processor = processor_d
167
+ model = model_d.to(device)
168
  elif model_name == "DeepSeek-OCR":
169
+ if model_ds is None:
170
+ yield "DeepSeek-OCR is not available.", "DeepSeek-OCR is not available."
171
+ return
172
  processor = processor_ds
173
+ model = model_ds.to(device)
174
  else:
175
  yield "Invalid model selected.", "Invalid model selected."
176
  return
 
179
  yield "Please upload an image.", "Please upload an image."
180
  return
181
 
182
+ try:
183
+ # Prepare messages in chat format
184
+ messages = [{
185
+ "role": "user",
186
+ "content": [
187
+ {"type": "image"},
188
+ {"type": "text", "text": text},
189
+ ]
190
+ }]
191
+
192
+ # Apply chat template
193
+ prompt_full = processor.apply_chat_template(
194
+ messages,
195
+ tokenize=False,
196
+ add_generation_prompt=True
197
+ )
198
 
199
+ # Process inputs
200
+ inputs = processor(
201
+ text=[prompt_full],
202
+ images=[image],
203
+ return_tensors="pt",
204
+ padding=True
205
+ ).to(device)
206
 
207
+ # Setup streaming generation
208
+ streamer = TextIteratorStreamer(
209
+ processor,
210
+ skip_prompt=True,
211
+ skip_special_tokens=True
212
+ )
213
+
214
+ generation_kwargs = {
215
+ **inputs,
216
+ "streamer": streamer,
217
+ "max_new_tokens": max_new_tokens,
218
+ "do_sample": True,
219
+ "temperature": temperature,
220
+ "top_p": top_p,
221
+ "top_k": top_k,
222
+ "repetition_penalty": repetition_penalty,
223
+ }
224
+
225
+ # Start generation in separate thread
226
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
227
+ thread.start()
228
+
229
+ # Stream the results
230
+ buffer = ""
231
+ for new_text in streamer:
232
+ buffer += new_text
233
+ buffer = buffer.replace("<|im_end|>", "")
234
+ time.sleep(0.01)
235
+ yield buffer, buffer
236
+
237
+ # Ensure thread completes
238
+ thread.join()
239
+
240
+ except Exception as e:
241
+ error_msg = f"Error during generation: {str(e)}"
242
+ yield error_msg, error_msg
243
 
244
 
245
  # Example usage for Gradio interface
246
  if __name__ == "__main__":
247
  import gradio as gr
248
 
249
+ # Determine available models
250
+ available_models = []
251
+ if model_m is not None:
252
+ available_models.append("olmOCR-2-7B-1025")
253
+ if model_x is not None:
254
+ available_models.append("Nanonets-OCR2-3B")
255
+ if model_v is not None:
256
+ available_models.append("Chandra-OCR")
257
+ if model_d is not None:
258
+ available_models.append("Dots.OCR")
259
+ if model_ds is not None:
260
+ available_models.append("DeepSeek-OCR")
261
+
262
+ if not available_models:
263
+ print("ERROR: No models were loaded successfully!")
264
+ exit(1)
265
+
266
+ print(f"\n✓ Available models: {', '.join(available_models)}")
267
+
268
+ with gr.Blocks(title="Multi-Model OCR") as demo:
269
+ gr.Markdown("# 🔍 Multi-Model OCR Application")
270
+ gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.")
271
 
272
  with gr.Row():
273
  with gr.Column():
274
  model_selector = gr.Dropdown(
275
+ choices=available_models,
276
+ value=available_models[0] if available_models else None,
 
 
 
 
 
 
277
  label="Select OCR Model"
278
  )
279
  image_input = gr.Image(type="pil", label="Upload Image")
280
  text_input = gr.Textbox(
281
  value="Extract all text from this image.",
282
+ label="Prompt",
283
+ lines=2
284
  )
285
 
286
  with gr.Accordion("Advanced Settings", open=False):
 
326
  output_text = gr.Textbox(label="Extracted Text", lines=20)
327
  output_markdown = gr.Markdown(label="Formatted Output")
328
 
329
+ gr.Markdown("""
330
+ ### Available Models:
331
+ - **olmOCR-2-7B-1025**: Allen AI's OCR model
332
+ - **Nanonets-OCR2-3B**: Nanonets OCR model
333
+ - **Chandra-OCR**: Datalab OCR model
334
+ - **Dots.OCR**: Stranger Vision OCR model
335
+ - **DeepSeek-OCR**: DeepSeek AI's OCR model
336
+ """)
337
+
338
  submit_btn.click(
339
  fn=generate_image,
340
  inputs=[