Shirochi commited on
Commit
f66ccd1
·
verified ·
1 Parent(s): 19e6730

Upload 7 files

Browse files
bubble_detector.py ADDED
@@ -0,0 +1,2031 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ bubble_detector.py - Modified version that works in frozen PyInstaller executables
3
+ Replace your bubble_detector.py with this version
4
+ """
5
+ import os
6
+ import sys
7
+ import json
8
+ import numpy as np
9
+ import cv2
10
+ from typing import List, Tuple, Optional, Dict, Any
11
+ import logging
12
+ import traceback
13
+ import hashlib
14
+ from pathlib import Path
15
+ import threading
16
+ import time
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Check if we're running in a frozen environment
22
+ IS_FROZEN = getattr(sys, 'frozen', False)
23
+ if IS_FROZEN:
24
+ # In frozen environment, set proper paths for ML libraries
25
+ MEIPASS = sys._MEIPASS
26
+ os.environ['TORCH_HOME'] = MEIPASS
27
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(MEIPASS, 'transformers')
28
+ os.environ['HF_HOME'] = os.path.join(MEIPASS, 'huggingface')
29
+ logger.info(f"Running in frozen environment: {MEIPASS}")
30
+
31
+ # Modified import checks for frozen environment
32
+ YOLO_AVAILABLE = False
33
+ YOLO = None
34
+ torch = None
35
+ TORCH_AVAILABLE = False
36
+ ONNX_AVAILABLE = False
37
+ TRANSFORMERS_AVAILABLE = False
38
+ RTDetrForObjectDetection = None
39
+ RTDetrImageProcessor = None
40
+ PIL_AVAILABLE = False
41
+
42
+ # Try to import YOLO dependencies with better error handling
43
+ if IS_FROZEN:
44
+ # In frozen environment, try harder to import
45
+ try:
46
+ # First try to import torch components individually
47
+ import torch
48
+ import torch.nn
49
+ import torch.cuda
50
+ TORCH_AVAILABLE = True
51
+ logger.info("✓ PyTorch loaded in frozen environment")
52
+ except Exception as e:
53
+ logger.warning(f"PyTorch not available in frozen environment: {e}")
54
+ TORCH_AVAILABLE = False
55
+ torch = None
56
+
57
+ # Try ultralytics after torch
58
+ if TORCH_AVAILABLE:
59
+ try:
60
+ from ultralytics import YOLO
61
+ YOLO_AVAILABLE = True
62
+ logger.info("✓ Ultralytics YOLO loaded in frozen environment")
63
+ except Exception as e:
64
+ logger.warning(f"Ultralytics not available in frozen environment: {e}")
65
+ YOLO_AVAILABLE = False
66
+
67
+ # Try transformers
68
+ try:
69
+ import transformers
70
+ # Try specific imports
71
+ try:
72
+ from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
73
+ TRANSFORMERS_AVAILABLE = True
74
+ logger.info("✓ Transformers RT-DETR loaded in frozen environment")
75
+ except ImportError:
76
+ # Try alternative import
77
+ try:
78
+ from transformers import AutoModel, AutoImageProcessor
79
+ RTDetrForObjectDetection = AutoModel
80
+ RTDetrImageProcessor = AutoImageProcessor
81
+ TRANSFORMERS_AVAILABLE = True
82
+ logger.info("✓ Transformers loaded with AutoModel fallback")
83
+ except:
84
+ TRANSFORMERS_AVAILABLE = False
85
+ logger.warning("Transformers RT-DETR not available in frozen environment")
86
+ except Exception as e:
87
+ logger.warning(f"Transformers not available in frozen environment: {e}")
88
+ TRANSFORMERS_AVAILABLE = False
89
+ else:
90
+ # Normal environment - original import logic
91
+ try:
92
+ from ultralytics import YOLO
93
+ YOLO_AVAILABLE = True
94
+ except:
95
+ YOLO_AVAILABLE = False
96
+ logger.warning("Ultralytics YOLO not available")
97
+
98
+ try:
99
+ import torch
100
+ # Test if cuda attribute exists
101
+ _ = torch.cuda
102
+ TORCH_AVAILABLE = True
103
+ except (ImportError, AttributeError):
104
+ TORCH_AVAILABLE = False
105
+ torch = None
106
+ logger.warning("PyTorch not available or incomplete")
107
+
108
+ try:
109
+ from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
110
+ try:
111
+ from transformers import RTDetrV2ForObjectDetection
112
+ RTDetrForObjectDetection = RTDetrV2ForObjectDetection
113
+ except ImportError:
114
+ pass
115
+ TRANSFORMERS_AVAILABLE = True
116
+ except:
117
+ TRANSFORMERS_AVAILABLE = False
118
+ logger.info("Transformers not available for RT-DETR")
119
+
120
+ # Configure ORT memory behavior before importing
121
+ try:
122
+ os.environ.setdefault('ORT_DISABLE_MEMORY_ARENA', '1')
123
+ except Exception:
124
+ pass
125
+ # ONNX Runtime - works well in frozen environments
126
+ try:
127
+ import onnxruntime as ort
128
+ ONNX_AVAILABLE = True
129
+ logger.info("✓ ONNX Runtime available")
130
+ except ImportError:
131
+ ONNX_AVAILABLE = False
132
+ logger.warning("ONNX Runtime not available")
133
+
134
+ # PIL
135
+ try:
136
+ from PIL import Image
137
+ PIL_AVAILABLE = True
138
+ except ImportError:
139
+ PIL_AVAILABLE = False
140
+ logger.info("PIL not available")
141
+
142
+
143
+ class BubbleDetector:
144
+ """
145
+ Combined YOLOv8 and RT-DETR speech bubble detector for comics and manga.
146
+ Supports multiple model formats and provides configurable detection.
147
+ Backward compatible with existing code while adding RT-DETR support.
148
+ """
149
+
150
+ # Process-wide shared RT-DETR to avoid concurrent meta-device loads
151
+ _rtdetr_init_lock = threading.Lock()
152
+ _rtdetr_shared_model = None
153
+ _rtdetr_shared_processor = None
154
+ _rtdetr_loaded = False
155
+ _rtdetr_repo_id = 'ogkalu/comic-text-and-bubble-detector'
156
+
157
+ # Shared RT-DETR (ONNX) across process to avoid device/context storms
158
+ _rtdetr_onnx_init_lock = threading.Lock()
159
+ _rtdetr_onnx_shared_session = None
160
+ _rtdetr_onnx_loaded = False
161
+ _rtdetr_onnx_providers = None
162
+ _rtdetr_onnx_model_path = None
163
+ # Limit concurrent runs to avoid device hangs. Defaults to 2 for better parallelism.
164
+ # Can be overridden via env DML_MAX_CONCURRENT or config rtdetr_max_concurrency
165
+ try:
166
+ _rtdetr_onnx_max_concurrent = int(os.environ.get('DML_MAX_CONCURRENT', '2'))
167
+ except Exception:
168
+ _rtdetr_onnx_max_concurrent = 2
169
+ _rtdetr_onnx_sema = threading.Semaphore(max(1, _rtdetr_onnx_max_concurrent))
170
+ _rtdetr_onnx_sema_initialized = False
171
+
172
+ def __init__(self, config_path: str = "config.json"):
173
+ """
174
+ Initialize the bubble detector.
175
+
176
+ Args:
177
+ config_path: Path to configuration file
178
+ """
179
+ # Set thread limits early if environment indicates single-threaded mode
180
+ try:
181
+ if os.environ.get('OMP_NUM_THREADS') == '1':
182
+ # Already in single-threaded mode, ensure it's applied to this process
183
+ # Check if torch is available at module level before trying to use it
184
+ if TORCH_AVAILABLE and torch is not None:
185
+ try:
186
+ torch.set_num_threads(1)
187
+ except (RuntimeError, AttributeError):
188
+ pass
189
+ try:
190
+ import cv2
191
+ cv2.setNumThreads(1)
192
+ except (ImportError, AttributeError):
193
+ pass
194
+ except Exception:
195
+ pass
196
+
197
+ self.config_path = config_path
198
+ self.config = self._load_config()
199
+
200
+ # YOLOv8 components (original)
201
+ self.model = None
202
+ self.model_loaded = False
203
+ self.model_type = None # 'yolo', 'onnx', or 'torch'
204
+ self.onnx_session = None
205
+
206
+ # RT-DETR components (new)
207
+ self.rtdetr_model = None
208
+ self.rtdetr_processor = None
209
+ self.rtdetr_loaded = False
210
+ self.rtdetr_repo = 'ogkalu/comic-text-and-bubble-detector'
211
+
212
+ # RT-DETR (ONNX) backend components
213
+ self.rtdetr_onnx_session = None
214
+ self.rtdetr_onnx_loaded = False
215
+ self.rtdetr_onnx_repo = 'ogkalu/comic-text-and-bubble-detector'
216
+
217
+ # RT-DETR class definitions
218
+ self.CLASS_BUBBLE = 0 # Empty speech bubble
219
+ self.CLASS_TEXT_BUBBLE = 1 # Bubble with text
220
+ self.CLASS_TEXT_FREE = 2 # Text without bubble
221
+
222
+ # Detection settings
223
+ self.default_confidence = 0.3
224
+ self.default_iou_threshold = 0.45
225
+ # Allow override from settings
226
+ try:
227
+ ocr_cfg = self.config.get('manga_settings', {}).get('ocr', {}) if isinstance(self.config, dict) else {}
228
+ self.default_max_detections = int(ocr_cfg.get('bubble_max_detections', 100))
229
+ self.max_det_yolo = int(ocr_cfg.get('bubble_max_detections_yolo', self.default_max_detections))
230
+ self.max_det_rtdetr = int(ocr_cfg.get('bubble_max_detections_rtdetr', self.default_max_detections))
231
+ except Exception:
232
+ self.default_max_detections = 100
233
+ self.max_det_yolo = 100
234
+ self.max_det_rtdetr = 100
235
+
236
+ # Cache directory for ONNX conversions
237
+ self.cache_dir = os.environ.get('BUBBLE_CACHE_DIR', 'models')
238
+ os.makedirs(self.cache_dir, exist_ok=True)
239
+
240
+ # RT-DETR concurrency setting from config
241
+ try:
242
+ rtdetr_max_conc = int(ocr_cfg.get('rtdetr_max_concurrency', 2))
243
+ # Update class-level semaphore if not yet initialized or if value changed
244
+ if not BubbleDetector._rtdetr_onnx_sema_initialized or rtdetr_max_conc != BubbleDetector._rtdetr_onnx_max_concurrent:
245
+ BubbleDetector._rtdetr_onnx_max_concurrent = max(1, rtdetr_max_conc)
246
+ BubbleDetector._rtdetr_onnx_sema = threading.Semaphore(BubbleDetector._rtdetr_onnx_max_concurrent)
247
+ BubbleDetector._rtdetr_onnx_sema_initialized = True
248
+ logger.info(f"RT-DETR concurrency set to: {BubbleDetector._rtdetr_onnx_max_concurrent}")
249
+ except Exception as e:
250
+ logger.warning(f"Failed to set RT-DETR concurrency: {e}")
251
+
252
+ # GPU availability
253
+ self.use_gpu = TORCH_AVAILABLE and torch.cuda.is_available()
254
+ self.device = 'cuda' if self.use_gpu else 'cpu'
255
+
256
+ # Quantization/precision settings
257
+ adv_cfg = self.config.get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
258
+ ocr_cfg = self.config.get('manga_settings', {}).get('ocr', {}) if isinstance(self.config, dict) else {}
259
+ env_quant = os.environ.get('MODEL_QUANTIZE', 'false').lower() == 'true'
260
+ self.quantize_enabled = bool(env_quant or adv_cfg.get('quantize_models', False) or ocr_cfg.get('quantize_bubble_detector', False))
261
+ self.quantize_dtype = str(adv_cfg.get('torch_precision', os.environ.get('TORCH_PRECISION', 'auto'))).lower()
262
+ # Prefer advanced.onnx_quantize; fall back to env or global quantize
263
+ self.onnx_quantize_enabled = bool(adv_cfg.get('onnx_quantize', os.environ.get('ONNX_QUANTIZE', 'false').lower() == 'true' or self.quantize_enabled))
264
+
265
+ # Stop flag support
266
+ self.stop_flag = None
267
+ self._stopped = False
268
+ self.log_callback = None
269
+
270
+ logger.info(f"🗨️ BubbleDetector initialized")
271
+ logger.info(f" GPU: {'Available' if self.use_gpu else 'Not available'}")
272
+ logger.info(f" YOLO: {'Available' if YOLO_AVAILABLE else 'Not installed'}")
273
+ logger.info(f" ONNX: {'Available' if ONNX_AVAILABLE else 'Not installed'}")
274
+ logger.info(f" RT-DETR: {'Available' if TRANSFORMERS_AVAILABLE else 'Not installed'}")
275
+ logger.info(f" Quantization: {'ENABLED' if self.quantize_enabled else 'disabled'} (torch_precision={self.quantize_dtype}, onnx_quantize={'on' if self.onnx_quantize_enabled else 'off'})" )
276
+
277
+ def _load_config(self) -> Dict[str, Any]:
278
+ """Load configuration from file."""
279
+ if os.path.exists(self.config_path):
280
+ try:
281
+ with open(self.config_path, 'r', encoding='utf-8') as f:
282
+ return json.load(f)
283
+ except Exception as e:
284
+ logger.warning(f"Failed to load config: {e}")
285
+ return {}
286
+
287
+ def _save_config(self):
288
+ """Save configuration to file."""
289
+ try:
290
+ with open(self.config_path, 'w', encoding='utf-8') as f:
291
+ json.dump(self.config, f, indent=2)
292
+ except Exception as e:
293
+ logger.error(f"Failed to save config: {e}")
294
+
295
+ def set_stop_flag(self, stop_flag):
296
+ """Set the stop flag for checking interruptions"""
297
+ self.stop_flag = stop_flag
298
+ self._stopped = False
299
+
300
+ def set_log_callback(self, log_callback):
301
+ """Set log callback for GUI integration"""
302
+ self.log_callback = log_callback
303
+
304
+ def _check_stop(self) -> bool:
305
+ """Check if stop has been requested"""
306
+ if self._stopped:
307
+ return True
308
+ if self.stop_flag and self.stop_flag.is_set():
309
+ self._stopped = True
310
+ return True
311
+ # Check global manga translator cancellation
312
+ try:
313
+ from manga_translator import MangaTranslator
314
+ if MangaTranslator.is_globally_cancelled():
315
+ self._stopped = True
316
+ return True
317
+ except Exception:
318
+ pass
319
+ return False
320
+
321
+ def _log(self, message: str, level: str = "info"):
322
+ """Log message with stop suppression"""
323
+ # Suppress logs when stopped (allow only essential stop confirmation messages)
324
+ if self._check_stop():
325
+ essential_stop_keywords = [
326
+ "⏹️ Translation stopped by user",
327
+ "⏹️ Bubble detection stopped",
328
+ "cleanup", "🧹"
329
+ ]
330
+ if not any(keyword in message for keyword in essential_stop_keywords):
331
+ return
332
+
333
+ if self.log_callback:
334
+ self.log_callback(message, level)
335
+ else:
336
+ logger.info(message) if level == 'info' else getattr(logger, level, logger.info)(message)
337
+
338
+ def reset_stop_flags(self):
339
+ """Reset stop flags when starting new processing"""
340
+ self._stopped = False
341
+
342
+ def load_model(self, model_path: str, force_reload: bool = False) -> bool:
343
+ """
344
+ Load a YOLOv8 model for bubble detection.
345
+
346
+ Args:
347
+ model_path: Path to model file (.pt, .onnx, or .torchscript)
348
+ force_reload: Force reload even if model is already loaded
349
+
350
+ Returns:
351
+ True if model loaded successfully, False otherwise
352
+ """
353
+ try:
354
+ # If given a Hugging Face repo ID (e.g., 'owner/name'), fetch detector.onnx into models/
355
+ if model_path and (('/' in model_path) and not os.path.exists(model_path)):
356
+ try:
357
+ from huggingface_hub import hf_hub_download
358
+ os.makedirs(self.cache_dir, exist_ok=True)
359
+ logger.info(f"📥 Resolving repo '{model_path}' to detector.onnx in {self.cache_dir}...")
360
+ resolved = hf_hub_download(repo_id=model_path, filename='detector.onnx', cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False)
361
+ if resolved and os.path.exists(resolved):
362
+ model_path = resolved
363
+ logger.info(f"✅ Downloaded detector.onnx to: {model_path}")
364
+ except Exception as repo_err:
365
+ logger.error(f"Failed to download from repo '{model_path}': {repo_err}")
366
+ if not os.path.exists(model_path):
367
+ logger.error(f"Model file not found: {model_path}")
368
+ return False
369
+
370
+ # Check if it's the same model already loaded
371
+ if self.model_loaded and not force_reload:
372
+ last_path = self.config.get('last_model_path', '')
373
+ if last_path == model_path:
374
+ logger.info("Model already loaded (same path)")
375
+ return True
376
+ else:
377
+ logger.info(f"Model path changed from {last_path} to {model_path}, reloading...")
378
+ force_reload = True
379
+
380
+ # Clear previous model if force reload
381
+ if force_reload:
382
+ logger.info("Force reloading model...")
383
+ self.model = None
384
+ self.onnx_session = None
385
+ self.model_loaded = False
386
+ self.model_type = None
387
+
388
+ logger.info(f"📥 Loading bubble detection model: {model_path}")
389
+
390
+ # Determine model type by extension
391
+ ext = Path(model_path).suffix.lower()
392
+
393
+ if ext in ['.pt', '.pth']:
394
+ if not YOLO_AVAILABLE:
395
+ logger.warning("Ultralytics package not available in this build")
396
+ logger.info("Bubble detection will be disabled - this is normal for lightweight builds")
397
+ # Don't return False immediately, try other fallbacks
398
+ self.model_loaded = False
399
+ return False
400
+
401
+ # Load YOLOv8 model
402
+ try:
403
+ self.model = YOLO(model_path)
404
+ self.model_type = 'yolo'
405
+
406
+ # Set to eval mode
407
+ if hasattr(self.model, 'model'):
408
+ self.model.model.eval()
409
+
410
+ # Move to GPU if available
411
+ if self.use_gpu and TORCH_AVAILABLE:
412
+ try:
413
+ self.model.to('cuda')
414
+ except Exception as gpu_error:
415
+ logger.warning(f"Could not move model to GPU: {gpu_error}")
416
+
417
+ logger.info("✅ YOLOv8 model loaded successfully")
418
+ # Apply optional FP16 precision to reduce VRAM if enabled
419
+ if self.quantize_enabled and self.use_gpu and TORCH_AVAILABLE:
420
+ try:
421
+ m = self.model.model if hasattr(self.model, 'model') else self.model
422
+ m.half()
423
+ logger.info("🔻 Applied FP16 precision to YOLO model (GPU)")
424
+ except Exception as _e:
425
+ logger.warning(f"Could not switch YOLO model to FP16: {_e}")
426
+
427
+ except Exception as yolo_error:
428
+ logger.error(f"Failed to load YOLO model: {yolo_error}")
429
+ return False
430
+
431
+ elif ext == '.onnx':
432
+ if not ONNX_AVAILABLE:
433
+ logger.warning("ONNX Runtime not available in this build")
434
+ logger.info("ONNX model support disabled - this is normal for lightweight builds")
435
+ return False
436
+
437
+ try:
438
+ # Load ONNX model
439
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.use_gpu else ['CPUExecutionProvider']
440
+ session_path = model_path
441
+ if self.quantize_enabled:
442
+ try:
443
+ from onnxruntime.quantization import quantize_dynamic, QuantType
444
+ quant_path = os.path.splitext(model_path)[0] + ".int8.onnx"
445
+ if not os.path.exists(quant_path) or os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
446
+ logger.info("🔻 Quantizing ONNX model weights to INT8 (dynamic)...")
447
+ quantize_dynamic(model_input=model_path, model_output=quant_path, weight_type=QuantType.QInt8, op_types_to_quantize=['Conv', 'MatMul'])
448
+ session_path = quant_path
449
+ self.config['last_onnx_quantized_path'] = quant_path
450
+ self._save_config()
451
+ logger.info(f"✅ Using quantized ONNX model: {quant_path}")
452
+ except Exception as qe:
453
+ logger.warning(f"ONNX quantization not applied: {qe}")
454
+ # Use conservative ORT memory options to reduce RAM growth
455
+ so = ort.SessionOptions()
456
+ try:
457
+ so.enable_mem_pattern = False
458
+ so.enable_cpu_mem_arena = False
459
+ except Exception:
460
+ pass
461
+ self.onnx_session = ort.InferenceSession(session_path, sess_options=so, providers=providers)
462
+ self.model_type = 'onnx'
463
+
464
+ logger.info("✅ ONNX model loaded successfully")
465
+
466
+ except Exception as onnx_error:
467
+ logger.error(f"Failed to load ONNX model: {onnx_error}")
468
+ return False
469
+
470
+ elif ext == '.torchscript':
471
+ if not TORCH_AVAILABLE:
472
+ logger.warning("PyTorch not available in this build")
473
+ logger.info("TorchScript model support disabled - this is normal for lightweight builds")
474
+ return False
475
+
476
+ try:
477
+ # Add safety check for torch being None
478
+ if torch is None:
479
+ logger.error("PyTorch module is None - cannot load TorchScript model")
480
+ return False
481
+
482
+ # Load TorchScript model
483
+ self.model = torch.jit.load(model_path, map_location='cpu')
484
+ self.model.eval()
485
+ self.model_type = 'torch'
486
+
487
+ if self.use_gpu:
488
+ try:
489
+ self.model = self.model.cuda()
490
+ except Exception as gpu_error:
491
+ logger.warning(f"Could not move TorchScript model to GPU: {gpu_error}")
492
+
493
+ logger.info("✅ TorchScript model loaded successfully")
494
+
495
+ # Optional FP16 precision on GPU
496
+ if self.quantize_enabled and self.use_gpu and TORCH_AVAILABLE:
497
+ try:
498
+ self.model = self.model.half()
499
+ logger.info("🔻 Applied FP16 precision to TorchScript model (GPU)")
500
+ except Exception as _e:
501
+ logger.warning(f"Could not switch TorchScript model to FP16: {_e}")
502
+
503
+ except Exception as torch_error:
504
+ logger.error(f"Failed to load TorchScript model: {torch_error}")
505
+ return False
506
+
507
+ else:
508
+ logger.error(f"Unsupported model format: {ext}")
509
+ logger.info("Supported formats: .pt/.pth (YOLOv8), .onnx (ONNX), .torchscript (TorchScript)")
510
+ return False
511
+
512
+ # Only set loaded if we actually succeeded
513
+ self.model_loaded = True
514
+ self.config['last_model_path'] = model_path
515
+ self.config['model_type'] = self.model_type
516
+ self._save_config()
517
+
518
+ return True
519
+
520
+ except Exception as e:
521
+ logger.error(f"Failed to load model: {e}")
522
+ logger.error(traceback.format_exc())
523
+ self.model_loaded = False
524
+
525
+ # Provide helpful context for .exe users
526
+ logger.info("Note: If running from .exe, some ML libraries may not be included")
527
+ logger.info("This is normal for lightweight builds - bubble detection will be disabled")
528
+
529
+ return False
530
+
531
+ def load_rtdetr_model(self, model_path: str = None, model_id: str = None, force_reload: bool = False) -> bool:
532
+ """
533
+ Load RT-DETR model for advanced bubble and text detection.
534
+ This implementation avoids the 'meta tensor' copy error by:
535
+ - Serializing the entire load under a class lock (no concurrent loads)
536
+ - Loading directly onto the target device (CUDA if available) via device_map='auto'
537
+ - Avoiding .to() on a potentially-meta model; no device migration post-load
538
+
539
+ Args:
540
+ model_path: Optional path to local model
541
+ model_id: Optional HuggingFace model ID (default: 'ogkalu/comic-text-and-bubble-detector')
542
+ force_reload: Force reload even if already loaded
543
+
544
+ Returns:
545
+ True if successful, False otherwise
546
+ """
547
+ if not TRANSFORMERS_AVAILABLE:
548
+ logger.error("Transformers library required for RT-DETR. Install with: pip install transformers")
549
+ return False
550
+
551
+ if not PIL_AVAILABLE:
552
+ logger.error("PIL required for RT-DETR. Install with: pip install pillow")
553
+ return False
554
+
555
+ if self.rtdetr_loaded and not force_reload:
556
+ logger.info("RT-DETR model already loaded")
557
+ return True
558
+
559
+ # Fast path: if shared already loaded and not forcing reload, attach
560
+ if BubbleDetector._rtdetr_loaded and not force_reload:
561
+ self.rtdetr_model = BubbleDetector._rtdetr_shared_model
562
+ self.rtdetr_processor = BubbleDetector._rtdetr_shared_processor
563
+ self.rtdetr_loaded = True
564
+ logger.info("RT-DETR model attached from shared cache")
565
+ return True
566
+
567
+ # Serialize the ENTIRE loading sequence to avoid concurrent init issues
568
+ with BubbleDetector._rtdetr_init_lock:
569
+ try:
570
+ # Re-check after acquiring lock
571
+ if BubbleDetector._rtdetr_loaded and not force_reload:
572
+ self.rtdetr_model = BubbleDetector._rtdetr_shared_model
573
+ self.rtdetr_processor = BubbleDetector._rtdetr_shared_processor
574
+ self.rtdetr_loaded = True
575
+ logger.info("RT-DETR model attached from shared cache (post-lock)")
576
+ return True
577
+
578
+ # Use custom model_id if provided, otherwise use default
579
+ repo_id = model_id if model_id else self.rtdetr_repo
580
+ logger.info(f"📥 Loading RT-DETR model from {repo_id}...")
581
+
582
+ # Ensure TorchDynamo/compile doesn't interfere on some builds
583
+ try:
584
+ os.environ.setdefault('TORCHDYNAMO_DISABLE', '1')
585
+ except Exception:
586
+ pass
587
+
588
+ # Decide device strategy
589
+ gpu_available = bool(TORCH_AVAILABLE and hasattr(torch, 'cuda') and torch.cuda.is_available())
590
+ device_map = 'auto' if gpu_available else None
591
+ # Choose dtype
592
+ dtype = None
593
+ if TORCH_AVAILABLE:
594
+ try:
595
+ dtype = torch.float16 if gpu_available else torch.float32
596
+ except Exception:
597
+ dtype = None
598
+ low_cpu = True if gpu_available else False
599
+
600
+ # Load processor (once)
601
+ self.rtdetr_processor = RTDetrImageProcessor.from_pretrained(
602
+ repo_id,
603
+ size={"width": 640, "height": 640},
604
+ cache_dir=self.cache_dir if not model_path else None
605
+ )
606
+
607
+ # Prepare kwargs for from_pretrained
608
+ from_kwargs = {
609
+ 'cache_dir': self.cache_dir if not model_path else None,
610
+ 'low_cpu_mem_usage': low_cpu,
611
+ 'device_map': device_map,
612
+ }
613
+ # Note: dtype is handled via torch_dtype parameter in newer transformers
614
+ if dtype is not None:
615
+ from_kwargs['torch_dtype'] = dtype
616
+
617
+ # First attempt: load directly to target (CUDA if available)
618
+ try:
619
+ self.rtdetr_model = RTDetrForObjectDetection.from_pretrained(
620
+ model_path if model_path else repo_id,
621
+ **from_kwargs,
622
+ )
623
+ except Exception as primary_err:
624
+ # Fallback to a simple CPU load (no device move) if CUDA path fails
625
+ logger.warning(f"RT-DETR primary load failed ({primary_err}); retrying on CPU...")
626
+ from_kwargs_fallback = {
627
+ 'cache_dir': self.cache_dir if not model_path else None,
628
+ 'low_cpu_mem_usage': False,
629
+ 'device_map': None,
630
+ }
631
+ if TORCH_AVAILABLE:
632
+ from_kwargs_fallback['torch_dtype'] = torch.float32
633
+ self.rtdetr_model = RTDetrForObjectDetection.from_pretrained(
634
+ model_path if model_path else repo_id,
635
+ **from_kwargs_fallback,
636
+ )
637
+
638
+ # Optional dynamic quantization for linear layers (CPU only)
639
+ if self.quantize_enabled and TORCH_AVAILABLE and (not gpu_available):
640
+ try:
641
+ try:
642
+ import torch.ao.quantization as tq
643
+ quantize_dynamic = tq.quantize_dynamic # type: ignore
644
+ except Exception:
645
+ import torch.quantization as tq # type: ignore
646
+ quantize_dynamic = tq.quantize_dynamic # type: ignore
647
+ self.rtdetr_model = quantize_dynamic(self.rtdetr_model, {torch.nn.Linear}, dtype=torch.qint8)
648
+ logger.info("🔻 Applied dynamic INT8 quantization to RT-DETR linear layers (CPU)")
649
+ except Exception as qe:
650
+ logger.warning(f"RT-DETR dynamic quantization skipped: {qe}")
651
+
652
+ # Finalize
653
+ self.rtdetr_model.eval()
654
+
655
+ # Sanity check: ensure no parameter is left on 'meta' device
656
+ try:
657
+ for n, p in self.rtdetr_model.named_parameters():
658
+ dev = getattr(p, 'device', None)
659
+ if dev is not None and getattr(dev, 'type', '') == 'meta':
660
+ raise RuntimeError(f"Parameter {n} is on 'meta' device after load")
661
+ except Exception as e:
662
+ logger.error(f"RT-DETR load sanity check failed: {e}")
663
+ self.rtdetr_loaded = False
664
+ return False
665
+
666
+ # Publish shared cache
667
+ BubbleDetector._rtdetr_shared_model = self.rtdetr_model
668
+ BubbleDetector._rtdetr_shared_processor = self.rtdetr_processor
669
+ BubbleDetector._rtdetr_loaded = True
670
+ BubbleDetector._rtdetr_repo_id = repo_id
671
+
672
+ self.rtdetr_loaded = True
673
+
674
+ # Save the model ID that was used
675
+ self.config['rtdetr_loaded'] = True
676
+ self.config['rtdetr_model_id'] = repo_id
677
+ self._save_config()
678
+
679
+ loc = 'CUDA' if gpu_available else 'CPU'
680
+ logger.info(f"✅ RT-DETR model loaded successfully ({loc})")
681
+ logger.info(" Classes: Empty bubbles, Text bubbles, Free text")
682
+
683
+ # Auto-convert to ONNX for RT-DETR only if explicitly enabled
684
+ if os.environ.get('AUTO_CONVERT_RTDETR_ONNX', 'false').lower() == 'true':
685
+ onnx_path = os.path.join(self.cache_dir, 'rtdetr_comic.onnx')
686
+ if self.convert_to_onnx('rtdetr', onnx_path):
687
+ logger.info("🚀 RT-DETR converted to ONNX for faster inference")
688
+ # Store ONNX path for later use
689
+ self.config['rtdetr_onnx_path'] = onnx_path
690
+ self._save_config()
691
+ # Optionally quantize ONNX for reduced RAM
692
+ if self.onnx_quantize_enabled:
693
+ try:
694
+ from onnxruntime.quantization import quantize_dynamic, QuantType
695
+ quant_path = os.path.splitext(onnx_path)[0] + ".int8.onnx"
696
+ if not os.path.exists(quant_path) or os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
697
+ logger.info("🔻 Quantizing RT-DETR ONNX to INT8 (dynamic)...")
698
+ quantize_dynamic(model_input=onnx_path, model_output=quant_path, weight_type=QuantType.QInt8, op_types_to_quantize=['Conv', 'MatMul'])
699
+ self.config['rtdetr_onnx_quantized_path'] = quant_path
700
+ self._save_config()
701
+ logger.info(f"✅ Quantized RT-DETR ONNX saved to: {quant_path}")
702
+ except Exception as qe:
703
+ logger.warning(f"ONNX quantization for RT-DETR skipped: {qe}")
704
+ else:
705
+ logger.info("ℹ️ Skipping RT-DETR ONNX export (converter not supported in current environment)")
706
+
707
+ return True
708
+ except Exception as e:
709
+ logger.error(f"❌ Failed to load RT-DETR: {e}")
710
+ self.rtdetr_loaded = False
711
+ return False
712
+
713
+ def check_rtdetr_available(self, model_id: str = None) -> bool:
714
+ """
715
+ Check if RT-DETR model is available (cached).
716
+
717
+ Args:
718
+ model_id: Optional HuggingFace model ID
719
+
720
+ Returns:
721
+ True if model is cached and available
722
+ """
723
+ try:
724
+ from pathlib import Path
725
+
726
+ # Use provided model_id or default
727
+ repo_id = model_id if model_id else self.rtdetr_repo
728
+
729
+ # Check HuggingFace cache
730
+ cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
731
+ model_id_formatted = repo_id.replace("/", "--")
732
+
733
+ # Look for model folder
734
+ model_folders = list(cache_dir.glob(f"models--{model_id_formatted}*"))
735
+
736
+ if model_folders:
737
+ for folder in model_folders:
738
+ if (folder / "snapshots").exists():
739
+ snapshots = list((folder / "snapshots").iterdir())
740
+ if snapshots:
741
+ return True
742
+
743
+ return False
744
+
745
+ except Exception:
746
+ return False
747
+
748
+ def detect_bubbles(self,
749
+ image_path: str,
750
+ confidence: float = None,
751
+ iou_threshold: float = None,
752
+ max_detections: int = None,
753
+ use_rtdetr: bool = None) -> List[Tuple[int, int, int, int]]:
754
+ """
755
+ Detect speech bubbles in an image (backward compatible method).
756
+
757
+ Args:
758
+ image_path: Path to image file
759
+ confidence: Minimum confidence threshold (0-1)
760
+ iou_threshold: IOU threshold for NMS (0-1)
761
+ max_detections: Maximum number of detections to return
762
+ use_rtdetr: If True, use RT-DETR instead of YOLOv8 (if available)
763
+
764
+ Returns:
765
+ List of bubble bounding boxes as (x, y, width, height) tuples
766
+ """
767
+ # Check for stop at start
768
+ if self._check_stop():
769
+ self._log("⏹️ Bubble detection stopped by user", "warning")
770
+ return []
771
+
772
+ # Decide which model to use
773
+ if use_rtdetr is None:
774
+ # Auto-select: prefer RT-DETR if available
775
+ use_rtdetr = self.rtdetr_loaded
776
+
777
+ if use_rtdetr:
778
+ # Prefer ONNX backend if available, else PyTorch
779
+ if getattr(self, 'rtdetr_onnx_loaded', False):
780
+ results = self.detect_with_rtdetr_onnx(
781
+ image_path=image_path,
782
+ confidence=confidence,
783
+ return_all_bubbles=True
784
+ )
785
+ return results
786
+ if self.rtdetr_loaded:
787
+ results = self.detect_with_rtdetr(
788
+ image_path=image_path,
789
+ confidence=confidence,
790
+ return_all_bubbles=True
791
+ )
792
+ return results
793
+
794
+ # Original YOLOv8 detection
795
+ if not self.model_loaded:
796
+ logger.error("No model loaded. Call load_model() first.")
797
+ return []
798
+
799
+ # Use defaults if not specified
800
+ confidence = confidence or self.default_confidence
801
+ iou_threshold = iou_threshold or self.default_iou_threshold
802
+ max_detections = max_detections or self.default_max_detections
803
+
804
+ try:
805
+ # Load image
806
+ image = cv2.imread(image_path)
807
+ if image is None:
808
+ logger.error(f"Failed to load image: {image_path}")
809
+ return []
810
+
811
+ h, w = image.shape[:2]
812
+ self._log(f"🔍 Detecting bubbles in {w}x{h} image")
813
+
814
+ # Check for stop before inference
815
+ if self._check_stop():
816
+ self._log("⏹️ Bubble detection inference stopped by user", "warning")
817
+ return []
818
+
819
+ if self.model_type == 'yolo':
820
+ # YOLOv8 inference
821
+ results = self.model(
822
+ image_path,
823
+ conf=confidence,
824
+ iou=iou_threshold,
825
+ max_det=min(max_detections, getattr(self, 'max_det_yolo', max_detections)),
826
+ verbose=False
827
+ )
828
+
829
+ bubbles = []
830
+ for r in results:
831
+ if r.boxes is not None:
832
+ for box in r.boxes:
833
+ # Get box coordinates
834
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
835
+ x, y = int(x1), int(y1)
836
+ width = int(x2 - x1)
837
+ height = int(y2 - y1)
838
+
839
+ # Get confidence
840
+ conf = float(box.conf[0])
841
+
842
+ # Add to list
843
+ if len(bubbles) < max_detections:
844
+ bubbles.append((x, y, width, height))
845
+
846
+ logger.debug(f" Bubble: ({x},{y}) {width}x{height} conf={conf:.2f}")
847
+
848
+ elif self.model_type == 'onnx':
849
+ # ONNX inference
850
+ bubbles = self._detect_with_onnx(image, confidence, iou_threshold, max_detections)
851
+
852
+ elif self.model_type == 'torch':
853
+ # TorchScript inference
854
+ bubbles = self._detect_with_torchscript(image, confidence, iou_threshold, max_detections)
855
+
856
+ else:
857
+ logger.error(f"Unknown model type: {self.model_type}")
858
+ return []
859
+
860
+ logger.info(f"✅ Detected {len(bubbles)} speech bubbles")
861
+ time.sleep(0.1) # Brief pause for stability
862
+ logger.debug("💤 Bubble detection pausing briefly for stability")
863
+ return bubbles
864
+
865
+ except Exception as e:
866
+ logger.error(f"Detection failed: {e}")
867
+ logger.error(traceback.format_exc())
868
+ return []
869
+
870
+ def detect_with_rtdetr(self,
871
+ image_path: str = None,
872
+ image: np.ndarray = None,
873
+ confidence: float = None,
874
+ return_all_bubbles: bool = False) -> Any:
875
+ """
876
+ Detect using RT-DETR model with 3-class detection (PyTorch backend).
877
+
878
+ Args:
879
+ image_path: Path to image file
880
+ image: Image array (BGR format)
881
+ confidence: Confidence threshold
882
+ return_all_bubbles: If True, return list of bubble boxes (for compatibility)
883
+ If False, return dict with all classes
884
+
885
+ Returns:
886
+ List of bubbles if return_all_bubbles=True, else dict with classes
887
+ """
888
+ # Check for stop at start
889
+ if self._check_stop():
890
+ self._log("⏹️ RT-DETR detection stopped by user", "warning")
891
+ if return_all_bubbles:
892
+ return []
893
+ return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
894
+
895
+ if not self.rtdetr_loaded:
896
+ self._log("RT-DETR not loaded. Call load_rtdetr_model() first.", "warning")
897
+ if return_all_bubbles:
898
+ return []
899
+ return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
900
+
901
+ confidence = confidence or self.default_confidence
902
+
903
+ try:
904
+ # Load image
905
+ if image_path:
906
+ image = cv2.imread(image_path)
907
+ elif image is None:
908
+ logger.error("No image provided")
909
+ if return_all_bubbles:
910
+ return []
911
+ return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
912
+
913
+ # Convert BGR to RGB for PIL
914
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
915
+ pil_image = Image.fromarray(image_rgb)
916
+
917
+ # Prepare image for model
918
+ inputs = self.rtdetr_processor(images=pil_image, return_tensors="pt")
919
+
920
+ # Move inputs to the same device as the model and match model dtype for floating tensors
921
+ model_device = next(self.rtdetr_model.parameters()).device if self.rtdetr_model is not None else (torch.device('cpu') if TORCH_AVAILABLE else 'cpu')
922
+ model_dtype = None
923
+ if TORCH_AVAILABLE and self.rtdetr_model is not None:
924
+ try:
925
+ model_dtype = next(self.rtdetr_model.parameters()).dtype
926
+ except Exception:
927
+ model_dtype = None
928
+
929
+ if TORCH_AVAILABLE:
930
+ new_inputs = {}
931
+ for k, v in inputs.items():
932
+ if isinstance(v, torch.Tensor):
933
+ v = v.to(model_device)
934
+ if model_dtype is not None and torch.is_floating_point(v):
935
+ v = v.to(model_dtype)
936
+ new_inputs[k] = v
937
+ inputs = new_inputs
938
+
939
+ # Run inference with autocast when model is half/bfloat16 on CUDA
940
+ use_amp = TORCH_AVAILABLE and hasattr(model_device, 'type') and model_device.type == 'cuda' and (model_dtype in (torch.float16, torch.bfloat16))
941
+ autocast_dtype = model_dtype if model_dtype in (torch.float16, torch.bfloat16) else None
942
+
943
+ with torch.no_grad():
944
+ if use_amp and autocast_dtype is not None:
945
+ with torch.autocast('cuda', dtype=autocast_dtype):
946
+ outputs = self.rtdetr_model(**inputs)
947
+ else:
948
+ outputs = self.rtdetr_model(**inputs)
949
+
950
+ # Brief pause for stability after inference
951
+ time.sleep(0.1)
952
+ logger.debug("💤 RT-DETR inference pausing briefly for stability")
953
+
954
+ # Post-process results
955
+ target_sizes = torch.tensor([pil_image.size[::-1]]) if TORCH_AVAILABLE else None
956
+ if TORCH_AVAILABLE and hasattr(model_device, 'type') and model_device.type == "cuda":
957
+ target_sizes = target_sizes.to(model_device)
958
+
959
+ results = self.rtdetr_processor.post_process_object_detection(
960
+ outputs,
961
+ target_sizes=target_sizes,
962
+ threshold=confidence
963
+ )[0]
964
+
965
+ # Apply per-detector cap if configured
966
+ cap = getattr(self, 'max_det_rtdetr', self.default_max_detections)
967
+ if cap and len(results['boxes']) > cap:
968
+ # Keep top-scoring first
969
+ scores = results['scores']
970
+ top_idx = scores.topk(k=cap).indices if hasattr(scores, 'topk') else range(cap)
971
+ results = {
972
+ 'boxes': [results['boxes'][i] for i in top_idx],
973
+ 'scores': [results['scores'][i] for i in top_idx],
974
+ 'labels': [results['labels'][i] for i in top_idx]
975
+ }
976
+
977
+ logger.info(f"📊 RT-DETR found {len(results['boxes'])} detections above {confidence:.2f} confidence")
978
+
979
+ # Apply NMS to remove duplicate detections
980
+ # Group detections by class
981
+ class_detections = {self.CLASS_BUBBLE: [], self.CLASS_TEXT_BUBBLE: [], self.CLASS_TEXT_FREE: []}
982
+
983
+ for box, score, label in zip(results['boxes'], results['scores'], results['labels']):
984
+ x1, y1, x2, y2 = map(float, box.tolist())
985
+ label_id = label.item()
986
+ if label_id in class_detections:
987
+ class_detections[label_id].append((x1, y1, x2, y2, float(score.item())))
988
+
989
+ # Apply NMS per class to remove duplicates
990
+ def compute_iou(box1, box2):
991
+ """Compute IoU between two boxes (x1, y1, x2, y2)"""
992
+ x1_1, y1_1, x2_1, y2_1 = box1[:4]
993
+ x1_2, y1_2, x2_2, y2_2 = box2[:4]
994
+
995
+ # Intersection
996
+ x_left = max(x1_1, x1_2)
997
+ y_top = max(y1_1, y1_2)
998
+ x_right = min(x2_1, x2_2)
999
+ y_bottom = min(y2_1, y2_2)
1000
+
1001
+ if x_right < x_left or y_bottom < y_top:
1002
+ return 0.0
1003
+
1004
+ intersection = (x_right - x_left) * (y_bottom - y_top)
1005
+
1006
+ # Union
1007
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
1008
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
1009
+ union = area1 + area2 - intersection
1010
+
1011
+ return intersection / union if union > 0 else 0.0
1012
+
1013
+ def apply_nms(boxes_with_scores, iou_threshold=0.45):
1014
+ """Apply Non-Maximum Suppression"""
1015
+ if not boxes_with_scores:
1016
+ return []
1017
+
1018
+ # Sort by score (descending)
1019
+ sorted_boxes = sorted(boxes_with_scores, key=lambda x: x[4], reverse=True)
1020
+ keep = []
1021
+
1022
+ while sorted_boxes:
1023
+ # Keep the box with highest score
1024
+ current = sorted_boxes.pop(0)
1025
+ keep.append(current)
1026
+
1027
+ # Remove boxes with high IoU
1028
+ sorted_boxes = [box for box in sorted_boxes if compute_iou(current, box) < iou_threshold]
1029
+
1030
+ return keep
1031
+
1032
+ # Apply NMS and organize by class
1033
+ detections = {
1034
+ 'bubbles': [], # Empty speech bubbles
1035
+ 'text_bubbles': [], # Bubbles with text
1036
+ 'text_free': [] # Text without bubbles
1037
+ }
1038
+
1039
+ for class_id, boxes_list in class_detections.items():
1040
+ nms_boxes = apply_nms(boxes_list, iou_threshold=self.default_iou_threshold)
1041
+
1042
+ for x1, y1, x2, y2, scr in nms_boxes:
1043
+ width = int(x2 - x1)
1044
+ height = int(y2 - y1)
1045
+ # Store as (x, y, width, height) to match YOLOv8 format
1046
+ bbox = (int(x1), int(y1), width, height)
1047
+
1048
+ if class_id == self.CLASS_BUBBLE:
1049
+ detections['bubbles'].append(bbox)
1050
+ elif class_id == self.CLASS_TEXT_BUBBLE:
1051
+ detections['text_bubbles'].append(bbox)
1052
+ elif class_id == self.CLASS_TEXT_FREE:
1053
+ detections['text_free'].append(bbox)
1054
+
1055
+ # Stop early if we hit the configured cap across all classes
1056
+ total_count = len(detections['bubbles']) + len(detections['text_bubbles']) + len(detections['text_free'])
1057
+ if total_count >= (self.config.get('manga_settings', {}).get('ocr', {}).get('bubble_max_detections', self.default_max_detections) if isinstance(self.config, dict) else self.default_max_detections):
1058
+ break
1059
+
1060
+ # Log results
1061
+ total = len(detections['bubbles']) + len(detections['text_bubbles']) + len(detections['text_free'])
1062
+ logger.info(f"✅ RT-DETR detected {total} objects:")
1063
+ logger.info(f" - Empty bubbles: {len(detections['bubbles'])}")
1064
+ logger.info(f" - Text bubbles: {len(detections['text_bubbles'])}")
1065
+ logger.info(f" - Free text: {len(detections['text_free'])}")
1066
+
1067
+ # Return format based on compatibility mode
1068
+ if return_all_bubbles:
1069
+ # Return all bubbles (empty + with text) for backward compatibility
1070
+ all_bubbles = detections['bubbles'] + detections['text_bubbles']
1071
+ return all_bubbles
1072
+ else:
1073
+ return detections
1074
+
1075
+ except Exception as e:
1076
+ logger.error(f"RT-DETR detection failed: {e}")
1077
+ logger.error(traceback.format_exc())
1078
+ if return_all_bubbles:
1079
+ return []
1080
+ return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1081
+
1082
+ def detect_all_text_regions(self, image_path: str = None, image: np.ndarray = None) -> List[Tuple[int, int, int, int]]:
1083
+ """
1084
+ Detect all text regions using RT-DETR (both in bubbles and free text).
1085
+
1086
+ Returns:
1087
+ List of bounding boxes for all text regions
1088
+ """
1089
+ if not self.rtdetr_loaded:
1090
+ logger.warning("RT-DETR required for text detection")
1091
+ return []
1092
+
1093
+ detections = self.detect_with_rtdetr(image_path=image_path, image=image, return_all_bubbles=False)
1094
+
1095
+ # Combine text bubbles and free text
1096
+ all_text = detections['text_bubbles'] + detections['text_free']
1097
+
1098
+ logger.info(f"📝 Found {len(all_text)} text regions total")
1099
+ return all_text
1100
+
1101
+ def _detect_with_onnx(self, image: np.ndarray, confidence: float,
1102
+ iou_threshold: float, max_detections: int) -> List[Tuple[int, int, int, int]]:
1103
+ """Run detection using ONNX model."""
1104
+ # Preprocess image
1105
+ img_size = 640 # Standard YOLOv8 input size
1106
+ img_resized = cv2.resize(image, (img_size, img_size))
1107
+ img_norm = img_resized.astype(np.float32) / 255.0
1108
+ img_transposed = np.transpose(img_norm, (2, 0, 1))
1109
+ img_batch = np.expand_dims(img_transposed, axis=0)
1110
+
1111
+ # Run inference
1112
+ input_name = self.onnx_session.get_inputs()[0].name
1113
+ outputs = self.onnx_session.run(None, {input_name: img_batch})
1114
+
1115
+ # Process outputs (YOLOv8 format)
1116
+ predictions = outputs[0][0] # Remove batch dimension
1117
+
1118
+ # Filter by confidence and apply NMS
1119
+ bubbles = []
1120
+ boxes = []
1121
+ scores = []
1122
+
1123
+ for pred in predictions.T: # Transpose to get predictions per detection
1124
+ if len(pred) >= 5:
1125
+ x_center, y_center, width, height, obj_conf = pred[:5]
1126
+
1127
+ if obj_conf >= confidence:
1128
+ # Convert to corner coordinates
1129
+ x1 = x_center - width / 2
1130
+ y1 = y_center - height / 2
1131
+
1132
+ # Scale to original image size
1133
+ h, w = image.shape[:2]
1134
+ x1 = int(x1 * w / img_size)
1135
+ y1 = int(y1 * h / img_size)
1136
+ width = int(width * w / img_size)
1137
+ height = int(height * h / img_size)
1138
+
1139
+ boxes.append([x1, y1, x1 + width, y1 + height])
1140
+ scores.append(float(obj_conf))
1141
+
1142
+ # Apply NMS
1143
+ if boxes:
1144
+ indices = cv2.dnn.NMSBoxes(boxes, scores, confidence, iou_threshold)
1145
+ if len(indices) > 0:
1146
+ indices = indices.flatten()[:max_detections]
1147
+ for i in indices:
1148
+ x1, y1, x2, y2 = boxes[i]
1149
+ bubbles.append((x1, y1, x2 - x1, y2 - y1))
1150
+
1151
+ return bubbles
1152
+
1153
+ def _detect_with_torchscript(self, image: np.ndarray, confidence: float,
1154
+ iou_threshold: float, max_detections: int) -> List[Tuple[int, int, int, int]]:
1155
+ """Run detection using TorchScript model."""
1156
+ # Similar to ONNX but using PyTorch tensors
1157
+ img_size = 640
1158
+ img_resized = cv2.resize(image, (img_size, img_size))
1159
+ img_norm = img_resized.astype(np.float32) / 255.0
1160
+ img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0)
1161
+
1162
+ if self.use_gpu:
1163
+ img_tensor = img_tensor.cuda()
1164
+
1165
+ with torch.no_grad():
1166
+ outputs = self.model(img_tensor)
1167
+
1168
+ # Process outputs similar to ONNX
1169
+ # Implementation depends on exact model output format
1170
+ # This is a placeholder - adjust based on your model
1171
+ return []
1172
+
1173
+ def visualize_detections(self, image_path: str, bubbles: List[Tuple[int, int, int, int]] = None,
1174
+ output_path: str = None, use_rtdetr: bool = False) -> np.ndarray:
1175
+ """
1176
+ Visualize detected bubbles on the image.
1177
+
1178
+ Args:
1179
+ image_path: Path to original image
1180
+ bubbles: List of bubble bounding boxes (if None, will detect)
1181
+ output_path: Optional path to save visualization
1182
+ use_rtdetr: Use RT-DETR for visualization with class colors
1183
+
1184
+ Returns:
1185
+ Image with drawn bounding boxes
1186
+ """
1187
+ image = cv2.imread(image_path)
1188
+ if image is None:
1189
+ logger.error(f"Failed to load image: {image_path}")
1190
+ return None
1191
+
1192
+ vis_image = image.copy()
1193
+
1194
+ if use_rtdetr and self.rtdetr_loaded:
1195
+ # RT-DETR visualization with different colors per class
1196
+ detections = self.detect_with_rtdetr(image_path=image_path, return_all_bubbles=False)
1197
+
1198
+ # Colors for each class
1199
+ colors = {
1200
+ 'bubbles': (0, 255, 0), # Green for empty bubbles
1201
+ 'text_bubbles': (255, 0, 0), # Blue for text bubbles
1202
+ 'text_free': (0, 0, 255) # Red for free text
1203
+ }
1204
+
1205
+ # Draw detections
1206
+ for class_name, bboxes in detections.items():
1207
+ color = colors[class_name]
1208
+
1209
+ for i, (x, y, w, h) in enumerate(bboxes):
1210
+ # Draw rectangle
1211
+ cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
1212
+
1213
+ # Add label
1214
+ label = f"{class_name.replace('_', ' ').title()} {i+1}"
1215
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
1216
+ cv2.rectangle(vis_image, (x, y - label_size[1] - 4),
1217
+ (x + label_size[0], y), color, -1)
1218
+ cv2.putText(vis_image, label, (x, y - 2),
1219
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
1220
+ else:
1221
+ # Original YOLOv8 visualization
1222
+ if bubbles is None:
1223
+ bubbles = self.detect_bubbles(image_path)
1224
+
1225
+ # Draw bounding boxes
1226
+ for i, (x, y, w, h) in enumerate(bubbles):
1227
+ # Draw rectangle
1228
+ color = (0, 255, 0) # Green
1229
+ thickness = 2
1230
+ cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, thickness)
1231
+
1232
+ # Add label
1233
+ label = f"Bubble {i+1}"
1234
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
1235
+ cv2.rectangle(vis_image, (x, y - label_size[1] - 4), (x + label_size[0], y), color, -1)
1236
+ cv2.putText(vis_image, label, (x, y - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
1237
+
1238
+ # Save if output path provided
1239
+ if output_path:
1240
+ cv2.imwrite(output_path, vis_image)
1241
+ logger.info(f"💾 Visualization saved to: {output_path}")
1242
+
1243
+ return vis_image
1244
+
1245
+ def convert_to_onnx(self, model_path: str, output_path: str = None) -> bool:
1246
+ """
1247
+ Convert a YOLOv8 or RT-DETR model to ONNX format.
1248
+
1249
+ Args:
1250
+ model_path: Path to model file or 'rtdetr' for loaded RT-DETR
1251
+ output_path: Path for ONNX output (auto-generated if None)
1252
+
1253
+ Returns:
1254
+ True if conversion successful, False otherwise
1255
+ """
1256
+ try:
1257
+ logger.info(f"🔄 Converting {model_path} to ONNX...")
1258
+
1259
+ # Generate output path if not provided
1260
+ if output_path is None:
1261
+ if model_path == 'rtdetr' and self.rtdetr_loaded:
1262
+ base_name = 'rtdetr_comic'
1263
+ else:
1264
+ base_name = Path(model_path).stem
1265
+ output_path = os.path.join(self.cache_dir, f"{base_name}.onnx")
1266
+
1267
+ # Check if already exists
1268
+ if os.path.exists(output_path) and not os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
1269
+ logger.info(f"✅ ONNX model already exists: {output_path}")
1270
+ return True
1271
+
1272
+ # Handle RT-DETR conversion
1273
+ if model_path == 'rtdetr' and self.rtdetr_loaded:
1274
+ if not TORCH_AVAILABLE:
1275
+ logger.error("PyTorch required for RT-DETR ONNX conversion")
1276
+ return False
1277
+
1278
+ # RT-DETR specific conversion
1279
+ self.rtdetr_model.eval()
1280
+
1281
+ # Create dummy input (pixel values): BxCxHxW
1282
+ dummy_input = torch.randn(1, 3, 640, 640)
1283
+ if self.device == 'cuda':
1284
+ dummy_input = dummy_input.to('cuda')
1285
+
1286
+ # Wrap the model to return only tensors (logits, pred_boxes)
1287
+ class _RTDetrExportWrapper(torch.nn.Module):
1288
+ def __init__(self, mdl):
1289
+ super().__init__()
1290
+ self.mdl = mdl
1291
+ def forward(self, images):
1292
+ out = self.mdl(pixel_values=images)
1293
+ # Handle dict/ModelOutput/tuple outputs
1294
+ logits = None
1295
+ boxes = None
1296
+ try:
1297
+ if isinstance(out, dict):
1298
+ logits = out.get('logits', None)
1299
+ boxes = out.get('pred_boxes', out.get('boxes', None))
1300
+ else:
1301
+ logits = getattr(out, 'logits', None)
1302
+ boxes = getattr(out, 'pred_boxes', getattr(out, 'boxes', None))
1303
+ except Exception:
1304
+ pass
1305
+ if (logits is None or boxes is None) and isinstance(out, (tuple, list)) and len(out) >= 2:
1306
+ logits, boxes = out[0], out[1]
1307
+ return logits, boxes
1308
+
1309
+ wrapper = _RTDetrExportWrapper(self.rtdetr_model)
1310
+ if self.device == 'cuda':
1311
+ wrapper = wrapper.to('cuda')
1312
+
1313
+ # Try PyTorch 2.x dynamo_export first (more tolerant of newer aten ops)
1314
+ try:
1315
+ success = False
1316
+ try:
1317
+ from torch.onnx import dynamo_export
1318
+ try:
1319
+ exp = dynamo_export(wrapper, dummy_input)
1320
+ except TypeError:
1321
+ # Older PyTorch dynamo_export may not support this calling convention
1322
+ exp = dynamo_export(wrapper, dummy_input)
1323
+ # exp may have save(); otherwise, it may expose model_proto
1324
+ try:
1325
+ exp.save(output_path) # type: ignore
1326
+ success = True
1327
+ except Exception:
1328
+ try:
1329
+ import onnx as _onnx
1330
+ _onnx.save(exp.model_proto, output_path) # type: ignore
1331
+ success = True
1332
+ except Exception as _se:
1333
+ logger.warning(f"dynamo_export produced model but could not save: {_se}")
1334
+ except Exception as de:
1335
+ logger.warning(f"dynamo_export failed; falling back to legacy exporter: {de}")
1336
+ if success:
1337
+ logger.info(f"✅ RT-DETR ONNX saved to: {output_path} (dynamo_export)")
1338
+ return True
1339
+ except Exception as de2:
1340
+ logger.warning(f"dynamo_export path error: {de2}")
1341
+
1342
+ # Legacy exporter with opset fallback
1343
+ last_err = None
1344
+ for opset in [19, 18, 17, 16, 15, 14, 13]:
1345
+ try:
1346
+ torch.onnx.export(
1347
+ wrapper,
1348
+ dummy_input,
1349
+ output_path,
1350
+ export_params=True,
1351
+ opset_version=opset,
1352
+ do_constant_folding=True,
1353
+ input_names=['pixel_values'],
1354
+ output_names=['logits', 'boxes'],
1355
+ dynamic_axes={
1356
+ 'pixel_values': {0: 'batch', 2: 'height', 3: 'width'},
1357
+ 'logits': {0: 'batch'},
1358
+ 'boxes': {0: 'batch'}
1359
+ }
1360
+ )
1361
+ logger.info(f"✅ RT-DETR ONNX saved to: {output_path} (opset {opset})")
1362
+ return True
1363
+ except Exception as _e:
1364
+ last_err = _e
1365
+ try:
1366
+ msg = str(_e)
1367
+ except Exception:
1368
+ msg = ''
1369
+ logger.warning(f"RT-DETR ONNX export failed at opset {opset}: {msg}")
1370
+ continue
1371
+
1372
+ logger.error(f"All RT-DETR ONNX export attempts failed. Last error: {last_err}")
1373
+ return False
1374
+
1375
+ # Handle YOLOv8 conversion - FIXED
1376
+ elif YOLO_AVAILABLE and os.path.exists(model_path):
1377
+ logger.info(f"Loading YOLOv8 model from: {model_path}")
1378
+
1379
+ # Load model
1380
+ model = YOLO(model_path)
1381
+
1382
+ # Export to ONNX - this returns the path to the exported model
1383
+ logger.info("Exporting to ONNX format...")
1384
+ exported_path = model.export(format='onnx', imgsz=640, simplify=True)
1385
+
1386
+ # exported_path could be a string or Path object
1387
+ exported_path = str(exported_path) if exported_path else None
1388
+
1389
+ if exported_path and os.path.exists(exported_path):
1390
+ # Move to desired location if different
1391
+ if exported_path != output_path:
1392
+ import shutil
1393
+ logger.info(f"Moving ONNX from {exported_path} to {output_path}")
1394
+ shutil.move(exported_path, output_path)
1395
+
1396
+ logger.info(f"✅ YOLOv8 ONNX saved to: {output_path}")
1397
+ return True
1398
+ else:
1399
+ # Fallback: check if it was created with expected name
1400
+ expected_onnx = model_path.replace('.pt', '.onnx')
1401
+ if os.path.exists(expected_onnx):
1402
+ if expected_onnx != output_path:
1403
+ import shutil
1404
+ shutil.move(expected_onnx, output_path)
1405
+ logger.info(f"✅ YOLOv8 ONNX saved to: {output_path}")
1406
+ return True
1407
+ else:
1408
+ logger.error(f"ONNX export failed - no output file found")
1409
+ return False
1410
+
1411
+ else:
1412
+ logger.error(f"Cannot convert {model_path}: Model not found or dependencies missing")
1413
+ return False
1414
+
1415
+ except Exception as e:
1416
+ logger.error(f"Conversion failed: {e}")
1417
+ # Avoid noisy full stack trace in production logs; return False gracefully
1418
+ return False
1419
+
1420
+ def batch_detect(self, image_paths: List[str], **kwargs) -> Dict[str, List[Tuple[int, int, int, int]]]:
1421
+ """
1422
+ Detect bubbles in multiple images.
1423
+
1424
+ Args:
1425
+ image_paths: List of image paths
1426
+ **kwargs: Detection parameters (confidence, iou_threshold, max_detections, use_rtdetr)
1427
+
1428
+ Returns:
1429
+ Dictionary mapping image paths to bubble lists
1430
+ """
1431
+ results = {}
1432
+
1433
+ for i, image_path in enumerate(image_paths):
1434
+ logger.info(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}")
1435
+ bubbles = self.detect_bubbles(image_path, **kwargs)
1436
+ results[image_path] = bubbles
1437
+
1438
+ return results
1439
+
1440
+ def unload(self, release_shared: bool = False):
1441
+ """Release model resources held by this detector instance.
1442
+ Args:
1443
+ release_shared: If True, also clear class-level shared RT-DETR caches.
1444
+ """
1445
+ try:
1446
+ # Release instance-level models and sessions
1447
+ try:
1448
+ if getattr(self, 'onnx_session', None) is not None:
1449
+ self.onnx_session = None
1450
+ except Exception:
1451
+ pass
1452
+ try:
1453
+ if getattr(self, 'rtdetr_onnx_session', None) is not None:
1454
+ self.rtdetr_onnx_session = None
1455
+ except Exception:
1456
+ pass
1457
+ for attr in ['model', 'rtdetr_model', 'rtdetr_processor']:
1458
+ try:
1459
+ if hasattr(self, attr):
1460
+ setattr(self, attr, None)
1461
+ except Exception:
1462
+ pass
1463
+ for flag in ['model_loaded', 'rtdetr_loaded', 'rtdetr_onnx_loaded']:
1464
+ try:
1465
+ if hasattr(self, flag):
1466
+ setattr(self, flag, False)
1467
+ except Exception:
1468
+ pass
1469
+
1470
+ # Optional: release shared caches
1471
+ if release_shared:
1472
+ try:
1473
+ BubbleDetector._rtdetr_shared_model = None
1474
+ BubbleDetector._rtdetr_shared_processor = None
1475
+ BubbleDetector._rtdetr_loaded = False
1476
+ except Exception:
1477
+ pass
1478
+
1479
+ # Free CUDA cache and trigger GC
1480
+ try:
1481
+ if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
1482
+ torch.cuda.empty_cache()
1483
+ except Exception:
1484
+ pass
1485
+ try:
1486
+ import gc
1487
+ gc.collect()
1488
+ except Exception:
1489
+ pass
1490
+ except Exception:
1491
+ # Best-effort only
1492
+ pass
1493
+
1494
+ def get_bubble_masks(self, image_path: str, bubbles: List[Tuple[int, int, int, int]]) -> np.ndarray:
1495
+ """
1496
+ Create a mask image with bubble regions.
1497
+
1498
+ Args:
1499
+ image_path: Path to original image
1500
+ bubbles: List of bubble bounding boxes
1501
+
1502
+ Returns:
1503
+ Binary mask with bubble regions as white (255)
1504
+ """
1505
+ image = cv2.imread(image_path)
1506
+ if image is None:
1507
+ return None
1508
+
1509
+ h, w = image.shape[:2]
1510
+ mask = np.zeros((h, w), dtype=np.uint8)
1511
+
1512
+ # Fill bubble regions
1513
+ for x, y, bw, bh in bubbles:
1514
+ cv2.rectangle(mask, (x, y), (x + bw, y + bh), 255, -1)
1515
+
1516
+ return mask
1517
+
1518
+ def filter_bubbles_by_size(self, bubbles: List[Tuple[int, int, int, int]],
1519
+ min_area: int = 100,
1520
+ max_area: int = None) -> List[Tuple[int, int, int, int]]:
1521
+ """
1522
+ Filter bubbles by area.
1523
+
1524
+ Args:
1525
+ bubbles: List of bubble bounding boxes
1526
+ min_area: Minimum area in pixels
1527
+ max_area: Maximum area in pixels (None for no limit)
1528
+
1529
+ Returns:
1530
+ Filtered list of bubbles
1531
+ """
1532
+ filtered = []
1533
+
1534
+ for x, y, w, h in bubbles:
1535
+ area = w * h
1536
+ if area >= min_area and (max_area is None or area <= max_area):
1537
+ filtered.append((x, y, w, h))
1538
+
1539
+ return filtered
1540
+
1541
+ def merge_overlapping_bubbles(self, bubbles: List[Tuple[int, int, int, int]],
1542
+ overlap_threshold: float = 0.1) -> List[Tuple[int, int, int, int]]:
1543
+ """
1544
+ Merge overlapping bubble detections.
1545
+
1546
+ Args:
1547
+ bubbles: List of bubble bounding boxes
1548
+ overlap_threshold: Minimum overlap ratio to merge
1549
+
1550
+ Returns:
1551
+ Merged list of bubbles
1552
+ """
1553
+ if not bubbles:
1554
+ return []
1555
+
1556
+ # Convert to numpy array for easier manipulation
1557
+ boxes = np.array([(x, y, x+w, y+h) for x, y, w, h in bubbles])
1558
+
1559
+ merged = []
1560
+ used = set()
1561
+
1562
+ for i, box1 in enumerate(boxes):
1563
+ if i in used:
1564
+ continue
1565
+
1566
+ # Start with current box
1567
+ x1, y1, x2, y2 = box1
1568
+
1569
+ # Check for overlaps with remaining boxes
1570
+ for j in range(i + 1, len(boxes)):
1571
+ if j in used:
1572
+ continue
1573
+
1574
+ box2 = boxes[j]
1575
+
1576
+ # Calculate intersection
1577
+ ix1 = max(x1, box2[0])
1578
+ iy1 = max(y1, box2[1])
1579
+ ix2 = min(x2, box2[2])
1580
+ iy2 = min(y2, box2[3])
1581
+
1582
+ if ix1 < ix2 and iy1 < iy2:
1583
+ # Calculate overlap ratio
1584
+ intersection = (ix2 - ix1) * (iy2 - iy1)
1585
+ area1 = (x2 - x1) * (y2 - y1)
1586
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
1587
+ overlap = intersection / min(area1, area2)
1588
+
1589
+ if overlap >= overlap_threshold:
1590
+ # Merge boxes
1591
+ x1 = min(x1, box2[0])
1592
+ y1 = min(y1, box2[1])
1593
+ x2 = max(x2, box2[2])
1594
+ y2 = max(y2, box2[3])
1595
+ used.add(j)
1596
+
1597
+ merged.append((int(x1), int(y1), int(x2 - x1), int(y2 - y1)))
1598
+
1599
+ return merged
1600
+
1601
+ # ============================
1602
+ # RT-DETR (ONNX) BACKEND
1603
+ # ============================
1604
+ def load_rtdetr_onnx_model(self, model_id: str = None, force_reload: bool = False) -> bool:
1605
+ """
1606
+ Load RT-DETR ONNX model using onnxruntime. Downloads detector.onnx and config.json
1607
+ from the provided Hugging Face repo if not already cached.
1608
+ """
1609
+ if not ONNX_AVAILABLE:
1610
+ logger.error("ONNX Runtime not available for RT-DETR ONNX backend")
1611
+ return False
1612
+ try:
1613
+ # If singleton mode and already loaded, just attach shared session
1614
+ try:
1615
+ adv = (self.config or {}).get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
1616
+ singleton = bool(adv.get('use_singleton_models', True))
1617
+ except Exception:
1618
+ singleton = True
1619
+ if singleton and BubbleDetector._rtdetr_onnx_loaded and not force_reload and BubbleDetector._rtdetr_onnx_shared_session is not None:
1620
+ self.rtdetr_onnx_session = BubbleDetector._rtdetr_onnx_shared_session
1621
+ self.rtdetr_onnx_loaded = True
1622
+ return True
1623
+
1624
+ repo = model_id or self.rtdetr_onnx_repo
1625
+ try:
1626
+ from huggingface_hub import hf_hub_download
1627
+ except Exception as e:
1628
+ logger.error(f"huggingface-hub required to fetch RT-DETR ONNX: {e}")
1629
+ return False
1630
+
1631
+ # Ensure local models dir (use configured cache_dir directly: e.g., 'models')
1632
+ cache_dir = self.cache_dir
1633
+ os.makedirs(cache_dir, exist_ok=True)
1634
+
1635
+ # Download files into models/ and avoid symlinks so the file is visible there
1636
+ try:
1637
+ _ = hf_hub_download(repo_id=repo, filename='config.json', cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False)
1638
+ except Exception:
1639
+ pass
1640
+ onnx_fp = hf_hub_download(repo_id=repo, filename='detector.onnx', cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False)
1641
+ BubbleDetector._rtdetr_onnx_model_path = onnx_fp
1642
+
1643
+ # Pick providers: prefer CUDA if available; otherwise CPU. Do NOT use DML.
1644
+ providers = ['CPUExecutionProvider']
1645
+ try:
1646
+ avail = ort.get_available_providers() if ONNX_AVAILABLE else []
1647
+ if 'CUDAExecutionProvider' in avail:
1648
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
1649
+ except Exception:
1650
+ pass
1651
+
1652
+ # Session options with reduced memory arena and optional thread limiting in singleton mode
1653
+ so = ort.SessionOptions()
1654
+ try:
1655
+ so.enable_mem_pattern = False
1656
+ so.enable_cpu_mem_arena = False
1657
+ except Exception:
1658
+ pass
1659
+ # If singleton models mode is enabled in config, limit ORT threading to reduce CPU spikes
1660
+ try:
1661
+ adv = (self.config or {}).get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
1662
+ if bool(adv.get('use_singleton_models', True)):
1663
+ so.intra_op_num_threads = 1
1664
+ so.inter_op_num_threads = 1
1665
+ try:
1666
+ so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
1667
+ except Exception:
1668
+ pass
1669
+ try:
1670
+ so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
1671
+ except Exception:
1672
+ pass
1673
+ except Exception:
1674
+ pass
1675
+
1676
+ # Create session (serialize creation in singleton mode to avoid device storms)
1677
+ if singleton:
1678
+ with BubbleDetector._rtdetr_onnx_init_lock:
1679
+ # Re-check after acquiring lock
1680
+ if BubbleDetector._rtdetr_onnx_loaded and BubbleDetector._rtdetr_onnx_shared_session is not None and not force_reload:
1681
+ self.rtdetr_onnx_session = BubbleDetector._rtdetr_onnx_shared_session
1682
+ self.rtdetr_onnx_loaded = True
1683
+ return True
1684
+ sess = ort.InferenceSession(onnx_fp, providers=providers, sess_options=so)
1685
+ BubbleDetector._rtdetr_onnx_shared_session = sess
1686
+ BubbleDetector._rtdetr_onnx_loaded = True
1687
+ BubbleDetector._rtdetr_onnx_providers = providers
1688
+ self.rtdetr_onnx_session = sess
1689
+ self.rtdetr_onnx_loaded = True
1690
+ else:
1691
+ self.rtdetr_onnx_session = ort.InferenceSession(onnx_fp, providers=providers, sess_options=so)
1692
+ self.rtdetr_onnx_loaded = True
1693
+ logger.info("✅ RT-DETR (ONNX) model ready")
1694
+ return True
1695
+ except Exception as e:
1696
+ logger.error(f"Failed to load RT-DETR ONNX: {e}")
1697
+ self.rtdetr_onnx_session = None
1698
+ self.rtdetr_onnx_loaded = False
1699
+ return False
1700
+
1701
+ def detect_with_rtdetr_onnx(self,
1702
+ image_path: str = None,
1703
+ image: np.ndarray = None,
1704
+ confidence: float = 0.3,
1705
+ return_all_bubbles: bool = False) -> Any:
1706
+ """Detect using RT-DETR ONNX backend.
1707
+ Returns bubbles list if return_all_bubbles else dict by classes similar to PyTorch path.
1708
+ """
1709
+ if not self.rtdetr_onnx_loaded or self.rtdetr_onnx_session is None:
1710
+ logger.warning("RT-DETR ONNX not loaded")
1711
+ return [] if return_all_bubbles else {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1712
+ try:
1713
+ # Acquire image
1714
+ if image_path is not None:
1715
+ import cv2
1716
+ image = cv2.imread(image_path)
1717
+ if image is None:
1718
+ raise RuntimeError(f"Failed to read image: {image_path}")
1719
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1720
+ else:
1721
+ if image is None:
1722
+ raise RuntimeError("No image provided")
1723
+ # Assume image is BGR np.ndarray if from OpenCV
1724
+ try:
1725
+ import cv2
1726
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1727
+ except Exception:
1728
+ image_rgb = image
1729
+
1730
+ # To PIL then resize 640x640 as in reference
1731
+ from PIL import Image as _PILImage
1732
+ pil_image = _PILImage.fromarray(image_rgb)
1733
+ im_resized = pil_image.resize((640, 640))
1734
+ arr = np.asarray(im_resized, dtype=np.float32) / 255.0
1735
+ arr = np.transpose(arr, (2, 0, 1)) # (3,H,W)
1736
+ im_data = arr[np.newaxis, ...]
1737
+
1738
+ w, h = pil_image.size
1739
+ orig_size = np.array([[w, h]], dtype=np.int64)
1740
+
1741
+ # Run with a concurrency guard to prevent device hangs and limit memory usage
1742
+ # Apply semaphore for ALL providers (not just DML) to control concurrency
1743
+ providers = BubbleDetector._rtdetr_onnx_providers or []
1744
+ def _do_run(session):
1745
+ return session.run(None, {
1746
+ 'images': im_data,
1747
+ 'orig_target_sizes': orig_size
1748
+ })
1749
+
1750
+ # Always use semaphore to limit concurrent RT-DETR calls
1751
+ acquired = False
1752
+ try:
1753
+ BubbleDetector._rtdetr_onnx_sema.acquire()
1754
+ acquired = True
1755
+
1756
+ # Special DML error handling
1757
+ if 'DmlExecutionProvider' in providers:
1758
+ try:
1759
+ outputs = _do_run(self.rtdetr_onnx_session)
1760
+ except Exception as dml_err:
1761
+ msg = str(dml_err)
1762
+ if '887A0005' in msg or '887A0006' in msg or 'Dml' in msg:
1763
+ # Rebuild CPU session and retry once
1764
+ try:
1765
+ base_path = BubbleDetector._rtdetr_onnx_model_path
1766
+ if base_path:
1767
+ so = ort.SessionOptions()
1768
+ so.enable_mem_pattern = False
1769
+ so.enable_cpu_mem_arena = False
1770
+ cpu_providers = ['CPUExecutionProvider']
1771
+ # Serialize rebuild
1772
+ with BubbleDetector._rtdetr_onnx_init_lock:
1773
+ sess = ort.InferenceSession(base_path, providers=cpu_providers, sess_options=so)
1774
+ BubbleDetector._rtdetr_onnx_shared_session = sess
1775
+ BubbleDetector._rtdetr_onnx_providers = cpu_providers
1776
+ self.rtdetr_onnx_session = sess
1777
+ outputs = _do_run(self.rtdetr_onnx_session)
1778
+ else:
1779
+ raise
1780
+ except Exception:
1781
+ raise
1782
+ else:
1783
+ raise
1784
+ else:
1785
+ # Non-DML providers - just run directly
1786
+ outputs = _do_run(self.rtdetr_onnx_session)
1787
+ finally:
1788
+ if acquired:
1789
+ try:
1790
+ BubbleDetector._rtdetr_onnx_sema.release()
1791
+ except Exception:
1792
+ pass
1793
+
1794
+ # outputs expected: labels, boxes, scores
1795
+ labels, boxes, scores = outputs[:3]
1796
+ if labels.ndim == 2 and labels.shape[0] == 1:
1797
+ labels = labels[0]
1798
+ if scores.ndim == 2 and scores.shape[0] == 1:
1799
+ scores = scores[0]
1800
+ if boxes.ndim == 3 and boxes.shape[0] == 1:
1801
+ boxes = boxes[0]
1802
+
1803
+ # Apply NMS to remove duplicate detections
1804
+ # Group detections by class and apply NMS per class
1805
+ class_detections = {self.CLASS_BUBBLE: [], self.CLASS_TEXT_BUBBLE: [], self.CLASS_TEXT_FREE: []}
1806
+
1807
+ for lab, box, scr in zip(labels, boxes, scores):
1808
+ if float(scr) < float(confidence):
1809
+ continue
1810
+ label_id = int(lab)
1811
+ if label_id in class_detections:
1812
+ x1, y1, x2, y2 = map(float, box)
1813
+ class_detections[label_id].append((x1, y1, x2, y2, float(scr)))
1814
+
1815
+ # Apply NMS per class to remove duplicates
1816
+ def compute_iou(box1, box2):
1817
+ """Compute IoU between two boxes (x1, y1, x2, y2)"""
1818
+ x1_1, y1_1, x2_1, y2_1 = box1[:4]
1819
+ x1_2, y1_2, x2_2, y2_2 = box2[:4]
1820
+
1821
+ # Intersection
1822
+ x_left = max(x1_1, x1_2)
1823
+ y_top = max(y1_1, y1_2)
1824
+ x_right = min(x2_1, x2_2)
1825
+ y_bottom = min(y2_1, y2_2)
1826
+
1827
+ if x_right < x_left or y_bottom < y_top:
1828
+ return 0.0
1829
+
1830
+ intersection = (x_right - x_left) * (y_bottom - y_top)
1831
+
1832
+ # Union
1833
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
1834
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
1835
+ union = area1 + area2 - intersection
1836
+
1837
+ return intersection / union if union > 0 else 0.0
1838
+
1839
+ def apply_nms(boxes_with_scores, iou_threshold=0.45):
1840
+ """Apply Non-Maximum Suppression"""
1841
+ if not boxes_with_scores:
1842
+ return []
1843
+
1844
+ # Sort by score (descending)
1845
+ sorted_boxes = sorted(boxes_with_scores, key=lambda x: x[4], reverse=True)
1846
+ keep = []
1847
+
1848
+ while sorted_boxes:
1849
+ # Keep the box with highest score
1850
+ current = sorted_boxes.pop(0)
1851
+ keep.append(current)
1852
+
1853
+ # Remove boxes with high IoU
1854
+ sorted_boxes = [box for box in sorted_boxes if compute_iou(current, box) < iou_threshold]
1855
+
1856
+ return keep
1857
+
1858
+ # Apply NMS and build final detections
1859
+ detections = {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1860
+ bubbles_all = []
1861
+
1862
+ for class_id, boxes_list in class_detections.items():
1863
+ nms_boxes = apply_nms(boxes_list, iou_threshold=self.default_iou_threshold)
1864
+
1865
+ for x1, y1, x2, y2, scr in nms_boxes:
1866
+ bbox = (int(x1), int(y1), int(x2 - x1), int(y2 - y1))
1867
+
1868
+ if class_id == self.CLASS_BUBBLE:
1869
+ detections['bubbles'].append(bbox)
1870
+ bubbles_all.append(bbox)
1871
+ elif class_id == self.CLASS_TEXT_BUBBLE:
1872
+ detections['text_bubbles'].append(bbox)
1873
+ bubbles_all.append(bbox)
1874
+ elif class_id == self.CLASS_TEXT_FREE:
1875
+ detections['text_free'].append(bbox)
1876
+
1877
+ return bubbles_all if return_all_bubbles else detections
1878
+ except Exception as e:
1879
+ logger.error(f"RT-DETR ONNX detection failed: {e}")
1880
+ return [] if return_all_bubbles else {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1881
+
1882
+
1883
+ # Standalone utility functions
1884
+ def download_model_from_huggingface(repo_id: str = "ogkalu/comic-speech-bubble-detector-yolov8m",
1885
+ filename: str = "comic-speech-bubble-detector-yolov8m.pt",
1886
+ cache_dir: str = "models") -> str:
1887
+ """
1888
+ Download model from Hugging Face Hub.
1889
+
1890
+ Args:
1891
+ repo_id: Hugging Face repository ID
1892
+ filename: Model filename in the repository
1893
+ cache_dir: Local directory to cache the model
1894
+
1895
+ Returns:
1896
+ Path to downloaded model file
1897
+ """
1898
+ try:
1899
+ from huggingface_hub import hf_hub_download
1900
+
1901
+ os.makedirs(cache_dir, exist_ok=True)
1902
+
1903
+ logger.info(f"📥 Downloading {filename} from {repo_id}...")
1904
+
1905
+ model_path = hf_hub_download(
1906
+ repo_id=repo_id,
1907
+ filename=filename,
1908
+ cache_dir=cache_dir,
1909
+ local_dir=cache_dir
1910
+ )
1911
+
1912
+ logger.info(f"✅ Model downloaded to: {model_path}")
1913
+ return model_path
1914
+
1915
+ except ImportError:
1916
+ logger.error("huggingface-hub package required. Install with: pip install huggingface-hub")
1917
+ return None
1918
+ except Exception as e:
1919
+ logger.error(f"Download failed: {e}")
1920
+ return None
1921
+
1922
+
1923
+ def download_rtdetr_model(cache_dir: str = "models") -> bool:
1924
+ """
1925
+ Download RT-DETR model for advanced detection.
1926
+
1927
+ Args:
1928
+ cache_dir: Directory to cache the model
1929
+
1930
+ Returns:
1931
+ True if successful
1932
+ """
1933
+ if not TRANSFORMERS_AVAILABLE:
1934
+ logger.error("Transformers required. Install with: pip install transformers")
1935
+ return False
1936
+
1937
+ try:
1938
+ logger.info("📥 Downloading RT-DETR model...")
1939
+ from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
1940
+
1941
+ # This will download and cache the model
1942
+ processor = RTDetrImageProcessor.from_pretrained(
1943
+ "ogkalu/comic-text-and-bubble-detector",
1944
+ cache_dir=cache_dir
1945
+ )
1946
+ model = RTDetrForObjectDetection.from_pretrained(
1947
+ "ogkalu/comic-text-and-bubble-detector",
1948
+ cache_dir=cache_dir
1949
+ )
1950
+
1951
+ logger.info("✅ RT-DETR model downloaded successfully")
1952
+ return True
1953
+
1954
+ except Exception as e:
1955
+ logger.error(f"Download failed: {e}")
1956
+ return False
1957
+
1958
+
1959
+ # Example usage and testing
1960
+ if __name__ == "__main__":
1961
+ import sys
1962
+
1963
+ # Create detector
1964
+ detector = BubbleDetector()
1965
+
1966
+ if len(sys.argv) > 1:
1967
+ if sys.argv[1] == "download":
1968
+ # Download model from Hugging Face
1969
+ model_path = download_model_from_huggingface()
1970
+ if model_path:
1971
+ print(f"YOLOv8 model downloaded to: {model_path}")
1972
+
1973
+ # Also download RT-DETR
1974
+ if download_rtdetr_model():
1975
+ print("RT-DETR model downloaded")
1976
+
1977
+ elif sys.argv[1] == "detect" and len(sys.argv) > 3:
1978
+ # Detect bubbles in an image
1979
+ model_path = sys.argv[2]
1980
+ image_path = sys.argv[3]
1981
+
1982
+ # Load appropriate model
1983
+ if 'rtdetr' in model_path.lower():
1984
+ if detector.load_rtdetr_model():
1985
+ # Use RT-DETR
1986
+ results = detector.detect_with_rtdetr(image_path)
1987
+ print(f"RT-DETR Detection:")
1988
+ print(f" Empty bubbles: {len(results['bubbles'])}")
1989
+ print(f" Text bubbles: {len(results['text_bubbles'])}")
1990
+ print(f" Free text: {len(results['text_free'])}")
1991
+ else:
1992
+ if detector.load_model(model_path):
1993
+ bubbles = detector.detect_bubbles(image_path, confidence=0.5)
1994
+ print(f"YOLOv8 detected {len(bubbles)} bubbles:")
1995
+ for i, (x, y, w, h) in enumerate(bubbles):
1996
+ print(f" Bubble {i+1}: position=({x},{y}) size=({w}x{h})")
1997
+
1998
+ # Optionally visualize
1999
+ if len(sys.argv) > 4:
2000
+ output_path = sys.argv[4]
2001
+ detector.visualize_detections(image_path, output_path=output_path,
2002
+ use_rtdetr='rtdetr' in model_path.lower())
2003
+
2004
+ elif sys.argv[1] == "test-both" and len(sys.argv) > 2:
2005
+ # Test both models
2006
+ image_path = sys.argv[2]
2007
+
2008
+ # Load YOLOv8
2009
+ yolo_path = "models/comic-speech-bubble-detector-yolov8m.pt"
2010
+ if os.path.exists(yolo_path):
2011
+ detector.load_model(yolo_path)
2012
+ yolo_bubbles = detector.detect_bubbles(image_path, use_rtdetr=False)
2013
+ print(f"YOLOv8: {len(yolo_bubbles)} bubbles")
2014
+
2015
+ # Load RT-DETR
2016
+ if detector.load_rtdetr_model():
2017
+ rtdetr_bubbles = detector.detect_bubbles(image_path, use_rtdetr=True)
2018
+ print(f"RT-DETR: {len(rtdetr_bubbles)} bubbles")
2019
+
2020
+ else:
2021
+ print("Usage:")
2022
+ print(" python bubble_detector.py download")
2023
+ print(" python bubble_detector.py detect <model_path> <image_path> [output_path]")
2024
+ print(" python bubble_detector.py test-both <image_path>")
2025
+
2026
+ else:
2027
+ print("Bubble Detector Module (YOLOv8 + RT-DETR)")
2028
+ print("Usage:")
2029
+ print(" python bubble_detector.py download")
2030
+ print(" python bubble_detector.py detect <model_path> <image_path> [output_path]")
2031
+ print(" python bubble_detector.py test-both <image_path>")
local_inpainter.py ADDED
The diff for this file is too large to render. See raw diff
 
manga_settings_dialog.py ADDED
The diff for this file is too large to render. See raw diff
 
manga_translator.py ADDED
The diff for this file is too large to render. See raw diff
 
ocr_manager.py ADDED
@@ -0,0 +1,1970 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ocr_manager.py
2
+ """
3
+ OCR Manager for handling multiple OCR providers
4
+ Handles installation, model downloading, and OCR processing
5
+ Updated with HuggingFace donut model and proper bubble detection integration
6
+ """
7
+ import os
8
+ import sys
9
+ import cv2
10
+ import json
11
+ import subprocess
12
+ import threading
13
+ import traceback
14
+ from typing import List, Dict, Optional, Tuple, Any
15
+ import numpy as np
16
+ from dataclasses import dataclass
17
+ from PIL import Image
18
+ import logging
19
+ import time
20
+ import random
21
+ import base64
22
+ import io
23
+ import requests
24
+
25
+ try:
26
+ import gptqmodel
27
+ HAS_GPTQ = True
28
+ except ImportError:
29
+ try:
30
+ import auto_gptq
31
+ HAS_GPTQ = True
32
+ except ImportError:
33
+ HAS_GPTQ = False
34
+
35
+ try:
36
+ import optimum
37
+ HAS_OPTIMUM = True
38
+ except ImportError:
39
+ HAS_OPTIMUM = False
40
+
41
+ try:
42
+ import accelerate
43
+ HAS_ACCELERATE = True
44
+ except ImportError:
45
+ HAS_ACCELERATE = False
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ @dataclass
50
+ class OCRResult:
51
+ """Unified OCR result format with built-in sanitization to prevent data corruption."""
52
+ text: str
53
+ bbox: Tuple[int, int, int, int] # x, y, w, h
54
+ confidence: float
55
+ vertices: Optional[List[Tuple[int, int]]] = None
56
+
57
+ def __post_init__(self):
58
+ """
59
+ This special method is called automatically after the object is created.
60
+ It acts as a final safeguard to ensure the 'text' attribute is ALWAYS a clean string.
61
+ """
62
+ # --- THIS IS THE DEFINITIVE FIX ---
63
+ # If the text we received is a tuple, we extract the first element.
64
+ # This makes it impossible for a tuple to exist in a finished object.
65
+ if isinstance(self.text, tuple):
66
+ # Log that we are fixing a critical data error.
67
+ print(f"CRITICAL WARNING: Corrupted tuple detected in OCRResult. Sanitizing '{self.text}' to '{self.text[0]}'.")
68
+ self.text = self.text[0]
69
+
70
+ # Ensure the final result is always a stripped string.
71
+ self.text = str(self.text).strip()
72
+
73
+ class OCRProvider:
74
+ """Base class for OCR providers"""
75
+
76
+ def __init__(self, log_callback=None):
77
+ # Set thread limits early if environment indicates single-threaded mode
78
+ try:
79
+ if os.environ.get('OMP_NUM_THREADS') == '1':
80
+ # Already in single-threaded mode, ensure it's applied to this process
81
+ try:
82
+ import sys
83
+ if 'torch' in sys.modules:
84
+ import torch
85
+ torch.set_num_threads(1)
86
+ except (ImportError, RuntimeError, AttributeError):
87
+ pass
88
+ try:
89
+ import cv2
90
+ cv2.setNumThreads(1)
91
+ except (ImportError, AttributeError):
92
+ pass
93
+ except Exception:
94
+ pass
95
+
96
+ self.log_callback = log_callback
97
+ self.is_installed = False
98
+ self.is_loaded = False
99
+ self.model = None
100
+ self.stop_flag = None
101
+ self._stopped = False
102
+
103
+ def _log(self, message: str, level: str = "info"):
104
+ """Log message with stop suppression"""
105
+ # Suppress logs when stopped (allow only essential stop confirmation messages)
106
+ if self._check_stop():
107
+ essential_stop_keywords = [
108
+ "⏹️ Translation stopped by user",
109
+ "⏹️ OCR processing stopped",
110
+ "cleanup", "🧹"
111
+ ]
112
+ if not any(keyword in message for keyword in essential_stop_keywords):
113
+ return
114
+
115
+ if self.log_callback:
116
+ self.log_callback(message, level)
117
+ else:
118
+ print(f"[{level.upper()}] {message}")
119
+
120
+ def set_stop_flag(self, stop_flag):
121
+ """Set the stop flag for checking interruptions"""
122
+ self.stop_flag = stop_flag
123
+ self._stopped = False
124
+
125
+ def _check_stop(self) -> bool:
126
+ """Check if stop has been requested"""
127
+ if self._stopped:
128
+ return True
129
+ if self.stop_flag and self.stop_flag.is_set():
130
+ self._stopped = True
131
+ return True
132
+ # Check global manga translator cancellation
133
+ try:
134
+ from manga_translator import MangaTranslator
135
+ if MangaTranslator.is_globally_cancelled():
136
+ self._stopped = True
137
+ return True
138
+ except Exception:
139
+ pass
140
+ return False
141
+
142
+ def reset_stop_flags(self):
143
+ """Reset stop flags when starting new processing"""
144
+ self._stopped = False
145
+
146
+ def check_installation(self) -> bool:
147
+ """Check if provider is installed"""
148
+ raise NotImplementedError
149
+
150
+ def install(self, progress_callback=None) -> bool:
151
+ """Install the provider"""
152
+ raise NotImplementedError
153
+
154
+ def load_model(self, **kwargs) -> bool:
155
+ """Load the OCR model"""
156
+ raise NotImplementedError
157
+
158
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
159
+ """Detect text in image"""
160
+ raise NotImplementedError
161
+
162
+ class CustomAPIProvider(OCRProvider):
163
+ """Custom API OCR provider that uses existing GUI variables"""
164
+
165
+ def __init__(self, log_callback=None):
166
+ super().__init__(log_callback)
167
+
168
+ # Use EXISTING environment variables from TranslatorGUI
169
+ self.api_url = os.environ.get('OPENAI_CUSTOM_BASE_URL', '')
170
+ self.api_key = os.environ.get('API_KEY', '') or os.environ.get('OPENAI_API_KEY', '')
171
+ self.model_name = os.environ.get('MODEL', 'gpt-4o-mini')
172
+
173
+ # OCR prompt - use system prompt or a dedicated OCR prompt variable
174
+ self.ocr_prompt = os.environ.get('OCR_SYSTEM_PROMPT',
175
+ os.environ.get('SYSTEM_PROMPT',
176
+ "YOU ARE A TEXT EXTRACTION MACHINE. EXTRACT EXACTLY WHAT YOU SEE.\n\n"
177
+ "ABSOLUTE RULES:\n"
178
+ "1. OUTPUT ONLY THE VISIBLE TEXT/SYMBOLS - NOTHING ELSE\n"
179
+ "2. NEVER TRANSLATE OR MODIFY\n"
180
+ "3. NEVER EXPLAIN, DESCRIBE, OR COMMENT\n"
181
+ "4. NEVER SAY \"I can't\" or \"I cannot\" or \"no text\" or \"blank image\"\n"
182
+ "5. IF YOU SEE DOTS, OUTPUT THE DOTS: .\n"
183
+ "6. IF YOU SEE PUNCTUATION, OUTPUT THE PUNCTUATION\n"
184
+ "7. IF YOU SEE A SINGLE CHARACTER, OUTPUT THAT CHARACTER\n"
185
+ "8. IF YOU SEE NOTHING, OUTPUT NOTHING (empty response)\n\n"
186
+ "LANGUAGE PRESERVATION:\n"
187
+ "- Korean text → Output in Korean\n"
188
+ "- Japanese text → Output in Japanese\n"
189
+ "- Chinese text → Output in Chinese\n"
190
+ "- English text → Output in English\n"
191
+ "- CJK quotation marks (「」『』【】《》〈〉) → Preserve exactly as shown\n\n"
192
+ "FORMATTING:\n"
193
+ "- OUTPUT ALL TEXT ON A SINGLE LINE WITH NO LINE BREAKS\n"
194
+ "- NEVER use \\n or line breaks in your output\n\n"
195
+ "FORBIDDEN RESPONSES:\n"
196
+ "- \"I can see this appears to be...\"\n"
197
+ "- \"I cannot make out any clear text...\"\n"
198
+ "- \"This appears to be blank...\"\n"
199
+ "- \"If there is text present...\"\n"
200
+ "- ANY explanatory text\n\n"
201
+ "YOUR ONLY OUTPUT: The exact visible text. Nothing more. Nothing less.\n"
202
+ "If image has a dot → Output: .\n"
203
+ "If image has two dots → Output: . .\n"
204
+ "If image has text → Output: [that text]\n"
205
+ "If image is truly blank → Output: [empty/no response]"
206
+ ))
207
+
208
+ # Use existing temperature and token settings
209
+ self.temperature = float(os.environ.get('TRANSLATION_TEMPERATURE', '0.01'))
210
+ # NOTE: max_tokens is NOT cached here - it's read fresh from environment in detect_text()
211
+ # to ensure we always get the latest value from the GUI
212
+
213
+ # Image settings from existing compression variables
214
+ self.image_format = 'jpeg' if os.environ.get('IMAGE_COMPRESSION_FORMAT', 'auto') != 'png' else 'png'
215
+ self.image_quality = int(os.environ.get('JPEG_QUALITY', '100'))
216
+
217
+ # Simple defaults
218
+ self.api_format = 'openai' # Most custom endpoints are OpenAI-compatible
219
+ self.timeout = int(os.environ.get('CHUNK_TIMEOUT', '30'))
220
+ self.api_headers = {} # Additional custom headers
221
+
222
+ # Retry configuration for Custom API OCR calls
223
+ self.max_retries = int(os.environ.get('CUSTOM_OCR_MAX_RETRIES', '3'))
224
+ self.retry_initial_delay = float(os.environ.get('CUSTOM_OCR_RETRY_INITIAL_DELAY', '0.8'))
225
+ self.retry_backoff = float(os.environ.get('CUSTOM_OCR_RETRY_BACKOFF', '1.8'))
226
+ self.retry_jitter = float(os.environ.get('CUSTOM_OCR_RETRY_JITTER', '0.4'))
227
+ self.retry_on_empty = os.environ.get('CUSTOM_OCR_RETRY_ON_EMPTY', '1') == '1'
228
+
229
+ def check_installation(self) -> bool:
230
+ """Always installed - uses UnifiedClient"""
231
+ self.is_installed = True
232
+ return True
233
+
234
+ def install(self, progress_callback=None) -> bool:
235
+ """No installation needed for API-based provider"""
236
+ return self.check_installation()
237
+
238
+ def load_model(self, **kwargs) -> bool:
239
+ """Initialize UnifiedClient with current settings"""
240
+ try:
241
+ from unified_api_client import UnifiedClient
242
+
243
+ # Support passing API key from GUI if available
244
+ if 'api_key' in kwargs:
245
+ api_key = kwargs['api_key']
246
+ else:
247
+ api_key = os.environ.get('API_KEY', '') or os.environ.get('OPENAI_API_KEY', '')
248
+
249
+ if 'model' in kwargs:
250
+ model = kwargs['model']
251
+ else:
252
+ model = os.environ.get('MODEL', 'gpt-4o-mini')
253
+
254
+ if not api_key:
255
+ self._log("❌ No API key configured", "error")
256
+ return False
257
+
258
+ # Create UnifiedClient just like translations do
259
+ self.client = UnifiedClient(model=model, api_key=api_key)
260
+
261
+ #self._log(f"✅ Using {model} for OCR via UnifiedClient")
262
+ self.is_loaded = True
263
+ return True
264
+
265
+ except Exception as e:
266
+ self._log(f"❌ Failed to initialize UnifiedClient: {str(e)}", "error")
267
+ return False
268
+
269
+ def _test_connection(self) -> bool:
270
+ """Test API connection with a simple request"""
271
+ try:
272
+ # Create a small test image
273
+ test_image = np.ones((100, 100, 3), dtype=np.uint8) * 255
274
+ cv2.putText(test_image, "TEST", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
275
+
276
+ # Encode image
277
+ image_base64 = self._encode_image(test_image)
278
+
279
+ # Prepare test request based on API format
280
+ if self.api_format == 'openai':
281
+ test_payload = {
282
+ "model": self.model_name,
283
+ "messages": [
284
+ {
285
+ "role": "user",
286
+ "content": [
287
+ {"type": "text", "text": "What text do you see?"},
288
+ {"type": "image_url", "image_url": {"url": f"data:image/{self.image_format};base64,{image_base64}"}}
289
+ ]
290
+ }
291
+ ],
292
+ "max_tokens": 50
293
+ }
294
+ else:
295
+ # For other formats, just try a basic health check
296
+ return True
297
+
298
+ headers = self._prepare_headers()
299
+ response = requests.post(
300
+ self.api_url,
301
+ headers=headers,
302
+ json=test_payload,
303
+ timeout=10
304
+ )
305
+
306
+ return response.status_code == 200
307
+
308
+ except Exception:
309
+ return False
310
+
311
+ def _encode_image(self, image: np.ndarray) -> str:
312
+ """Encode numpy array to base64 string"""
313
+ # Convert BGR to RGB if needed
314
+ if len(image.shape) == 3 and image.shape[2] == 3:
315
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
316
+ else:
317
+ image_rgb = image
318
+
319
+ # Convert to PIL Image
320
+ pil_image = Image.fromarray(image_rgb)
321
+
322
+ # Save to bytes buffer
323
+ buffer = io.BytesIO()
324
+ if self.image_format.lower() == 'png':
325
+ pil_image.save(buffer, format='PNG')
326
+ else:
327
+ pil_image.save(buffer, format='JPEG', quality=self.image_quality)
328
+
329
+ # Encode to base64
330
+ buffer.seek(0)
331
+ image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
332
+
333
+ return image_base64
334
+
335
+ def _prepare_headers(self) -> dict:
336
+ """Prepare request headers"""
337
+ headers = {
338
+ "Content-Type": "application/json"
339
+ }
340
+
341
+ # Add API key if configured
342
+ if self.api_key:
343
+ if self.api_format == 'anthropic':
344
+ headers["x-api-key"] = self.api_key
345
+ else:
346
+ headers["Authorization"] = f"Bearer {self.api_key}"
347
+
348
+ # Add any custom headers
349
+ headers.update(self.api_headers)
350
+
351
+ return headers
352
+
353
+ def _prepare_request_payload(self, image_base64: str) -> dict:
354
+ """Prepare request payload based on API format"""
355
+ if self.api_format == 'openai':
356
+ return {
357
+ "model": self.model_name,
358
+ "messages": [
359
+ {
360
+ "role": "user",
361
+ "content": [
362
+ {"type": "text", "text": self.ocr_prompt},
363
+ {
364
+ "type": "image_url",
365
+ "image_url": {
366
+ "url": f"data:image/{self.image_format};base64,{image_base64}"
367
+ }
368
+ }
369
+ ]
370
+ }
371
+ ],
372
+ "max_tokens": self.max_tokens,
373
+ "temperature": self.temperature
374
+ }
375
+
376
+ elif self.api_format == 'anthropic':
377
+ return {
378
+ "model": self.model_name,
379
+ "max_tokens": self.max_tokens,
380
+ "temperature": self.temperature,
381
+ "messages": [
382
+ {
383
+ "role": "user",
384
+ "content": [
385
+ {
386
+ "type": "text",
387
+ "text": self.ocr_prompt
388
+ },
389
+ {
390
+ "type": "image",
391
+ "source": {
392
+ "type": "base64",
393
+ "media_type": f"image/{self.image_format}",
394
+ "data": image_base64
395
+ }
396
+ }
397
+ ]
398
+ }
399
+ ]
400
+ }
401
+
402
+ else:
403
+ # Custom format - use environment variable for template
404
+ template = os.environ.get('CUSTOM_OCR_REQUEST_TEMPLATE', '{}')
405
+ payload = json.loads(template)
406
+
407
+ # Replace placeholders
408
+ payload_str = json.dumps(payload)
409
+ payload_str = payload_str.replace('{{IMAGE_BASE64}}', image_base64)
410
+ payload_str = payload_str.replace('{{PROMPT}}', self.ocr_prompt)
411
+ payload_str = payload_str.replace('{{MODEL}}', self.model_name)
412
+ payload_str = payload_str.replace('{{MAX_TOKENS}}', str(self.max_tokens))
413
+ payload_str = payload_str.replace('{{TEMPERATURE}}', str(self.temperature))
414
+
415
+ return json.loads(payload_str)
416
+
417
+ def _extract_text_from_response(self, response_data: dict) -> str:
418
+ """Extract text from API response based on format"""
419
+ try:
420
+ if self.api_format == 'openai':
421
+ # OpenAI format: response.choices[0].message.content
422
+ return response_data.get('choices', [{}])[0].get('message', {}).get('content', '')
423
+
424
+ elif self.api_format == 'anthropic':
425
+ # Anthropic format: response.content[0].text
426
+ content = response_data.get('content', [])
427
+ if content and isinstance(content, list):
428
+ return content[0].get('text', '')
429
+ return ''
430
+
431
+ else:
432
+ # Custom format - use environment variable for path
433
+ response_path = os.environ.get('CUSTOM_OCR_RESPONSE_PATH', 'text')
434
+
435
+ # Navigate through the response using the path
436
+ result = response_data
437
+ for key in response_path.split('.'):
438
+ if isinstance(result, dict):
439
+ result = result.get(key, '')
440
+ elif isinstance(result, list) and key.isdigit():
441
+ idx = int(key)
442
+ result = result[idx] if idx < len(result) else ''
443
+ else:
444
+ result = ''
445
+ break
446
+
447
+ return str(result)
448
+
449
+ except Exception as e:
450
+ self._log(f"Failed to extract text from response: {e}", "error")
451
+ return ''
452
+
453
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
454
+ """Process image using UnifiedClient.send_image()"""
455
+ results = []
456
+
457
+ try:
458
+ # CRITICAL: Reload OCR prompt from environment before each detection
459
+ # This ensures we use the latest prompt set by manga_integration.py
460
+ self.ocr_prompt = os.environ.get('OCR_SYSTEM_PROMPT', self.ocr_prompt)
461
+
462
+ # Get fresh max_tokens from environment - GUI will have set this
463
+ max_tokens = int(os.environ.get('MAX_OUTPUT_TOKENS', '8192'))
464
+ if not self.is_loaded:
465
+ if not self.load_model():
466
+ return results
467
+
468
+ import cv2
469
+ from PIL import Image
470
+ import base64
471
+ import io
472
+
473
+ # Validate and resize image if too small (consistent with Google/Azure logic)
474
+ h, w = image.shape[:2]
475
+ MIN_SIZE = 50 # Minimum dimension for good OCR quality
476
+ MIN_AREA = 2500 # Minimum area (50x50)
477
+
478
+ # Skip completely invalid/corrupted images (0 or negative dimensions)
479
+ if h <= 0 or w <= 0:
480
+ self._log(f"⚠️ Invalid image dimensions ({w}x{h}px), skipping", "warning")
481
+ return results
482
+
483
+ if h < MIN_SIZE or w < MIN_SIZE or h * w < MIN_AREA:
484
+ # Image too small - resize it
485
+ scale_w = MIN_SIZE / w if w < MIN_SIZE else 1.0
486
+ scale_h = MIN_SIZE / h if h < MIN_SIZE else 1.0
487
+ scale = max(scale_w, scale_h)
488
+
489
+ if scale > 1.0:
490
+ new_w = int(w * scale)
491
+ new_h = int(h * scale)
492
+ image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
493
+ self._log(f"🔍 Image resized from {w}x{h}px to {new_w}x{new_h}px for Custom API OCR", "debug")
494
+ h, w = new_h, new_w
495
+
496
+ # Convert numpy array to PIL Image
497
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
498
+ pil_image = Image.fromarray(image_rgb)
499
+
500
+ # Convert PIL Image to base64 string
501
+ buffer = io.BytesIO()
502
+
503
+ # Use the image format from settings
504
+ if self.image_format.lower() == 'png':
505
+ pil_image.save(buffer, format='PNG')
506
+ else:
507
+ pil_image.save(buffer, format='JPEG', quality=self.image_quality)
508
+
509
+ buffer.seek(0)
510
+ image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
511
+
512
+ # For OpenAI vision models, we need BOTH:
513
+ # 1. System prompt with instructions
514
+ # 2. User message that includes the image
515
+ messages = [
516
+ {
517
+ "role": "system",
518
+ "content": self.ocr_prompt # The OCR instruction as system prompt
519
+ },
520
+ {
521
+ "role": "user",
522
+ "content": [
523
+ {
524
+ "type": "text",
525
+ "text": "Image:" # Minimal text, just to have something
526
+ },
527
+ {
528
+ "type": "image_url",
529
+ "image_url": {
530
+ "url": f"data:image/jpeg;base64,{image_base64}"
531
+ }
532
+ }
533
+ ]
534
+ }
535
+ ]
536
+
537
+ # Now send this properly formatted message
538
+ # The UnifiedClient should handle this correctly
539
+ # But we're NOT using send_image, we're using regular send
540
+
541
+ # Retry-aware call
542
+ from unified_api_client import UnifiedClientError # local import to avoid hard dependency at module import time
543
+ max_attempts = max(1, self.max_retries)
544
+ attempt = 0
545
+ last_error = None
546
+
547
+ # Common refusal/error phrases that indicate a non-OCR response (expanded list)
548
+ refusal_phrases = [
549
+ "I can't extract", "I cannot extract",
550
+ "I'm sorry", "I am sorry",
551
+ "I'm unable", "I am unable",
552
+ "cannot process images",
553
+ "I can't help with that",
554
+ "cannot view images",
555
+ "no text in the image",
556
+ "I can see this appears",
557
+ "I cannot make out",
558
+ "appears to be blank",
559
+ "appears to be a mostly blank",
560
+ "mostly blank or white image",
561
+ "If there is text present",
562
+ "too small, faint, or unclear",
563
+ "cannot accurately extract",
564
+ "may be too",
565
+ "However, I cannot",
566
+ "I don't see any",
567
+ "no clear text",
568
+ "no visible text",
569
+ "does not contain",
570
+ "doesn't contain",
571
+ "I do not see"
572
+ ]
573
+
574
+ while attempt < max_attempts:
575
+ # Check for stop before each attempt
576
+ if self._check_stop():
577
+ self._log("⏹️ OCR processing stopped by user", "warning")
578
+ return results
579
+
580
+ try:
581
+ response = self.client.send(
582
+ messages=messages,
583
+ temperature=self.temperature,
584
+ max_tokens=max_tokens
585
+ )
586
+
587
+ # Extract content from response object
588
+ content, finish_reason = response
589
+
590
+ # Validate content
591
+ has_content = bool(content and str(content).strip())
592
+ refused = False
593
+ if has_content:
594
+ # Filter out explicit failure markers
595
+ if "[" in content and "FAILED]" in content:
596
+ refused = True
597
+ elif any(phrase.lower() in content.lower() for phrase in refusal_phrases):
598
+ refused = True
599
+
600
+ # Decide success or retry
601
+ if has_content and not refused:
602
+ text = str(content).strip()
603
+ results.append(OCRResult(
604
+ text=text,
605
+ bbox=(0, 0, w, h),
606
+ confidence=kwargs.get('confidence', 0.85),
607
+ vertices=[(0, 0), (w, 0), (w, h), (0, h)]
608
+ ))
609
+ self._log(f"✅ Detected: {text[:50]}...")
610
+ break # success
611
+ else:
612
+ reason = "empty result" if not has_content else "refusal/non-OCR response"
613
+ last_error = f"{reason} (finish_reason: {finish_reason})"
614
+ # Check if we should retry on empty or refusal
615
+ should_retry = (not has_content and self.retry_on_empty) or refused
616
+ attempt += 1
617
+ if attempt >= max_attempts or not should_retry:
618
+ # No more retries or shouldn't retry
619
+ if not has_content:
620
+ self._log(f"⚠️ No text detected (finish_reason: {finish_reason})")
621
+ else:
622
+ self._log(f"❌ Model returned non-OCR response: {str(content)[:120]}", "warning")
623
+ break
624
+ # Backoff before retrying
625
+ delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
626
+ self._log(f"🔄 Retry {attempt}/{max_attempts - 1} after {delay:.1f}s due to {reason}...", "warning")
627
+ time.sleep(delay)
628
+ time.sleep(0.1) # Brief pause for stability
629
+ self._log("💤 OCR retry pausing briefly for stability", "debug")
630
+ continue
631
+
632
+ except UnifiedClientError as ue:
633
+ msg = str(ue)
634
+ last_error = msg
635
+ # Do not retry on explicit user cancellation
636
+ if 'cancelled' in msg.lower() or 'stopped by user' in msg.lower():
637
+ self._log(f"❌ OCR cancelled: {msg}", "error")
638
+ break
639
+ attempt += 1
640
+ if attempt >= max_attempts:
641
+ self._log(f"❌ OCR failed after {attempt} attempts: {msg}", "error")
642
+ break
643
+ delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
644
+ self._log(f"🔄 API error, retry {attempt}/{max_attempts - 1} after {delay:.1f}s: {msg}", "warning")
645
+ time.sleep(delay)
646
+ time.sleep(0.1) # Brief pause for stability
647
+ self._log("💤 OCR API error retry pausing briefly for stability", "debug")
648
+ continue
649
+ except Exception as e_inner:
650
+ last_error = str(e_inner)
651
+ attempt += 1
652
+ if attempt >= max_attempts:
653
+ self._log(f"❌ OCR exception after {attempt} attempts: {last_error}", "error")
654
+ break
655
+ delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
656
+ self._log(f"🔄 Exception, retry {attempt}/{max_attempts - 1} after {delay:.1f}s: {last_error}", "warning")
657
+ time.sleep(delay)
658
+ time.sleep(0.1) # Brief pause for stability
659
+ self._log("💤 OCR exception retry pausing briefly for stability", "debug")
660
+ continue
661
+
662
+ except Exception as e:
663
+ self._log(f"❌ Error: {str(e)}", "error")
664
+ import traceback
665
+ self._log(traceback.format_exc(), "debug")
666
+
667
+ return results
668
+
669
+ class MangaOCRProvider(OCRProvider):
670
+ """Manga OCR provider using HuggingFace model directly"""
671
+
672
+ def __init__(self, log_callback=None):
673
+ super().__init__(log_callback)
674
+ self.processor = None
675
+ self.model = None
676
+ self.tokenizer = None
677
+
678
+ def check_installation(self) -> bool:
679
+ """Check if transformers is installed"""
680
+ try:
681
+ import transformers
682
+ import torch
683
+ self.is_installed = True
684
+ return True
685
+ except ImportError:
686
+ return False
687
+
688
+ def install(self, progress_callback=None) -> bool:
689
+ """Install transformers and torch"""
690
+ pass
691
+
692
+ def _is_valid_local_model_dir(self, path: str) -> bool:
693
+ """Check that a local HF model directory has required files."""
694
+ try:
695
+ if not path or not os.path.isdir(path):
696
+ return False
697
+ needed_any_weights = any(
698
+ os.path.exists(os.path.join(path, name)) for name in (
699
+ 'pytorch_model.bin',
700
+ 'model.safetensors'
701
+ )
702
+ )
703
+ has_config = os.path.exists(os.path.join(path, 'config.json'))
704
+ has_processor = (
705
+ os.path.exists(os.path.join(path, 'preprocessor_config.json')) or
706
+ os.path.exists(os.path.join(path, 'processor_config.json'))
707
+ )
708
+ has_tokenizer = (
709
+ os.path.exists(os.path.join(path, 'tokenizer.json')) or
710
+ os.path.exists(os.path.join(path, 'tokenizer_config.json'))
711
+ )
712
+ return has_config and needed_any_weights and has_processor and has_tokenizer
713
+ except Exception:
714
+ return False
715
+
716
+ def load_model(self, **kwargs) -> bool:
717
+ """Load the manga-ocr model, preferring a local directory to avoid re-downloading"""
718
+ print("\n>>> MangaOCRProvider.load_model() called")
719
+ try:
720
+ if not self.is_installed and not self.check_installation():
721
+ print("ERROR: Transformers not installed")
722
+ self._log("❌ Transformers not installed", "error")
723
+ return False
724
+
725
+ # Always disable progress bars to avoid tqdm issues in some environments
726
+ import os
727
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
728
+
729
+ from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoImageProcessor
730
+ import torch
731
+
732
+ # Prefer a local model directory if present to avoid any Hub access
733
+ candidates = []
734
+ env_local = os.environ.get("MANGA_OCR_LOCAL_DIR")
735
+ if env_local:
736
+ candidates.append(env_local)
737
+
738
+ # Project root one level up from this file
739
+ root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
740
+ candidates.append(os.path.join(root_dir, 'models', 'manga-ocr-base'))
741
+ candidates.append(os.path.join(root_dir, 'models', 'kha-white', 'manga-ocr-base'))
742
+
743
+ model_source = None
744
+ local_only = False
745
+ # Find a valid local dir
746
+ for cand in candidates:
747
+ if self._is_valid_local_model_dir(cand):
748
+ model_source = cand
749
+ local_only = True
750
+ break
751
+
752
+ # If no valid local dir, use Hub
753
+ if not model_source:
754
+ model_source = "kha-white/manga-ocr-base"
755
+ # Make sure we are not forcing offline mode
756
+ if os.environ.get("HF_HUB_OFFLINE") == "1":
757
+ try:
758
+ del os.environ["HF_HUB_OFFLINE"]
759
+ except Exception:
760
+ pass
761
+ self._log("🔥 Loading manga-ocr model from Hugging Face Hub")
762
+ self._log(f" Repo: {model_source}")
763
+ else:
764
+ # Only set offline when local dir is fully valid
765
+ os.environ.setdefault("HF_HUB_OFFLINE", "1")
766
+ self._log("🔥 Loading manga-ocr model from local directory")
767
+ self._log(f" Local path: {model_source}")
768
+
769
+ # Decide target device once; we will move after full CPU load to avoid meta tensors
770
+ use_cuda = torch.cuda.is_available()
771
+
772
+ # Try loading components, falling back to Hub if local-only fails
773
+ def _load_components(source: str, local_flag: bool):
774
+ self._log(" Loading tokenizer...")
775
+ tok = AutoTokenizer.from_pretrained(source, local_files_only=local_flag)
776
+
777
+ self._log(" Loading image processor...")
778
+ try:
779
+ from transformers import AutoProcessor
780
+ except Exception:
781
+ AutoProcessor = None
782
+ try:
783
+ proc = AutoImageProcessor.from_pretrained(source, local_files_only=local_flag)
784
+ except Exception as e_proc:
785
+ if AutoProcessor is not None:
786
+ self._log(f" ⚠️ AutoImageProcessor failed: {e_proc}. Trying AutoProcessor...", "warning")
787
+ proc = AutoProcessor.from_pretrained(source, local_files_only=local_flag)
788
+ else:
789
+ raise
790
+
791
+ self._log(" Loading model...")
792
+ # Prevent meta tensors by forcing full materialization on CPU at load time
793
+ os.environ.setdefault('TORCHDYNAMO_DISABLE', '1')
794
+ mdl = VisionEncoderDecoderModel.from_pretrained(
795
+ source,
796
+ local_files_only=local_flag,
797
+ low_cpu_mem_usage=False,
798
+ device_map=None,
799
+ torch_dtype=torch.float32 # Use torch_dtype instead of dtype
800
+ )
801
+ return tok, proc, mdl
802
+
803
+ try:
804
+ self.tokenizer, self.processor, self.model = _load_components(model_source, local_only)
805
+ except Exception as e_local:
806
+ if local_only:
807
+ # Fallback to Hub once if local fails
808
+ self._log(f" ⚠️ Local model load failed: {e_local}", "warning")
809
+ try:
810
+ if os.environ.get("HF_HUB_OFFLINE") == "1":
811
+ del os.environ["HF_HUB_OFFLINE"]
812
+ except Exception:
813
+ pass
814
+ model_source = "kha-white/manga-ocr-base"
815
+ local_only = False
816
+ self._log(" Retrying from Hugging Face Hub...")
817
+ self.tokenizer, self.processor, self.model = _load_components(model_source, local_only)
818
+ else:
819
+ raise
820
+
821
+ # Move to CUDA only after full CPU materialization
822
+ target_device = 'cpu'
823
+ if use_cuda:
824
+ try:
825
+ self.model = self.model.to('cuda')
826
+ target_device = 'cuda'
827
+ except Exception as move_err:
828
+ self._log(f" ⚠️ Could not move model to CUDA: {move_err}", "warning")
829
+ target_device = 'cpu'
830
+
831
+ # Finalize eval mode
832
+ self.model.eval()
833
+
834
+ # Sanity-check: ensure no parameter remains on 'meta' device
835
+ try:
836
+ for n, p in self.model.named_parameters():
837
+ dev = getattr(p, 'device', None)
838
+ if dev is not None and getattr(dev, 'type', '') == 'meta':
839
+ raise RuntimeError(f"Parameter {n} is on 'meta' after load")
840
+ except Exception as sanity_err:
841
+ self._log(f"❌ Manga-OCR model load sanity check failed: {sanity_err}", "error")
842
+ return False
843
+
844
+ print(f"SUCCESS: Model loaded on {target_device.upper()}")
845
+ self._log(f" ✅ Model loaded on {target_device.upper()}")
846
+ self.is_loaded = True
847
+ self._log("✅ Manga OCR model ready")
848
+ print(">>> Returning True from load_model()")
849
+ return True
850
+
851
+ except Exception as e:
852
+ print(f"\nEXCEPTION in load_model: {e}")
853
+ import traceback
854
+ print(traceback.format_exc())
855
+ self._log(f"❌ Failed to load manga-ocr model: {str(e)}", "error")
856
+ self._log(traceback.format_exc(), "error")
857
+ try:
858
+ if 'local_only' in locals() and local_only:
859
+ self._log("Hint: Local load failed. Ensure your models/manga-ocr-base contains required files (config.json, preprocessor_config.json, tokenizer.json or tokenizer_config.json, and model weights).", "warning")
860
+ except Exception:
861
+ pass
862
+ return False
863
+
864
+ def _run_ocr(self, pil_image):
865
+ """Run OCR on a PIL image using the HuggingFace model"""
866
+ import torch
867
+
868
+ # Process image (keyword arg for broader compatibility across transformers versions)
869
+ inputs = self.processor(images=pil_image, return_tensors="pt")
870
+ pixel_values = inputs["pixel_values"]
871
+
872
+ # Move to same device as model
873
+ try:
874
+ model_device = next(self.model.parameters()).device
875
+ except StopIteration:
876
+ model_device = torch.device('cpu')
877
+ pixel_values = pixel_values.to(model_device)
878
+
879
+ # Generate text
880
+ with torch.no_grad():
881
+ generated_ids = self.model.generate(pixel_values)
882
+
883
+ # Decode
884
+ generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
885
+
886
+ return generated_text
887
+
888
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
889
+ """
890
+ Process the image region passed to it.
891
+ This could be a bubble region or the full image.
892
+ """
893
+ results = []
894
+
895
+ # Check for stop at start
896
+ if self._check_stop():
897
+ self._log("⏹️ Manga-OCR processing stopped by user", "warning")
898
+ return results
899
+
900
+ try:
901
+ if not self.is_loaded:
902
+ if not self.load_model():
903
+ return results
904
+
905
+ import cv2
906
+ from PIL import Image
907
+
908
+ # Get confidence from kwargs
909
+ confidence = kwargs.get('confidence', 0.7)
910
+
911
+ # Convert numpy array to PIL
912
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
913
+ pil_image = Image.fromarray(image_rgb)
914
+ h, w = image.shape[:2]
915
+
916
+ self._log("🔍 Processing region with manga-ocr...")
917
+
918
+ # Check for stop before inference
919
+ if self._check_stop():
920
+ self._log("⏹️ Manga-OCR inference stopped by user", "warning")
921
+ return results
922
+
923
+ # Run OCR on the image region
924
+ text = self._run_ocr(pil_image)
925
+
926
+ if text and text.strip():
927
+ # Return result for this region with its actual bbox
928
+ results.append(OCRResult(
929
+ text=text.strip(),
930
+ bbox=(0, 0, w, h), # Relative to the region passed in
931
+ confidence=confidence,
932
+ vertices=[(0, 0), (w, 0), (w, h), (0, h)]
933
+ ))
934
+ self._log(f"✅ Detected text: {text[:50]}...")
935
+
936
+ except Exception as e:
937
+ self._log(f"❌ Error in manga-ocr: {str(e)}", "error")
938
+
939
+ return results
940
+
941
+ class Qwen2VL(OCRProvider):
942
+ """OCR using Qwen2-VL - Vision Language Model that can read Korean text"""
943
+
944
+ def __init__(self, log_callback=None):
945
+ super().__init__(log_callback)
946
+ self.processor = None
947
+ self.model = None
948
+ self.tokenizer = None
949
+
950
+ # Get OCR prompt from environment or use default (UPDATED: Improved prompt)
951
+ self.ocr_prompt = os.environ.get('OCR_SYSTEM_PROMPT',
952
+ "YOU ARE A TEXT EXTRACTION MACHINE. EXTRACT EXACTLY WHAT YOU SEE.\n\n"
953
+ "ABSOLUTE RULES:\n"
954
+ "1. OUTPUT ONLY THE VISIBLE TEXT/SYMBOLS - NOTHING ELSE\n"
955
+ "2. NEVER TRANSLATE OR MODIFY\n"
956
+ "3. NEVER EXPLAIN, DESCRIBE, OR COMMENT\n"
957
+ "4. NEVER SAY \"I can't\" or \"I cannot\" or \"no text\" or \"blank image\"\n"
958
+ "5. IF YOU SEE DOTS, OUTPUT THE DOTS: .\n"
959
+ "6. IF YOU SEE PUNCTUATION, OUTPUT THE PUNCTUATION\n"
960
+ "7. IF YOU SEE A SINGLE CHARACTER, OUTPUT THAT CHARACTER\n"
961
+ "8. IF YOU SEE NOTHING, OUTPUT NOTHING (empty response)\n\n"
962
+ "LANGUAGE PRESERVATION:\n"
963
+ "- Korean text → Output in Korean\n"
964
+ "- Japanese text → Output in Japanese\n"
965
+ "- Chinese text → Output in Chinese\n"
966
+ "- English text → Output in English\n"
967
+ "- CJK quotation marks (「」『』【】《》〈〉) → Preserve exactly as shown\n\n"
968
+ "FORMATTING:\n"
969
+ "- OUTPUT ALL TEXT ON A SINGLE LINE WITH NO LINE BREAKS\n"
970
+ "- NEVER use \\n or line breaks in your output\n\n"
971
+ "FORBIDDEN RESPONSES:\n"
972
+ "- \"I can see this appears to be...\"\n"
973
+ "- \"I cannot make out any clear text...\"\n"
974
+ "- \"This appears to be blank...\"\n"
975
+ "- \"If there is text present...\"\n"
976
+ "- ANY explanatory text\n\n"
977
+ "YOUR ONLY OUTPUT: The exact visible text. Nothing more. Nothing less.\n"
978
+ "If image has a dot → Output: .\n"
979
+ "If image has two dots → Output: . .\n"
980
+ "If image has text → Output: [that text]\n"
981
+ "If image is truly blank → Output: [empty/no response]"
982
+ )
983
+
984
+ def set_ocr_prompt(self, prompt: str):
985
+ """Allow setting the OCR prompt dynamically"""
986
+ self.ocr_prompt = prompt
987
+
988
+ def check_installation(self) -> bool:
989
+ """Check if required packages are installed"""
990
+ try:
991
+ import transformers
992
+ import torch
993
+ self.is_installed = True
994
+ return True
995
+ except ImportError:
996
+ return False
997
+
998
+ def install(self, progress_callback=None) -> bool:
999
+ """Install requirements for Qwen2-VL"""
1000
+ pass
1001
+
1002
+ def load_model(self, model_size=None, **kwargs) -> bool:
1003
+ """Load Qwen2-VL model with size selection"""
1004
+ self._log(f"DEBUG: load_model called with model_size={model_size}")
1005
+
1006
+ try:
1007
+ if not self.is_installed and not self.check_installation():
1008
+ self._log("❌ Not installed", "error")
1009
+ return False
1010
+
1011
+ self._log("🔥 Loading Qwen2-VL for Advanced OCR...")
1012
+
1013
+
1014
+
1015
+ from transformers import AutoProcessor, AutoTokenizer
1016
+ import torch
1017
+
1018
+ # Model options
1019
+ model_options = {
1020
+ "1": "Qwen/Qwen2-VL-2B-Instruct",
1021
+ "2": "Qwen/Qwen2-VL-7B-Instruct",
1022
+ "3": "Qwen/Qwen2-VL-72B-Instruct",
1023
+ "4": "custom"
1024
+ }
1025
+ # CHANGE: Default to 7B instead of 2B
1026
+ # Check for saved preference first
1027
+ if model_size is None:
1028
+ # Try to get from environment or config
1029
+ import os
1030
+ model_size = os.environ.get('QWEN2VL_MODEL_SIZE', '1')
1031
+
1032
+ # Determine which model to load
1033
+ if model_size and str(model_size).startswith("custom:"):
1034
+ # Custom model passed with ID
1035
+ model_id = str(model_size).replace("custom:", "")
1036
+ self.loaded_model_size = "Custom"
1037
+ self.model_id = model_id
1038
+ self._log(f"Loading custom model: {model_id}")
1039
+ elif model_size == "4":
1040
+ # Custom option selected but no ID - shouldn't happen
1041
+ self._log("❌ Custom model selected but no ID provided", "error")
1042
+ return False
1043
+ elif model_size and str(model_size) in model_options:
1044
+ # Standard model option
1045
+ option = model_options[str(model_size)]
1046
+ if option == "custom":
1047
+ self._log("❌ Custom model needs an ID", "error")
1048
+ return False
1049
+ model_id = option
1050
+ # Set loaded_model_size for status display
1051
+ if model_size == "1":
1052
+ self.loaded_model_size = "2B"
1053
+ elif model_size == "2":
1054
+ self.loaded_model_size = "7B"
1055
+ elif model_size == "3":
1056
+ self.loaded_model_size = "72B"
1057
+ else:
1058
+ # CHANGE: Default to 7B (option "2") instead of 2B
1059
+ model_id = model_options["1"] # Changed from "1" to "2"
1060
+ self.loaded_model_size = "2B" # Changed from "2B" to "7B"
1061
+ self._log("No model size specified, defaulting to 2B") # Changed message
1062
+
1063
+ self._log(f"Loading model: {model_id}")
1064
+
1065
+ # Load processor and tokenizer
1066
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
1067
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
1068
+
1069
+ # Load the model - let it figure out the class dynamically
1070
+ if torch.cuda.is_available():
1071
+ self._log(f"GPU: {torch.cuda.get_device_name(0)}")
1072
+ # Use auto model class
1073
+ from transformers import AutoModelForVision2Seq
1074
+ self.model = AutoModelForVision2Seq.from_pretrained(
1075
+ model_id,
1076
+ dtype=torch.float16,
1077
+ device_map="auto",
1078
+ trust_remote_code=True
1079
+ )
1080
+ self._log("✅ Model loaded on GPU")
1081
+ else:
1082
+ self._log("Loading on CPU...")
1083
+ from transformers import AutoModelForVision2Seq
1084
+ self.model = AutoModelForVision2Seq.from_pretrained(
1085
+ model_id,
1086
+ dtype=torch.float32,
1087
+ trust_remote_code=True
1088
+ )
1089
+ self._log("✅ Model loaded on CPU")
1090
+
1091
+ self.model.eval()
1092
+ self.is_loaded = True
1093
+ self._log("✅ Qwen2-VL ready for Advanced OCR!")
1094
+ return True
1095
+
1096
+ except Exception as e:
1097
+ self._log(f"❌ Failed to load: {str(e)}", "error")
1098
+ import traceback
1099
+ self._log(traceback.format_exc(), "debug")
1100
+ return False
1101
+
1102
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1103
+ """Process image with Qwen2-VL for Korean text extraction"""
1104
+ results = []
1105
+ if hasattr(self, 'model_id'):
1106
+ self._log(f"DEBUG: Using model: {self.model_id}", "debug")
1107
+
1108
+ # Check if OCR prompt was passed in kwargs (for dynamic updates)
1109
+ if 'ocr_prompt' in kwargs:
1110
+ self.ocr_prompt = kwargs['ocr_prompt']
1111
+
1112
+ try:
1113
+ if not self.is_loaded:
1114
+ if not self.load_model():
1115
+ return results
1116
+
1117
+ import cv2
1118
+ from PIL import Image
1119
+ import torch
1120
+
1121
+ # Convert to PIL
1122
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1123
+ pil_image = Image.fromarray(image_rgb)
1124
+ h, w = image.shape[:2]
1125
+
1126
+ self._log(f"🔍 Processing with Qwen2-VL ({w}x{h} pixels)...")
1127
+
1128
+ # Use the configurable OCR prompt
1129
+ messages = [
1130
+ {
1131
+ "role": "user",
1132
+ "content": [
1133
+ {
1134
+ "type": "image",
1135
+ "image": pil_image,
1136
+ },
1137
+ {
1138
+ "type": "text",
1139
+ "text": self.ocr_prompt # Use the configurable prompt
1140
+ }
1141
+ ]
1142
+ }
1143
+ ]
1144
+
1145
+ # Alternative simpler prompt if the above still causes issues:
1146
+ # "text": "OCR: Extract text as-is"
1147
+
1148
+ # Process with Qwen2-VL
1149
+ text = self.processor.apply_chat_template(
1150
+ messages,
1151
+ tokenize=False,
1152
+ add_generation_prompt=True
1153
+ )
1154
+
1155
+ inputs = self.processor(
1156
+ text=[text],
1157
+ images=[pil_image],
1158
+ padding=True,
1159
+ return_tensors="pt"
1160
+ )
1161
+
1162
+ # Get the device and dtype the model is currently on
1163
+ model_device = next(self.model.parameters()).device
1164
+ model_dtype = next(self.model.parameters()).dtype
1165
+
1166
+ # Move inputs to the same device as the model and cast float tensors to model dtype
1167
+ try:
1168
+ # Move first
1169
+ inputs = inputs.to(model_device)
1170
+ # Then align dtypes only for floating tensors (e.g., pixel_values)
1171
+ for k, v in inputs.items():
1172
+ if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
1173
+ inputs[k] = v.to(model_dtype)
1174
+ except Exception:
1175
+ # Fallback: ensure at least pixel_values is correct if present
1176
+ try:
1177
+ if isinstance(inputs, dict) and "pixel_values" in inputs:
1178
+ pv = inputs["pixel_values"].to(model_device)
1179
+ if torch.is_floating_point(pv):
1180
+ inputs["pixel_values"] = pv.to(model_dtype)
1181
+ except Exception:
1182
+ pass
1183
+
1184
+ # Ensure pixel_values explicitly matches model dtype if present
1185
+ try:
1186
+ if isinstance(inputs, dict) and "pixel_values" in inputs:
1187
+ inputs["pixel_values"] = inputs["pixel_values"].to(device=model_device, dtype=model_dtype)
1188
+ except Exception:
1189
+ pass
1190
+
1191
+ # Generate text with stricter parameters to avoid creative responses
1192
+ use_amp = (hasattr(torch, 'cuda') and model_device.type == 'cuda' and model_dtype in (torch.float16, torch.bfloat16))
1193
+ autocast_dev = 'cuda' if model_device.type == 'cuda' else 'cpu'
1194
+ autocast_dtype = model_dtype if model_dtype in (torch.float16, torch.bfloat16) else None
1195
+
1196
+ with torch.no_grad():
1197
+ if use_amp and autocast_dtype is not None:
1198
+ with torch.autocast(autocast_dev, dtype=autocast_dtype):
1199
+ generated_ids = self.model.generate(
1200
+ **inputs,
1201
+ max_new_tokens=128, # Reduced from 512 - manga bubbles are typically short
1202
+ do_sample=False, # Keep deterministic
1203
+ temperature=0.01, # Keep your very low temperature
1204
+ top_p=1.0, # Keep no nucleus sampling
1205
+ repetition_penalty=1.0, # Keep no repetition penalty
1206
+ num_beams=1, # Ensure greedy decoding (faster than beam search)
1207
+ use_cache=True, # Enable KV cache for speed
1208
+ early_stopping=True, # Stop at EOS token
1209
+ pad_token_id=self.tokenizer.pad_token_id, # Proper padding
1210
+ eos_token_id=self.tokenizer.eos_token_id, # Proper stopping
1211
+ )
1212
+ else:
1213
+ generated_ids = self.model.generate(
1214
+ **inputs,
1215
+ max_new_tokens=128, # Reduced from 512 - manga bubbles are typically short
1216
+ do_sample=False, # Keep deterministic
1217
+ temperature=0.01, # Keep your very low temperature
1218
+ top_p=1.0, # Keep no nucleus sampling
1219
+ repetition_penalty=1.0, # Keep no repetition penalty
1220
+ num_beams=1, # Ensure greedy decoding (faster than beam search)
1221
+ use_cache=True, # Enable KV cache for speed
1222
+ early_stopping=True, # Stop at EOS token
1223
+ pad_token_id=self.tokenizer.pad_token_id, # Proper padding
1224
+ eos_token_id=self.tokenizer.eos_token_id, # Proper stopping
1225
+ )
1226
+
1227
+ # Decode the output
1228
+ generated_ids_trimmed = [
1229
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
1230
+ ]
1231
+ output_text = self.processor.batch_decode(
1232
+ generated_ids_trimmed,
1233
+ skip_special_tokens=True,
1234
+ clean_up_tokenization_spaces=False
1235
+ )[0]
1236
+
1237
+ if output_text and output_text.strip():
1238
+ text = output_text.strip()
1239
+
1240
+ # ADDED: Filter out any response that looks like an explanation or apology
1241
+ # Common patterns that indicate the model is being "helpful" instead of just extracting
1242
+ unwanted_patterns = [
1243
+ "죄송합니다", # "I apologize"
1244
+ "sorry",
1245
+ "apologize",
1246
+ "이미지에는", # "in this image"
1247
+ "텍스트가 없습니다", # "there is no text"
1248
+ "I cannot",
1249
+ "I don't see",
1250
+ "There is no",
1251
+ "질문이 있으시면", # "if you have questions"
1252
+ ]
1253
+
1254
+ # Check if response contains unwanted patterns
1255
+ text_lower = text.lower()
1256
+ is_explanation = any(pattern.lower() in text_lower for pattern in unwanted_patterns)
1257
+
1258
+ # Also check if the response is suspiciously long for a bubble
1259
+ # Most manga bubbles are short, if we get 50+ chars it might be an explanation
1260
+ is_too_long = len(text) > 100 and ('.' in text or ',' in text or '!' in text)
1261
+
1262
+ if is_explanation or is_too_long:
1263
+ self._log(f"⚠️ Model returned explanation instead of text, ignoring", "warning")
1264
+ # Return empty result or just skip this region
1265
+ return results
1266
+
1267
+ # Check language
1268
+ has_korean = any('\uAC00' <= c <= '\uD7AF' for c in text)
1269
+ has_japanese = any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' for c in text)
1270
+ has_chinese = any('\u4E00' <= c <= '\u9FFF' for c in text)
1271
+
1272
+ if has_korean:
1273
+ self._log(f"✅ Korean detected: {text[:50]}...")
1274
+ elif has_japanese:
1275
+ self._log(f"✅ Japanese detected: {text[:50]}...")
1276
+ elif has_chinese:
1277
+ self._log(f"✅ Chinese detected: {text[:50]}...")
1278
+ else:
1279
+ self._log(f"✅ Text: {text[:50]}...")
1280
+
1281
+ results.append(OCRResult(
1282
+ text=text,
1283
+ bbox=(0, 0, w, h),
1284
+ confidence=0.9,
1285
+ vertices=[(0, 0), (w, 0), (w, h), (0, h)]
1286
+ ))
1287
+ else:
1288
+ self._log("⚠️ No text detected", "warning")
1289
+
1290
+ except Exception as e:
1291
+ self._log(f"❌ Error: {str(e)}", "error")
1292
+ import traceback
1293
+ self._log(traceback.format_exc(), "debug")
1294
+
1295
+ return results
1296
+
1297
+ class EasyOCRProvider(OCRProvider):
1298
+ """EasyOCR provider for multiple languages"""
1299
+
1300
+ def __init__(self, log_callback=None, languages=None):
1301
+ super().__init__(log_callback)
1302
+ # Default to safe language combination
1303
+ self.languages = languages or ['ja', 'en'] # Safe default
1304
+ self._validate_language_combination()
1305
+
1306
+ def _validate_language_combination(self):
1307
+ """Validate and fix EasyOCR language combinations"""
1308
+ # EasyOCR language compatibility rules
1309
+ incompatible_pairs = [
1310
+ (['ja', 'ko'], 'Japanese and Korean cannot be used together'),
1311
+ (['ja', 'zh'], 'Japanese and Chinese cannot be used together'),
1312
+ (['ko', 'zh'], 'Korean and Chinese cannot be used together')
1313
+ ]
1314
+
1315
+ for incompatible, reason in incompatible_pairs:
1316
+ if all(lang in self.languages for lang in incompatible):
1317
+ self._log(f"⚠️ EasyOCR: {reason}", "warning")
1318
+ # Keep first language + English
1319
+ self.languages = [self.languages[0], 'en']
1320
+ self._log(f"🔧 Auto-adjusted to: {self.languages}", "info")
1321
+ break
1322
+
1323
+ def check_installation(self) -> bool:
1324
+ """Check if easyocr is installed"""
1325
+ try:
1326
+ import easyocr
1327
+ self.is_installed = True
1328
+ return True
1329
+ except ImportError:
1330
+ return False
1331
+
1332
+ def install(self, progress_callback=None) -> bool:
1333
+ """Install easyocr"""
1334
+ pass
1335
+
1336
+ def load_model(self, **kwargs) -> bool:
1337
+ """Load easyocr model"""
1338
+ try:
1339
+ if not self.is_installed and not self.check_installation():
1340
+ self._log("❌ easyocr not installed", "error")
1341
+ return False
1342
+
1343
+ self._log(f"🔥 Loading easyocr model for languages: {self.languages}...")
1344
+ import easyocr
1345
+
1346
+ # This will download models on first run
1347
+ self.model = easyocr.Reader(self.languages, gpu=True)
1348
+ self.is_loaded = True
1349
+
1350
+ self._log("✅ easyocr model loaded successfully")
1351
+ return True
1352
+
1353
+ except Exception as e:
1354
+ self._log(f"❌ Failed to load easyocr: {str(e)}", "error")
1355
+ # Try CPU mode if GPU fails
1356
+ try:
1357
+ import easyocr
1358
+ self.model = easyocr.Reader(self.languages, gpu=False)
1359
+ self.is_loaded = True
1360
+ self._log("✅ easyocr loaded in CPU mode")
1361
+ return True
1362
+ except:
1363
+ return False
1364
+
1365
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1366
+ """Detect text using easyocr"""
1367
+ results = []
1368
+
1369
+ try:
1370
+ if not self.is_loaded:
1371
+ if not self.load_model():
1372
+ return results
1373
+
1374
+ # EasyOCR can work directly with numpy arrays
1375
+ ocr_results = self.model.readtext(image, detail=1)
1376
+
1377
+ # Parse results
1378
+ for (bbox, text, confidence) in ocr_results:
1379
+ # bbox is a list of 4 points
1380
+ xs = [point[0] for point in bbox]
1381
+ ys = [point[1] for point in bbox]
1382
+ x_min, x_max = min(xs), max(xs)
1383
+ y_min, y_max = min(ys), max(ys)
1384
+
1385
+ results.append(OCRResult(
1386
+ text=text,
1387
+ bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1388
+ confidence=confidence,
1389
+ vertices=[(int(p[0]), int(p[1])) for p in bbox]
1390
+ ))
1391
+
1392
+ self._log(f"✅ Detected {len(results)} text regions")
1393
+
1394
+ except Exception as e:
1395
+ self._log(f"❌ Error in easyocr detection: {str(e)}", "error")
1396
+
1397
+ return results
1398
+
1399
+
1400
+ class PaddleOCRProvider(OCRProvider):
1401
+ """PaddleOCR provider with memory safety measures"""
1402
+
1403
+ def check_installation(self) -> bool:
1404
+ """Check if paddleocr is installed"""
1405
+ try:
1406
+ from paddleocr import PaddleOCR
1407
+ self.is_installed = True
1408
+ return True
1409
+ except ImportError:
1410
+ return False
1411
+
1412
+ def install(self, progress_callback=None) -> bool:
1413
+ """Install paddleocr"""
1414
+ pass
1415
+
1416
+ def load_model(self, **kwargs) -> bool:
1417
+ """Load paddleocr model with memory-safe configurations"""
1418
+ try:
1419
+ if not self.is_installed and not self.check_installation():
1420
+ self._log("❌ paddleocr not installed", "error")
1421
+ return False
1422
+
1423
+ self._log("🔥 Loading PaddleOCR model...")
1424
+
1425
+ # Set memory-safe environment variables BEFORE importing
1426
+ import os
1427
+ os.environ['OMP_NUM_THREADS'] = '1' # Prevent OpenMP conflicts
1428
+ os.environ['MKL_NUM_THREADS'] = '1' # Prevent MKL conflicts
1429
+ os.environ['OPENBLAS_NUM_THREADS'] = '1' # Prevent OpenBLAS conflicts
1430
+ os.environ['FLAGS_use_mkldnn'] = '0' # Disable MKL-DNN
1431
+
1432
+ from paddleocr import PaddleOCR
1433
+
1434
+ # Try memory-safe configurations
1435
+ configs_to_try = [
1436
+ # Config 1: Most memory-safe configuration
1437
+ {
1438
+ 'use_angle_cls': False, # Disable angle to save memory
1439
+ 'lang': 'ch',
1440
+ 'rec_batch_num': 1, # Process one at a time
1441
+ 'max_text_length': 100, # Limit text length
1442
+ 'drop_score': 0.5, # Higher threshold to reduce detections
1443
+ 'cpu_threads': 1, # Single thread to avoid conflicts
1444
+ },
1445
+ # Config 2: Minimal memory footprint
1446
+ {
1447
+ 'lang': 'ch',
1448
+ 'rec_batch_num': 1,
1449
+ 'cpu_threads': 1,
1450
+ },
1451
+ # Config 3: Absolute minimal
1452
+ {
1453
+ 'lang': 'ch'
1454
+ },
1455
+ # Config 4: Empty config
1456
+ {}
1457
+ ]
1458
+
1459
+ for i, config in enumerate(configs_to_try):
1460
+ try:
1461
+ self._log(f" Trying configuration {i+1}/{len(configs_to_try)}: {config}")
1462
+
1463
+ # Force garbage collection before loading
1464
+ import gc
1465
+ gc.collect()
1466
+
1467
+ self.model = PaddleOCR(**config)
1468
+ self.is_loaded = True
1469
+ self.current_config = config
1470
+ self._log(f"✅ PaddleOCR loaded successfully with config: {config}")
1471
+ return True
1472
+ except Exception as e:
1473
+ error_str = str(e)
1474
+ self._log(f" Config {i+1} failed: {error_str}", "debug")
1475
+
1476
+ # Clean up on failure
1477
+ if hasattr(self, 'model'):
1478
+ del self.model
1479
+ gc.collect()
1480
+ continue
1481
+
1482
+ self._log(f"❌ PaddleOCR failed to load with any configuration", "error")
1483
+ return False
1484
+
1485
+ except Exception as e:
1486
+ self._log(f"❌ Failed to load paddleocr: {str(e)}", "error")
1487
+ import traceback
1488
+ self._log(traceback.format_exc(), "debug")
1489
+ return False
1490
+
1491
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1492
+ """Detect text with memory safety measures"""
1493
+ results = []
1494
+
1495
+ try:
1496
+ if not self.is_loaded:
1497
+ if not self.load_model():
1498
+ return results
1499
+
1500
+ import cv2
1501
+ import numpy as np
1502
+ import gc
1503
+
1504
+ # Memory safety: Ensure image isn't too large
1505
+ h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
1506
+
1507
+ # Limit image size to prevent memory issues
1508
+ MAX_DIMENSION = 1500
1509
+ if h > MAX_DIMENSION or w > MAX_DIMENSION:
1510
+ scale = min(MAX_DIMENSION/h, MAX_DIMENSION/w)
1511
+ new_h, new_w = int(h*scale), int(w*scale)
1512
+ self._log(f"⚠️ Resizing large image from {w}x{h} to {new_w}x{new_h} for memory safety", "warning")
1513
+ image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
1514
+ scale_factor = 1/scale
1515
+ else:
1516
+ scale_factor = 1.0
1517
+
1518
+ # Ensure correct format
1519
+ if len(image.shape) == 2: # Grayscale
1520
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
1521
+ elif len(image.shape) == 4: # Batch
1522
+ image = image[0]
1523
+
1524
+ # Ensure uint8 type
1525
+ if image.dtype != np.uint8:
1526
+ if image.max() <= 1.0:
1527
+ image = (image * 255).astype(np.uint8)
1528
+ else:
1529
+ image = image.astype(np.uint8)
1530
+
1531
+ # Make a copy to avoid memory corruption
1532
+ image_copy = image.copy()
1533
+
1534
+ # Force garbage collection before OCR
1535
+ gc.collect()
1536
+
1537
+ # Process with timeout protection
1538
+ import signal
1539
+ import threading
1540
+
1541
+ ocr_results = None
1542
+ ocr_error = None
1543
+
1544
+ def run_ocr():
1545
+ nonlocal ocr_results, ocr_error
1546
+ try:
1547
+ ocr_results = self.model.ocr(image_copy)
1548
+ except Exception as e:
1549
+ ocr_error = e
1550
+
1551
+ # Run OCR in a separate thread with timeout
1552
+ ocr_thread = threading.Thread(target=run_ocr)
1553
+ ocr_thread.daemon = True
1554
+ ocr_thread.start()
1555
+ ocr_thread.join(timeout=30) # 30 second timeout
1556
+
1557
+ if ocr_thread.is_alive():
1558
+ self._log("❌ PaddleOCR timeout - taking too long", "error")
1559
+ return results
1560
+
1561
+ if ocr_error:
1562
+ raise ocr_error
1563
+
1564
+ # Parse results
1565
+ results = self._parse_ocr_results(ocr_results)
1566
+
1567
+ # Scale coordinates back if image was resized
1568
+ if scale_factor != 1.0 and results:
1569
+ for r in results:
1570
+ x, y, width, height = r.bbox
1571
+ r.bbox = (int(x*scale_factor), int(y*scale_factor),
1572
+ int(width*scale_factor), int(height*scale_factor))
1573
+ r.vertices = [(int(v[0]*scale_factor), int(v[1]*scale_factor))
1574
+ for v in r.vertices]
1575
+
1576
+ if results:
1577
+ self._log(f"✅ Detected {len(results)} text regions", "info")
1578
+ else:
1579
+ self._log("No text regions found", "debug")
1580
+
1581
+ # Clean up
1582
+ del image_copy
1583
+ gc.collect()
1584
+
1585
+ except Exception as e:
1586
+ error_msg = str(e) if str(e) else type(e).__name__
1587
+
1588
+ if "memory" in error_msg.lower() or "0x" in error_msg:
1589
+ self._log("❌ Memory access violation in PaddleOCR", "error")
1590
+ self._log(" This is a known Windows issue with PaddleOCR", "info")
1591
+ self._log(" Please switch to EasyOCR or manga-ocr instead", "warning")
1592
+ elif "trace_order.size()" in error_msg:
1593
+ self._log("❌ PaddleOCR internal error", "error")
1594
+ self._log(" Please switch to EasyOCR or manga-ocr", "warning")
1595
+ else:
1596
+ self._log(f"❌ Error in paddleocr detection: {error_msg}", "error")
1597
+
1598
+ import traceback
1599
+ self._log(traceback.format_exc(), "debug")
1600
+
1601
+ return results
1602
+
1603
+ def _parse_ocr_results(self, ocr_results) -> List[OCRResult]:
1604
+ """Parse OCR results safely"""
1605
+ results = []
1606
+
1607
+ if isinstance(ocr_results, bool) and ocr_results == False:
1608
+ return results
1609
+
1610
+ if ocr_results is None or not isinstance(ocr_results, list):
1611
+ return results
1612
+
1613
+ if len(ocr_results) == 0:
1614
+ return results
1615
+
1616
+ # Handle batch format
1617
+ if isinstance(ocr_results[0], list) and len(ocr_results[0]) > 0:
1618
+ first_item = ocr_results[0][0]
1619
+ if isinstance(first_item, list) and len(first_item) > 0:
1620
+ if isinstance(first_item[0], (list, tuple)) and len(first_item[0]) == 2:
1621
+ ocr_results = ocr_results[0]
1622
+
1623
+ # Parse detections
1624
+ for detection in ocr_results:
1625
+ if not detection or isinstance(detection, bool):
1626
+ continue
1627
+
1628
+ if not isinstance(detection, (list, tuple)) or len(detection) < 2:
1629
+ continue
1630
+
1631
+ try:
1632
+ bbox_points = detection[0]
1633
+ text_data = detection[1]
1634
+
1635
+ if not isinstance(bbox_points, (list, tuple)) or len(bbox_points) != 4:
1636
+ continue
1637
+
1638
+ if not isinstance(text_data, (tuple, list)) or len(text_data) < 2:
1639
+ continue
1640
+
1641
+ text = str(text_data[0]).strip()
1642
+ confidence = float(text_data[1])
1643
+
1644
+ if not text or confidence < 0.3:
1645
+ continue
1646
+
1647
+ xs = [float(p[0]) for p in bbox_points]
1648
+ ys = [float(p[1]) for p in bbox_points]
1649
+ x_min, x_max = min(xs), max(xs)
1650
+ y_min, y_max = min(ys), max(ys)
1651
+
1652
+ if (x_max - x_min) < 5 or (y_max - y_min) < 5:
1653
+ continue
1654
+
1655
+ results.append(OCRResult(
1656
+ text=text,
1657
+ bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1658
+ confidence=confidence,
1659
+ vertices=[(int(p[0]), int(p[1])) for p in bbox_points]
1660
+ ))
1661
+
1662
+ except Exception:
1663
+ continue
1664
+
1665
+ return results
1666
+
1667
+ class DocTROCRProvider(OCRProvider):
1668
+ """DocTR OCR provider"""
1669
+
1670
+ def check_installation(self) -> bool:
1671
+ """Check if doctr is installed"""
1672
+ try:
1673
+ from doctr.models import ocr_predictor
1674
+ self.is_installed = True
1675
+ return True
1676
+ except ImportError:
1677
+ return False
1678
+
1679
+ def install(self, progress_callback=None) -> bool:
1680
+ """Install doctr"""
1681
+ pass
1682
+
1683
+ def load_model(self, **kwargs) -> bool:
1684
+ """Load doctr model"""
1685
+ try:
1686
+ if not self.is_installed and not self.check_installation():
1687
+ self._log("❌ doctr not installed", "error")
1688
+ return False
1689
+
1690
+ self._log("🔥 Loading DocTR model...")
1691
+ from doctr.models import ocr_predictor
1692
+
1693
+ # Load pretrained model
1694
+ self.model = ocr_predictor(pretrained=True)
1695
+ self.is_loaded = True
1696
+
1697
+ self._log("✅ DocTR model loaded successfully")
1698
+ return True
1699
+
1700
+ except Exception as e:
1701
+ self._log(f"❌ Failed to load doctr: {str(e)}", "error")
1702
+ return False
1703
+
1704
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1705
+ """Detect text using doctr"""
1706
+ results = []
1707
+
1708
+ try:
1709
+ if not self.is_loaded:
1710
+ if not self.load_model():
1711
+ return results
1712
+
1713
+ from doctr.io import DocumentFile
1714
+
1715
+ # DocTR expects document format
1716
+ # Convert numpy array to PIL and save temporarily
1717
+ import tempfile
1718
+ import cv2
1719
+
1720
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
1721
+ cv2.imwrite(tmp.name, image)
1722
+ doc = DocumentFile.from_images(tmp.name)
1723
+
1724
+ # Run OCR
1725
+ result = self.model(doc)
1726
+
1727
+ # Parse results
1728
+ h, w = image.shape[:2]
1729
+ for page in result.pages:
1730
+ for block in page.blocks:
1731
+ for line in block.lines:
1732
+ for word in line.words:
1733
+ # Handle different geometry formats
1734
+ geometry = word.geometry
1735
+
1736
+ if len(geometry) == 4:
1737
+ # Standard format: (x1, y1, x2, y2)
1738
+ x1, y1, x2, y2 = geometry
1739
+ elif len(geometry) == 2:
1740
+ # Alternative format: ((x1, y1), (x2, y2))
1741
+ (x1, y1), (x2, y2) = geometry
1742
+ else:
1743
+ self._log(f"Unexpected geometry format: {geometry}", "warning")
1744
+ continue
1745
+
1746
+ # Convert relative coordinates to absolute
1747
+ x1, x2 = int(x1 * w), int(x2 * w)
1748
+ y1, y2 = int(y1 * h), int(y2 * h)
1749
+
1750
+ results.append(OCRResult(
1751
+ text=word.value,
1752
+ bbox=(x1, y1, x2 - x1, y2 - y1),
1753
+ confidence=word.confidence,
1754
+ vertices=[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
1755
+ ))
1756
+
1757
+ # Clean up temp file
1758
+ try:
1759
+ os.unlink(tmp.name)
1760
+ except:
1761
+ pass
1762
+
1763
+ self._log(f"DocTR detected {len(results)} text regions")
1764
+
1765
+ except Exception as e:
1766
+ self._log(f"Error in doctr detection: {str(e)}", "error")
1767
+ import traceback
1768
+ self._log(traceback.format_exc(), "error")
1769
+
1770
+ return results
1771
+
1772
+
1773
+ class RapidOCRProvider(OCRProvider):
1774
+ """RapidOCR provider for fast local OCR"""
1775
+
1776
+ def check_installation(self) -> bool:
1777
+ """Check if rapidocr is installed"""
1778
+ try:
1779
+ import rapidocr_onnxruntime
1780
+ self.is_installed = True
1781
+ return True
1782
+ except ImportError:
1783
+ return False
1784
+
1785
+ def install(self, progress_callback=None) -> bool:
1786
+ """Install rapidocr (requires manual pip install)"""
1787
+ # RapidOCR requires manual installation
1788
+ if progress_callback:
1789
+ progress_callback("RapidOCR requires manual pip installation")
1790
+ self._log("Run: pip install rapidocr-onnxruntime", "info")
1791
+ return False # Always return False since we can't auto-install
1792
+
1793
+ def load_model(self, **kwargs) -> bool:
1794
+ """Load RapidOCR model"""
1795
+ try:
1796
+ if not self.is_installed and not self.check_installation():
1797
+ self._log("RapidOCR not installed", "error")
1798
+ return False
1799
+
1800
+ self._log("Loading RapidOCR...")
1801
+ from rapidocr_onnxruntime import RapidOCR
1802
+
1803
+ self.model = RapidOCR()
1804
+ self.is_loaded = True
1805
+
1806
+ self._log("RapidOCR model loaded successfully")
1807
+ return True
1808
+
1809
+ except Exception as e:
1810
+ self._log(f"Failed to load RapidOCR: {str(e)}", "error")
1811
+ return False
1812
+
1813
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1814
+ """Detect text using RapidOCR"""
1815
+ if not self.is_loaded:
1816
+ self._log("RapidOCR model not loaded", "error")
1817
+ return []
1818
+
1819
+ results = []
1820
+
1821
+ try:
1822
+ # Convert numpy array to PIL Image for RapidOCR
1823
+ if len(image.shape) == 3:
1824
+ # BGR to RGB
1825
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1826
+ else:
1827
+ image_rgb = image
1828
+
1829
+ # RapidOCR expects PIL Image or numpy array
1830
+ ocr_results, _ = self.model(image_rgb)
1831
+
1832
+ if ocr_results:
1833
+ for result in ocr_results:
1834
+ # RapidOCR returns [bbox, text, confidence]
1835
+ bbox_points = result[0] # 4 corner points
1836
+ text = result[1]
1837
+ confidence = float(result[2])
1838
+
1839
+ if not text or not text.strip():
1840
+ continue
1841
+
1842
+ # Convert 4-point bbox to x,y,w,h format
1843
+ xs = [point[0] for point in bbox_points]
1844
+ ys = [point[1] for point in bbox_points]
1845
+ x_min, x_max = min(xs), max(xs)
1846
+ y_min, y_max = min(ys), max(ys)
1847
+
1848
+ results.append(OCRResult(
1849
+ text=text.strip(),
1850
+ bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1851
+ confidence=confidence,
1852
+ vertices=[(int(p[0]), int(p[1])) for p in bbox_points]
1853
+ ))
1854
+
1855
+ self._log(f"Detected {len(results)} text regions")
1856
+
1857
+ except Exception as e:
1858
+ self._log(f"Error in RapidOCR detection: {str(e)}", "error")
1859
+
1860
+ return results
1861
+
1862
+ class OCRManager:
1863
+ """Manager for multiple OCR providers"""
1864
+
1865
+ def __init__(self, log_callback=None):
1866
+ self.log_callback = log_callback
1867
+ self.providers = {
1868
+ 'custom-api': CustomAPIProvider(log_callback) ,
1869
+ 'manga-ocr': MangaOCRProvider(log_callback),
1870
+ 'easyocr': EasyOCRProvider(log_callback),
1871
+ 'paddleocr': PaddleOCRProvider(log_callback),
1872
+ 'doctr': DocTROCRProvider(log_callback),
1873
+ 'rapidocr': RapidOCRProvider(log_callback),
1874
+ 'Qwen2-VL': Qwen2VL(log_callback)
1875
+ }
1876
+ self.current_provider = None
1877
+ self.stop_flag = None
1878
+
1879
+ def get_provider(self, name: str) -> Optional[OCRProvider]:
1880
+ """Get OCR provider by name"""
1881
+ return self.providers.get(name)
1882
+
1883
+ def set_current_provider(self, name: str):
1884
+ """Set current active provider"""
1885
+ if name in self.providers:
1886
+ self.current_provider = name
1887
+ return True
1888
+ return False
1889
+
1890
+ def check_provider_status(self, name: str) -> Dict[str, bool]:
1891
+ """Check installation and loading status of provider"""
1892
+ provider = self.providers.get(name)
1893
+ if not provider:
1894
+ return {'installed': False, 'loaded': False}
1895
+
1896
+ result = {
1897
+ 'installed': provider.check_installation(),
1898
+ 'loaded': provider.is_loaded
1899
+ }
1900
+ if self.log_callback:
1901
+ self.log_callback(f"DEBUG: check_provider_status({name}) returning loaded={result['loaded']}", "debug")
1902
+ return result
1903
+
1904
+ def install_provider(self, name: str, progress_callback=None) -> bool:
1905
+ """Install a provider"""
1906
+ provider = self.providers.get(name)
1907
+ if not provider:
1908
+ return False
1909
+
1910
+ return provider.install(progress_callback)
1911
+
1912
+ def load_provider(self, name: str, **kwargs) -> bool:
1913
+ """Load a provider's model with optional parameters"""
1914
+ provider = self.providers.get(name)
1915
+ if not provider:
1916
+ return False
1917
+
1918
+ return provider.load_model(**kwargs) # <-- Passes model_size and any other kwargs
1919
+
1920
+ def shutdown(self):
1921
+ """Release models/processors/tokenizers for all providers and clear caches."""
1922
+ try:
1923
+ import gc
1924
+ for name, provider in list(self.providers.items()):
1925
+ try:
1926
+ if hasattr(provider, 'model'):
1927
+ provider.model = None
1928
+ if hasattr(provider, 'processor'):
1929
+ provider.processor = None
1930
+ if hasattr(provider, 'tokenizer'):
1931
+ provider.tokenizer = None
1932
+ if hasattr(provider, 'reader'):
1933
+ provider.reader = None
1934
+ if hasattr(provider, 'is_loaded'):
1935
+ provider.is_loaded = False
1936
+ except Exception:
1937
+ pass
1938
+ gc.collect()
1939
+ try:
1940
+ import torch
1941
+ torch.cuda.empty_cache()
1942
+ except Exception:
1943
+ pass
1944
+ except Exception:
1945
+ pass
1946
+
1947
+ def detect_text(self, image: np.ndarray, provider_name: str = None, **kwargs) -> List[OCRResult]:
1948
+ """Detect text using specified or current provider"""
1949
+ provider_name = provider_name or self.current_provider
1950
+ if not provider_name:
1951
+ return []
1952
+
1953
+ provider = self.providers.get(provider_name)
1954
+ if not provider:
1955
+ return []
1956
+
1957
+ return provider.detect_text(image, **kwargs)
1958
+
1959
+ def set_stop_flag(self, stop_flag):
1960
+ """Set stop flag for all providers"""
1961
+ self.stop_flag = stop_flag
1962
+ for provider in self.providers.values():
1963
+ if hasattr(provider, 'set_stop_flag'):
1964
+ provider.set_stop_flag(stop_flag)
1965
+
1966
+ def reset_stop_flags(self):
1967
+ """Reset stop flags for all providers"""
1968
+ for provider in self.providers.values():
1969
+ if hasattr(provider, 'reset_stop_flags'):
1970
+ provider.reset_stop_flags()
translator_gui.py ADDED
The diff for this file is too large to render. See raw diff
 
unified_api_client.py ADDED
The diff for this file is too large to render. See raw diff