IFMedTechdemo commited on
Commit
0eb08d6
·
verified ·
1 Parent(s): 5594430

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -31
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  OCR Application with Multiple Models including DeepSeek OCR
3
- Final fixed version with proper tokenizer handling
4
  """
5
 
6
  import os
@@ -8,11 +8,17 @@ import time
8
  import torch
9
  import spaces
10
  import warnings
 
 
 
 
11
  from threading import Thread
12
  from PIL import Image
13
  from transformers import (
14
  AutoProcessor,
15
  AutoModelForCausalLM,
 
 
16
  Qwen2_5_VLForConditionalGeneration,
17
  TextIteratorStreamer
18
  )
@@ -101,41 +107,172 @@ except Exception as e:
101
  processor_m = None
102
  print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}")
103
 
104
- # Load DeepSeek-OCR with proper tokenizer handling
105
  try:
106
  MODEL_ID_DS = "deepseek-ai/deepseek-ocr"
107
- processor_ds = AutoProcessor.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
108
- model_ds = Qwen2_5_VLForConditionalGeneration.from_pretrained(
109
  MODEL_ID_DS,
 
110
  trust_remote_code=True,
111
- torch_dtype=torch.float16
112
  ).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # Fix tokenizer chat template - access the correct tokenizer attribute
115
  try:
116
- # The tokenizer might be nested under processor_ds.tokenizer
117
- tokenizer = processor_ds.tokenizer if hasattr(processor_ds, 'tokenizer') else processor_ds
 
118
 
119
- if not hasattr(tokenizer, 'chat_template') or tokenizer.chat_template is None:
120
- # Use a standard Qwen-style chat template
121
- tokenizer.chat_template = "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}<|im_start|>assistant\n"
122
- print("✓ DeepSeek-OCR loaded (with custom chat template)")
123
- else:
124
- print("✓ DeepSeek-OCR loaded")
125
- except Exception as tokenizer_error:
126
- print(f" Warning: Could not set chat template - {tokenizer_error}")
127
- print(" Model loaded but may need fallback prompting")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- except Exception as e:
130
- model_ds = None
131
- processor_ds = None
132
- print(f"✗ DeepSeek-OCR: Failed to load - {str(e)}")
 
 
133
 
134
 
135
  @spaces.GPU
136
  def generate_image(model_name: str, text: str, image: Image.Image,
137
  max_new_tokens: int, temperature: float, top_p: float,
138
- top_k: int, repetition_penalty: float):
139
  """
140
  Generates responses using the selected model for image input.
141
  Yields raw text and Markdown-formatted text.
@@ -152,10 +289,16 @@ def generate_image(model_name: str, text: str, image: Image.Image,
152
  top_p: Nucleus sampling parameter
153
  top_k: Top-k sampling parameter
154
  repetition_penalty: Penalty for repeating tokens
 
155
 
156
  Yields:
157
  tuple: (raw_text, markdown_text)
158
  """
 
 
 
 
 
159
  # Device will be cuda when @spaces.GPU decorator activates
160
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
161
 
@@ -184,12 +327,6 @@ def generate_image(model_name: str, text: str, image: Image.Image,
184
  return
185
  processor = processor_d
186
  model = model_d.to(device)
187
- elif model_name == "DeepSeek-OCR":
188
- if model_ds is None:
189
- yield "DeepSeek-OCR is not available.", "DeepSeek-OCR is not available."
190
- return
191
- processor = processor_ds
192
- model = model_ds.to(device)
193
  else:
194
  yield "Invalid model selected.", "Invalid model selected."
195
  return
@@ -218,7 +355,6 @@ def generate_image(model_name: str, text: str, image: Image.Image,
218
  except Exception as template_error:
219
  # Fallback: create a simple prompt without chat template
220
  print(f"Chat template error: {template_error}. Using fallback prompt.")
221
- # Simple format that most models understand
222
  prompt_full = f"{text}"
223
 
224
  # Process inputs
@@ -347,6 +483,14 @@ if __name__ == "__main__":
347
  step=0.1,
348
  label="Repetition Penalty"
349
  )
 
 
 
 
 
 
 
 
350
 
351
  submit_btn = gr.Button("Extract Text", variant="primary")
352
 
@@ -360,7 +504,14 @@ if __name__ == "__main__":
360
  - **Nanonets-OCR2-3B**: Nanonets OCR model
361
  - **Chandra-OCR**: Datalab OCR model
362
  - **Dots.OCR**: Stranger Vision OCR model
363
- - **DeepSeek-OCR**: DeepSeek AI's OCR model (experimental)
 
 
 
 
 
 
 
364
  """)
365
 
366
  submit_btn.click(
@@ -373,7 +524,8 @@ if __name__ == "__main__":
373
  temperature,
374
  top_p,
375
  top_k,
376
- repetition_penalty
 
377
  ],
378
  outputs=[output_text, output_markdown]
379
  )
 
1
  """
2
  OCR Application with Multiple Models including DeepSeek OCR
3
+ Merged version with working DeepSeek implementation
4
  """
5
 
6
  import os
 
8
  import torch
9
  import spaces
10
  import warnings
11
+ import tempfile
12
+ import sys
13
+ from io import StringIO
14
+ from contextlib import contextmanager
15
  from threading import Thread
16
  from PIL import Image
17
  from transformers import (
18
  AutoProcessor,
19
  AutoModelForCausalLM,
20
+ AutoModel,
21
+ AutoTokenizer,
22
  Qwen2_5_VLForConditionalGeneration,
23
  TextIteratorStreamer
24
  )
 
107
  processor_m = None
108
  print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}")
109
 
110
+ # Load DeepSeek-OCR using the working implementation
111
  try:
112
  MODEL_ID_DS = "deepseek-ai/deepseek-ocr"
113
+ tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
114
+ model_ds = AutoModel.from_pretrained(
115
  MODEL_ID_DS,
116
+ _attn_implementation="flash_attention_2",
117
  trust_remote_code=True,
118
+ use_safetensors=True,
119
  ).eval()
120
+ print("✓ DeepSeek-OCR loaded")
121
+ except Exception as e:
122
+ model_ds = None
123
+ tokenizer_ds = None
124
+ print(f"✗ DeepSeek-OCR: Failed to load - {str(e)}")
125
+
126
+
127
+ @contextmanager
128
+ def capture_stdout():
129
+ """Capture stdout to get printed output from model"""
130
+ old_stdout = sys.stdout
131
+ sys.stdout = StringIO()
132
+ try:
133
+ yield sys.stdout
134
+ finally:
135
+ sys.stdout = old_stdout
136
+
137
+
138
+ @spaces.GPU
139
+ def generate_image_deepseek(text: str, image: Image.Image,
140
+ preset: str = "gundam"):
141
+ """
142
+ Special generation function for DeepSeek-OCR using its native infer method.
143
+
144
+ Args:
145
+ text: Prompt text (used to determine task type)
146
+ image: PIL Image object to process
147
+ preset: Model preset configuration
148
+
149
+ Yields:
150
+ tuple: (raw_text, markdown_text)
151
+ """
152
+ if model_ds is None:
153
+ yield "DeepSeek-OCR is not available.", "DeepSeek-OCR is not available."
154
+ return
155
+
156
+ if image is None:
157
+ yield "Please upload an image.", "Please upload an image."
158
+ return
159
 
 
160
  try:
161
+ # Move model to GPU
162
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
163
+ model_ds.to(device).to(torch.bfloat16)
164
 
165
+ # Create temp directory for this session
166
+ with tempfile.TemporaryDirectory() as temp_dir:
167
+ # Save image with proper format
168
+ temp_image_path = os.path.join(temp_dir, "input_image.jpg")
169
+ # Convert RGBA to RGB if necessary
170
+ if image.mode in ('RGBA', 'LA', 'P'):
171
+ rgb_image = Image.new('RGB', image.size, (255, 255, 255))
172
+ if image.mode == 'RGBA':
173
+ rgb_image.paste(image, mask=image.split()[3])
174
+ else:
175
+ rgb_image.paste(image)
176
+ rgb_image.save(temp_image_path, 'JPEG', quality=95)
177
+ else:
178
+ image.save(temp_image_path, 'JPEG', quality=95)
179
+
180
+ # Set parameters based on preset
181
+ presets = {
182
+ "tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
183
+ "small": {"base_size": 640, "image_size": 640, "crop_mode": False},
184
+ "base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
185
+ "large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
186
+ "gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
187
+ }
188
+
189
+ config = presets[preset]
190
+
191
+ # Determine task type from prompt
192
+ if "markdown" in text.lower() or "convert" in text.lower():
193
+ prompt = "<image>\n<|grounding|>Convert the document to markdown. "
194
+ else:
195
+ prompt = "<image>\nFree OCR. "
196
+
197
+ # Capture stdout while running inference
198
+ captured_output = ""
199
+ with capture_stdout() as output:
200
+ result = model_ds.infer(
201
+ tokenizer_ds,
202
+ prompt=prompt,
203
+ image_file=temp_image_path,
204
+ output_path=temp_dir,
205
+ base_size=config["base_size"],
206
+ image_size=config["image_size"],
207
+ crop_mode=config["crop_mode"],
208
+ save_results=True,
209
+ test_compress=True,
210
+ )
211
+ captured_output = output.getvalue()
212
+
213
+ # Extract the text from captured output
214
+ extracted_text = ""
215
+
216
+ # Look for the actual OCR result in the captured output
217
+ lines = captured_output.split('\n')
218
+ capture_text = False
219
+ text_lines = []
220
+
221
+ for line in lines:
222
+ # Start capturing after seeing certain patterns
223
+ if "# " in line or line.strip().startswith("**"):
224
+ capture_text = True
225
+
226
+ if capture_text:
227
+ # Stop at the separator lines
228
+ if line.startswith("====") or line.startswith("---") and len(line) > 10:
229
+ if text_lines: # Only stop if we've captured something
230
+ break
231
+ # Add non-empty lines that aren't debug output
232
+ elif line.strip() and not line.startswith("image size:") and not line.startswith("valid image") and not line.startswith("output texts") and not line.startswith("compression"):
233
+ text_lines.append(line)
234
+
235
+ if text_lines:
236
+ extracted_text = '\n'.join(text_lines)
237
+
238
+ # If we didn't get text from stdout, check if result contains text
239
+ if not extracted_text and result is not None:
240
+ if isinstance(result, str):
241
+ extracted_text = result
242
+ elif isinstance(result, (list, tuple)) and len(result) > 0:
243
+ if isinstance(result[0], str):
244
+ extracted_text = result[0]
245
+ elif hasattr(result[0], 'text'):
246
+ extracted_text = result[0].text
247
+
248
+ # Clean up any remaining markers from the text
249
+ if extracted_text:
250
+ clean_lines = []
251
+ for line in extracted_text.split('\n'):
252
+ if not any(pattern in line.lower() for pattern in ['image size:', 'valid image', 'compression ratio', 'save results:', 'output texts']):
253
+ clean_lines.append(line)
254
+ extracted_text = '\n'.join(clean_lines).strip()
255
+
256
+ # Move model back to CPU to free GPU memory
257
+ model_ds.to("cpu")
258
+ torch.cuda.empty_cache()
259
+
260
+ # Return the extracted text
261
+ final_text = extracted_text if extracted_text else "No text could be extracted from the image."
262
+ yield final_text, final_text
263
 
264
+ except Exception as e:
265
+ error_msg = f"Error during DeepSeek generation: {str(e)}"
266
+ print(f"Full error: {e}")
267
+ import traceback
268
+ traceback.print_exc()
269
+ yield error_msg, error_msg
270
 
271
 
272
  @spaces.GPU
273
  def generate_image(model_name: str, text: str, image: Image.Image,
274
  max_new_tokens: int, temperature: float, top_p: float,
275
+ top_k: int, repetition_penalty: float, deepseek_preset: str = "gundam"):
276
  """
277
  Generates responses using the selected model for image input.
278
  Yields raw text and Markdown-formatted text.
 
289
  top_p: Nucleus sampling parameter
290
  top_k: Top-k sampling parameter
291
  repetition_penalty: Penalty for repeating tokens
292
+ deepseek_preset: Preset for DeepSeek model
293
 
294
  Yields:
295
  tuple: (raw_text, markdown_text)
296
  """
297
+ # Special handling for DeepSeek-OCR
298
+ if model_name == "DeepSeek-OCR":
299
+ yield from generate_image_deepseek(text, image, deepseek_preset)
300
+ return
301
+
302
  # Device will be cuda when @spaces.GPU decorator activates
303
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
304
 
 
327
  return
328
  processor = processor_d
329
  model = model_d.to(device)
 
 
 
 
 
 
330
  else:
331
  yield "Invalid model selected.", "Invalid model selected."
332
  return
 
355
  except Exception as template_error:
356
  # Fallback: create a simple prompt without chat template
357
  print(f"Chat template error: {template_error}. Using fallback prompt.")
 
358
  prompt_full = f"{text}"
359
 
360
  # Process inputs
 
483
  step=0.1,
484
  label="Repetition Penalty"
485
  )
486
+
487
+ gr.Markdown("### DeepSeek-OCR Specific Settings")
488
+ deepseek_preset = gr.Radio(
489
+ choices=["gundam", "base", "large", "small", "tiny"],
490
+ value="gundam",
491
+ label="DeepSeek Preset",
492
+ info="Only applies when DeepSeek-OCR is selected"
493
+ )
494
 
495
  submit_btn = gr.Button("Extract Text", variant="primary")
496
 
 
504
  - **Nanonets-OCR2-3B**: Nanonets OCR model
505
  - **Chandra-OCR**: Datalab OCR model
506
  - **Dots.OCR**: Stranger Vision OCR model
507
+ - **DeepSeek-OCR**: DeepSeek AI's OCR model (uses native inference method)
508
+
509
+ ### DeepSeek-OCR Presets:
510
+ - **Gundam** (Recommended): Balanced performance with crop mode
511
+ - **Base**: Standard quality without cropping
512
+ - **Large**: Highest quality for complex documents
513
+ - **Small**: Faster processing, good for simple text
514
+ - **Tiny**: Fastest, suitable for clear printed text
515
  """)
516
 
517
  submit_btn.click(
 
524
  temperature,
525
  top_p,
526
  top_k,
527
+ repetition_penalty,
528
+ deepseek_preset
529
  ],
530
  outputs=[output_text, output_markdown]
531
  )