make789 commited on
Commit
14305f4
·
verified ·
1 Parent(s): ec87c41

Upload ocr_service.py

Browse files
Files changed (1) hide show
  1. ocr_service.py +97 -75
ocr_service.py CHANGED
@@ -76,6 +76,11 @@ MAX_JSON_LIST_ITEMS = 100
76
 
77
  # DeepSeek-OCR Model Configuration - Maximum Quality Settings for M4 Mac (Apple Silicon)
78
  MODEL_NAME = "deepseek-ai/DeepSeek-OCR"
 
 
 
 
 
79
  # Detect Apple Silicon (M1/M2/M3/M4) - use MPS if available, otherwise CPU
80
  IS_APPLE_SILICON = platform.machine() == "arm64"
81
  USE_GPU = os.getenv("USE_GPU", "true").lower() == "true" and not IS_APPLE_SILICON # M4 uses MPS, not CUDA
@@ -143,53 +148,69 @@ def _patch_deepseek_model_for_compatibility():
143
  Path(os.getenv("TRANSFORMERS_CACHE", "")) / "hub" if os.getenv("TRANSFORMERS_CACHE") else None,
144
  ]
145
 
 
146
  model_files = []
147
  for cache_dir in possible_cache_dirs:
148
  if cache_dir and cache_dir.exists():
149
  try:
150
- found = list(cache_dir.glob("**/modeling_deepseekv2.py"))
151
- model_files.extend(found)
 
 
 
 
 
 
 
 
 
 
152
  except Exception:
153
  continue
154
 
155
  if not model_files:
156
- print("⚠️ Model file not found yet - will patch on first model load")
157
  return # Model not downloaded yet, will patch on first load
158
 
159
- model_file = model_files[0]
160
- print(f"🔍 Found model file: {model_file}")
 
 
161
 
162
- # Check if already patched
163
- try:
164
- with open(model_file, 'r', encoding='utf-8') as f:
165
- content = f.read()
166
- if "LlamaFlashAttention2 = LlamaAttention" in content:
167
- print("✅ Model already patched")
168
- return # Already patched
169
- original_content = content # Save original for comparison
170
- except Exception as e:
171
- print(f"⚠️ Could not read model file to check patch status: {e}")
172
- return
173
-
174
- # More flexible approach: find and replace any import containing LlamaFlashAttention2
175
  import re
176
 
177
- try:
178
- # Pattern 1: Multi-line import with parentheses
179
- # Matches: from transformers.models.llama.modeling_llama import (\n ...\n LlamaFlashAttention2\n )
180
- multiline_pattern = r'from transformers\.models\.llama\.modeling_llama import\s*\([^)]*LlamaFlashAttention2[^)]*\)'
181
 
182
- # Pattern 2: Single-line import with parentheses
183
- singleline_parentheses_pattern = r'from transformers\.models\.llama\.modeling_llama import\s*\([^)]*LlamaFlashAttention2[^)]*\)'
184
-
185
- # Pattern 3: Direct import without parentheses
186
- direct_import_pattern = r'from transformers\.models\.llama\.modeling_llama import[^;]*LlamaFlashAttention2[^;\n]*'
187
-
188
- patched = False
 
 
 
 
189
 
190
- # Try multiline replacement first (most common)
191
- if re.search(multiline_pattern, content, re.MULTILINE | re.DOTALL):
192
- # Create backup
 
 
 
 
 
 
 
 
 
 
 
 
193
  backup_file = model_file.with_suffix('.py.backup')
194
  try:
195
  import shutil
@@ -231,8 +252,8 @@ except ImportError:
231
  LlamaFlashAttention2 = LlamaAttention"""
232
 
233
  content = re.sub(multiline_pattern, replacement, content, flags=re.MULTILINE | re.DOTALL)
234
- patched = True
235
- print("🔧 Applied multiline import patch")
236
 
237
  # Try single-line parentheses pattern
238
  elif re.search(singleline_parentheses_pattern, content):
@@ -266,8 +287,8 @@ except ImportError:
266
  LlamaFlashAttention2 = LlamaAttention"""
267
 
268
  content = re.sub(singleline_parentheses_pattern, replacement, content)
269
- patched = True
270
- print("🔧 Applied single-line parentheses import patch")
271
 
272
  # Try direct import pattern (no parentheses)
273
  elif re.search(direct_import_pattern, content):
@@ -288,23 +309,22 @@ except ImportError:
288
  LlamaFlashAttention2 = LlamaAttention"""
289
 
290
  content = re.sub(direct_import_pattern, replacement, content)
291
- patched = True
292
- print("🔧 Applied direct import patch")
293
 
294
- # Last resort: find any line containing LlamaFlashAttention2 import and replace
295
- if not patched:
296
  lines = content.split('\n')
297
  for i, line in enumerate(lines):
298
  if 'LlamaFlashAttention2' in line and 'from transformers.models.llama.modeling_llama' in line:
299
  # Create backup on first match
300
- if not patched:
301
- backup_file = model_file.with_suffix('.py.backup')
302
- try:
303
- import shutil
304
- shutil.copy2(model_file, backup_file)
305
- print(f"📋 Created backup: {backup_file}")
306
- except Exception as backup_err:
307
- print(f"⚠️ Could not create backup: {backup_err}")
308
 
309
  # Replace the import line(s)
310
  # Handle multiline imports
@@ -329,19 +349,19 @@ except ImportError:
329
  # Replace the block
330
  lines[i:j+1] = replacement_lines
331
  content = '\n'.join(lines)
332
- patched = True
333
- print(f"🔧 Applied line-by-line patch (lines {i}-{j})")
334
  break
335
  else:
336
  # Single line import
337
  lines[i] = "# Patch: LlamaFlashAttention2 import with fallback\ntry:\n from transformers.models.llama.modeling_llama import LlamaFlashAttention2\nexcept ImportError:\n from transformers.models.llama.modeling_llama import LlamaAttention\n LlamaFlashAttention2 = LlamaAttention"
338
  content = '\n'.join(lines)
339
- patched = True
340
- print(f"🔧 Applied single-line patch (line {i})")
341
  break
342
-
343
- # Last resort: find any line containing LlamaFlashAttention2 import and add fallback
344
- if not patched:
345
  lines_for_fallback = content.split('\n')
346
  for i, line in enumerate(lines_for_fallback):
347
  if 'LlamaFlashAttention2' in line and 'from transformers.models.llama.modeling_llama' in line:
@@ -372,18 +392,18 @@ except ImportError:
372
  ])
373
  new_lines.extend(lines_for_fallback[i+1:])
374
  content = '\n'.join(new_lines)
375
- print(f"✅ Added fallback import block after line {i}")
376
- patched = True
377
  break
378
 
379
- if patched:
380
  # Write file if content was modified
381
  # (fallback already writes immediately, regex patterns modify content then write here)
382
  with open(model_file, 'w', encoding='utf-8') as f:
383
  f.write(content)
384
- print(f"✅ Successfully patched DeepSeek model file: {model_file}")
385
  else:
386
- print(f"⚠️ Could not find LlamaFlashAttention2 import to patch in {model_file}")
387
  # Show a snippet around potential import lines for debugging
388
  lines = content.split('\n')
389
  for i, line in enumerate(lines):
@@ -414,11 +434,8 @@ async def get_ocr_model():
414
  if _ocr_model is None or _ocr_tokenizer is None:
415
  async with _model_lock:
416
  if _ocr_model is None or _ocr_tokenizer is None:
417
- # Patch DeepSeek model code for compatibility BEFORE loading
418
- # Works on HuggingFace Spaces (CPU) and M4 Macs (Apple Silicon)
419
- _patch_deepseek_model_for_compatibility()
420
-
421
  # Lazy import dependencies
 
422
  AutoModel, AutoTokenizer = _get_transformers()
423
  torch = _get_torch()
424
 
@@ -428,14 +445,14 @@ async def get_ocr_model():
428
  print(f" - Crop mode: {CROP_MODE} (best accuracy)")
429
 
430
  # Load tokenizer first (this triggers model download if needed)
 
 
431
  _ocr_tokenizer = AutoTokenizer.from_pretrained(
432
- MODEL_NAME, trust_remote_code=True
 
 
433
  )
434
-
435
- # Patch AFTER tokenizer loads (model files are now downloaded)
436
- # This ensures the model files exist before we try to patch
437
- print(" - Patching model code for compatibility...")
438
- _patch_deepseek_model_for_compatibility()
439
 
440
  # Load model with compatibility settings
441
  # Use SDPA attention to avoid LlamaFlashAttention2 import errors
@@ -451,15 +468,20 @@ async def get_ocr_model():
451
  print(" - Using SDPA attention (HuggingFace Spaces/CPU optimized)")
452
 
453
  try:
454
- _ocr_model = AutoModel.from_pretrained(MODEL_NAME, **load_kwargs)
 
 
 
 
 
455
  except Exception as e:
456
  error_msg = str(e)
457
  print(f"⚠️ Model load error: {error_msg}")
458
- # If still fails due to LlamaFlashAttention2, patch again and retry
459
  if "LlamaFlashAttention2" in error_msg or "flash" in error_msg.lower():
460
- print(" - LlamaFlashAttention2 error detected, patching again...")
461
- _patch_deepseek_model_for_compatibility() # Patch again in case files changed
462
- print(" - Retrying model load...")
463
  _ocr_model = AutoModel.from_pretrained(MODEL_NAME, **load_kwargs)
464
  else:
465
  raise
 
76
 
77
  # DeepSeek-OCR Model Configuration - Maximum Quality Settings for M4 Mac (Apple Silicon)
78
  MODEL_NAME = "deepseek-ai/DeepSeek-OCR"
79
+ # PIN MODEL REVISION to prevent auto-updates that break compatibility
80
+ # Use a commit hash from https://huggingface.co/deepseek-ai/DeepSeek-OCR/tree/main
81
+ # This prevents "A new version of ... was downloaded" warnings and keeps code stable
82
+ MODEL_REVISION = os.getenv("DEEPSEEK_MODEL_REVISION", "2c968b433af61a059311cbf8997765023806a24d") # Latest stable commit
83
+
84
  # Detect Apple Silicon (M1/M2/M3/M4) - use MPS if available, otherwise CPU
85
  IS_APPLE_SILICON = platform.machine() == "arm64"
86
  USE_GPU = os.getenv("USE_GPU", "true").lower() == "true" and not IS_APPLE_SILICON # M4 uses MPS, not CUDA
 
148
  Path(os.getenv("TRANSFORMERS_CACHE", "")) / "hub" if os.getenv("TRANSFORMERS_CACHE") else None,
149
  ]
150
 
151
+ # Patch ALL files that might import LlamaFlashAttention2
152
  model_files = []
153
  for cache_dir in possible_cache_dirs:
154
  if cache_dir and cache_dir.exists():
155
  try:
156
+ # Find all Python files in the DeepSeek-OCR model directory
157
+ found = list(cache_dir.glob(f"**/models--deepseek-ai--DeepSeek-OCR/**/*.py"))
158
+ # Filter for the files that might import LlamaFlashAttention2
159
+ relevant_files = [
160
+ f for f in found
161
+ if any(pattern in f.name for pattern in [
162
+ 'modeling_deepseekv2.py',
163
+ 'modeling_deepseekocr.py',
164
+ 'modeling_llama.py' # In case it's in a nested location
165
+ ])
166
+ ]
167
+ model_files.extend(relevant_files)
168
  except Exception:
169
  continue
170
 
171
  if not model_files:
172
+ print("⚠️ Model files not found yet - will patch on first model load")
173
  return # Model not downloaded yet, will patch on first load
174
 
175
+ # Patch all found files
176
+ print(f"🔍 Found {len(model_files)} model file(s) to patch")
177
+ for model_file in model_files:
178
+ print(f" - {model_file.name}")
179
 
180
+ # Patch each file
 
 
 
 
 
 
 
 
 
 
 
 
181
  import re
182
 
183
+ for model_file in model_files:
184
+ print(f"\n🔧 Patching: {model_file.name}")
185
+ patched_this_file = False
 
186
 
187
+ # Check if already patched
188
+ try:
189
+ with open(model_file, 'r', encoding='utf-8') as f:
190
+ content = f.read()
191
+ if "LlamaFlashAttention2 = LlamaAttention" in content:
192
+ print(f" ✅ Already patched, skipping")
193
+ continue # Already patched, move to next file
194
+ original_content = content # Save original for comparison
195
+ except Exception as e:
196
+ print(f" ⚠️ Could not read model file to check patch status: {e}")
197
+ continue # Skip this file, try next
198
 
199
+ # More flexible approach: find and replace any import containing LlamaFlashAttention2
200
+ try:
201
+ # Pattern 1: Multi-line import with parentheses
202
+ # Matches: from transformers.models.llama.modeling_llama import (\n ...\n LlamaFlashAttention2\n )
203
+ multiline_pattern = r'from transformers\.models\.llama\.modeling_llama import\s*\([^)]*LlamaFlashAttention2[^)]*\)'
204
+
205
+ # Pattern 2: Single-line import with parentheses
206
+ singleline_parentheses_pattern = r'from transformers\.models\.llama\.modeling_llama import\s*\([^)]*LlamaFlashAttention2[^)]*\)'
207
+
208
+ # Pattern 3: Direct import without parentheses
209
+ direct_import_pattern = r'from transformers\.models\.llama\.modeling_llama import[^;]*LlamaFlashAttention2[^;\n]*'
210
+
211
+ # Try multiline replacement first (most common)
212
+ if re.search(multiline_pattern, content, re.MULTILINE | re.DOTALL):
213
+ # Create backup
214
  backup_file = model_file.with_suffix('.py.backup')
215
  try:
216
  import shutil
 
252
  LlamaFlashAttention2 = LlamaAttention"""
253
 
254
  content = re.sub(multiline_pattern, replacement, content, flags=re.MULTILINE | re.DOTALL)
255
+ patched_this_file = True
256
+ print(" 🔧 Applied multiline import patch")
257
 
258
  # Try single-line parentheses pattern
259
  elif re.search(singleline_parentheses_pattern, content):
 
287
  LlamaFlashAttention2 = LlamaAttention"""
288
 
289
  content = re.sub(singleline_parentheses_pattern, replacement, content)
290
+ patched_this_file = True
291
+ print(" 🔧 Applied single-line parentheses import patch")
292
 
293
  # Try direct import pattern (no parentheses)
294
  elif re.search(direct_import_pattern, content):
 
309
  LlamaFlashAttention2 = LlamaAttention"""
310
 
311
  content = re.sub(direct_import_pattern, replacement, content)
312
+ patched_this_file = True
313
+ print(" 🔧 Applied direct import patch")
314
 
315
+ # Last resort: find any line containing LlamaFlashAttention2 import and replace
316
+ if not patched_this_file:
317
  lines = content.split('\n')
318
  for i, line in enumerate(lines):
319
  if 'LlamaFlashAttention2' in line and 'from transformers.models.llama.modeling_llama' in line:
320
  # Create backup on first match
321
+ backup_file = model_file.with_suffix('.py.backup')
322
+ try:
323
+ import shutil
324
+ shutil.copy2(model_file, backup_file)
325
+ print(f" 📋 Created backup: {backup_file.name}")
326
+ except Exception as backup_err:
327
+ print(f" ⚠️ Could not create backup: {backup_err}")
 
328
 
329
  # Replace the import line(s)
330
  # Handle multiline imports
 
349
  # Replace the block
350
  lines[i:j+1] = replacement_lines
351
  content = '\n'.join(lines)
352
+ patched_this_file = True
353
+ print(f" 🔧 Applied line-by-line patch (lines {i}-{j})")
354
  break
355
  else:
356
  # Single line import
357
  lines[i] = "# Patch: LlamaFlashAttention2 import with fallback\ntry:\n from transformers.models.llama.modeling_llama import LlamaFlashAttention2\nexcept ImportError:\n from transformers.models.llama.modeling_llama import LlamaAttention\n LlamaFlashAttention2 = LlamaAttention"
358
  content = '\n'.join(lines)
359
+ patched_this_file = True
360
+ print(f" 🔧 Applied single-line patch (line {i})")
361
  break
362
+
363
+ # Last resort: find any line containing LlamaFlashAttention2 import and add fallback
364
+ if not patched_this_file:
365
  lines_for_fallback = content.split('\n')
366
  for i, line in enumerate(lines_for_fallback):
367
  if 'LlamaFlashAttention2' in line and 'from transformers.models.llama.modeling_llama' in line:
 
392
  ])
393
  new_lines.extend(lines_for_fallback[i+1:])
394
  content = '\n'.join(new_lines)
395
+ print(f" ✅ Added fallback import block after line {i}")
396
+ patched_this_file = True
397
  break
398
 
399
+ if patched_this_file:
400
  # Write file if content was modified
401
  # (fallback already writes immediately, regex patterns modify content then write here)
402
  with open(model_file, 'w', encoding='utf-8') as f:
403
  f.write(content)
404
+ print(f" ✅ Successfully patched: {model_file.name}")
405
  else:
406
+ print(f" ⚠️ Could not find LlamaFlashAttention2 import to patch in {model_file.name}")
407
  # Show a snippet around potential import lines for debugging
408
  lines = content.split('\n')
409
  for i, line in enumerate(lines):
 
434
  if _ocr_model is None or _ocr_tokenizer is None:
435
  async with _model_lock:
436
  if _ocr_model is None or _ocr_tokenizer is None:
 
 
 
 
437
  # Lazy import dependencies
438
+ # Note: Patching no longer needed - we pin transformers==4.46.3 and model revision
439
  AutoModel, AutoTokenizer = _get_transformers()
440
  torch = _get_torch()
441
 
 
445
  print(f" - Crop mode: {CROP_MODE} (best accuracy)")
446
 
447
  # Load tokenizer first (this triggers model download if needed)
448
+ # PIN REVISION to prevent auto-updates that break compatibility
449
+ print(" - Loading tokenizer (pinned to revision for stability)...")
450
  _ocr_tokenizer = AutoTokenizer.from_pretrained(
451
+ MODEL_NAME,
452
+ trust_remote_code=True,
453
+ revision=MODEL_REVISION # Pin revision to prevent code changes
454
  )
455
+ print(" - Tokenizer loaded successfully")
 
 
 
 
456
 
457
  # Load model with compatibility settings
458
  # Use SDPA attention to avoid LlamaFlashAttention2 import errors
 
468
  print(" - Using SDPA attention (HuggingFace Spaces/CPU optimized)")
469
 
470
  try:
471
+ # PIN REVISION to prevent auto-updates that break compatibility
472
+ _ocr_model = AutoModel.from_pretrained(
473
+ MODEL_NAME,
474
+ revision=MODEL_REVISION, # Pin revision to prevent code changes
475
+ **load_kwargs
476
+ )
477
  except Exception as e:
478
  error_msg = str(e)
479
  print(f"⚠️ Model load error: {error_msg}")
480
+ # If still fails, try without revision pin (fallback)
481
  if "LlamaFlashAttention2" in error_msg or "flash" in error_msg.lower():
482
+ print(" - LlamaFlashAttention2 error detected")
483
+ print(" - This should not happen with transformers==4.46.3")
484
+ print(" - Retrying without revision pin as fallback...")
485
  _ocr_model = AutoModel.from_pretrained(MODEL_NAME, **load_kwargs)
486
  else:
487
  raise