Shirochi commited on
Commit
0bf8c19
·
verified ·
1 Parent(s): 463ae26

Delete bubble_detector.py

Browse files
Files changed (1) hide show
  1. bubble_detector.py +0 -1920
bubble_detector.py DELETED
@@ -1,1920 +0,0 @@
1
- """
2
- bubble_detector.py - Modified version that works in frozen PyInstaller executables
3
- Replace your bubble_detector.py with this version
4
- """
5
- import os
6
- import sys
7
- import json
8
- import numpy as np
9
- import cv2
10
- from typing import List, Tuple, Optional, Dict, Any
11
- import logging
12
- import traceback
13
- import hashlib
14
- from pathlib import Path
15
- import threading
16
- import time
17
-
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
-
21
- # Check if we're running in a frozen environment
22
- IS_FROZEN = getattr(sys, 'frozen', False)
23
- if IS_FROZEN:
24
- # In frozen environment, set proper paths for ML libraries
25
- MEIPASS = sys._MEIPASS
26
- os.environ['TORCH_HOME'] = MEIPASS
27
- os.environ['TRANSFORMERS_CACHE'] = os.path.join(MEIPASS, 'transformers')
28
- os.environ['HF_HOME'] = os.path.join(MEIPASS, 'huggingface')
29
- logger.info(f"Running in frozen environment: {MEIPASS}")
30
-
31
- # Modified import checks for frozen environment
32
- YOLO_AVAILABLE = False
33
- YOLO = None
34
- torch = None
35
- TORCH_AVAILABLE = False
36
- ONNX_AVAILABLE = False
37
- TRANSFORMERS_AVAILABLE = False
38
- RTDetrForObjectDetection = None
39
- RTDetrImageProcessor = None
40
- PIL_AVAILABLE = False
41
-
42
- # Try to import YOLO dependencies with better error handling
43
- if IS_FROZEN:
44
- # In frozen environment, try harder to import
45
- try:
46
- # First try to import torch components individually
47
- import torch
48
- import torch.nn
49
- import torch.cuda
50
- TORCH_AVAILABLE = True
51
- logger.info("✓ PyTorch loaded in frozen environment")
52
- except Exception as e:
53
- logger.warning(f"PyTorch not available in frozen environment: {e}")
54
- TORCH_AVAILABLE = False
55
- torch = None
56
-
57
- # Try ultralytics after torch
58
- if TORCH_AVAILABLE:
59
- try:
60
- from ultralytics import YOLO
61
- YOLO_AVAILABLE = True
62
- logger.info("✓ Ultralytics YOLO loaded in frozen environment")
63
- except Exception as e:
64
- logger.warning(f"Ultralytics not available in frozen environment: {e}")
65
- YOLO_AVAILABLE = False
66
-
67
- # Try transformers
68
- try:
69
- import transformers
70
- # Try specific imports
71
- try:
72
- from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
73
- TRANSFORMERS_AVAILABLE = True
74
- logger.info("✓ Transformers RT-DETR loaded in frozen environment")
75
- except ImportError:
76
- # Try alternative import
77
- try:
78
- from transformers import AutoModel, AutoImageProcessor
79
- RTDetrForObjectDetection = AutoModel
80
- RTDetrImageProcessor = AutoImageProcessor
81
- TRANSFORMERS_AVAILABLE = True
82
- logger.info("✓ Transformers loaded with AutoModel fallback")
83
- except:
84
- TRANSFORMERS_AVAILABLE = False
85
- logger.warning("Transformers RT-DETR not available in frozen environment")
86
- except Exception as e:
87
- logger.warning(f"Transformers not available in frozen environment: {e}")
88
- TRANSFORMERS_AVAILABLE = False
89
- else:
90
- # Normal environment - original import logic
91
- try:
92
- from ultralytics import YOLO
93
- YOLO_AVAILABLE = True
94
- except:
95
- YOLO_AVAILABLE = False
96
- logger.warning("Ultralytics YOLO not available")
97
-
98
- try:
99
- import torch
100
- # Test if cuda attribute exists
101
- _ = torch.cuda
102
- TORCH_AVAILABLE = True
103
- except (ImportError, AttributeError):
104
- TORCH_AVAILABLE = False
105
- torch = None
106
- logger.warning("PyTorch not available or incomplete")
107
-
108
- try:
109
- from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
110
- try:
111
- from transformers import RTDetrV2ForObjectDetection
112
- RTDetrForObjectDetection = RTDetrV2ForObjectDetection
113
- except ImportError:
114
- pass
115
- TRANSFORMERS_AVAILABLE = True
116
- except:
117
- TRANSFORMERS_AVAILABLE = False
118
- logger.info("Transformers not available for RT-DETR")
119
-
120
- # Configure ORT memory behavior before importing
121
- try:
122
- os.environ.setdefault('ORT_DISABLE_MEMORY_ARENA', '1')
123
- except Exception:
124
- pass
125
- # ONNX Runtime - works well in frozen environments
126
- try:
127
- import onnxruntime as ort
128
- ONNX_AVAILABLE = True
129
- logger.info("✓ ONNX Runtime available")
130
- except ImportError:
131
- ONNX_AVAILABLE = False
132
- logger.warning("ONNX Runtime not available")
133
-
134
- # PIL
135
- try:
136
- from PIL import Image
137
- PIL_AVAILABLE = True
138
- except ImportError:
139
- PIL_AVAILABLE = False
140
- logger.info("PIL not available")
141
-
142
-
143
- class BubbleDetector:
144
- """
145
- Combined YOLOv8 and RT-DETR speech bubble detector for comics and manga.
146
- Supports multiple model formats and provides configurable detection.
147
- Backward compatible with existing code while adding RT-DETR support.
148
- """
149
-
150
- # Process-wide shared RT-DETR to avoid concurrent meta-device loads
151
- _rtdetr_init_lock = threading.Lock()
152
- _rtdetr_shared_model = None
153
- _rtdetr_shared_processor = None
154
- _rtdetr_loaded = False
155
- _rtdetr_repo_id = 'ogkalu/comic-text-and-bubble-detector'
156
-
157
- # Shared RT-DETR (ONNX) across process to avoid device/context storms
158
- _rtdetr_onnx_init_lock = threading.Lock()
159
- _rtdetr_onnx_shared_session = None
160
- _rtdetr_onnx_loaded = False
161
- _rtdetr_onnx_providers = None
162
- _rtdetr_onnx_model_path = None
163
- # Limit concurrent runs to avoid device hangs. Defaults to 2 for better parallelism.
164
- # Can be overridden via env DML_MAX_CONCURRENT or config rtdetr_max_concurrency
165
- try:
166
- _rtdetr_onnx_max_concurrent = int(os.environ.get('DML_MAX_CONCURRENT', '2'))
167
- except Exception:
168
- _rtdetr_onnx_max_concurrent = 2
169
- _rtdetr_onnx_sema = threading.Semaphore(max(1, _rtdetr_onnx_max_concurrent))
170
- _rtdetr_onnx_sema_initialized = False
171
-
172
- def __init__(self, config_path: str = "config.json"):
173
- """
174
- Initialize the bubble detector.
175
-
176
- Args:
177
- config_path: Path to configuration file
178
- """
179
- # Set thread limits early if environment indicates single-threaded mode
180
- try:
181
- if os.environ.get('OMP_NUM_THREADS') == '1':
182
- # Already in single-threaded mode, ensure it's applied to this process
183
- # Check if torch is available at module level before trying to use it
184
- if TORCH_AVAILABLE and torch is not None:
185
- try:
186
- torch.set_num_threads(1)
187
- except (RuntimeError, AttributeError):
188
- pass
189
- try:
190
- import cv2
191
- cv2.setNumThreads(1)
192
- except (ImportError, AttributeError):
193
- pass
194
- except Exception:
195
- pass
196
-
197
- self.config_path = config_path
198
- self.config = self._load_config()
199
-
200
- # YOLOv8 components (original)
201
- self.model = None
202
- self.model_loaded = False
203
- self.model_type = None # 'yolo', 'onnx', or 'torch'
204
- self.onnx_session = None
205
-
206
- # RT-DETR components (new)
207
- self.rtdetr_model = None
208
- self.rtdetr_processor = None
209
- self.rtdetr_loaded = False
210
- self.rtdetr_repo = 'ogkalu/comic-text-and-bubble-detector'
211
-
212
- # RT-DETR (ONNX) backend components
213
- self.rtdetr_onnx_session = None
214
- self.rtdetr_onnx_loaded = False
215
- self.rtdetr_onnx_repo = 'ogkalu/comic-text-and-bubble-detector'
216
-
217
- # RT-DETR class definitions
218
- self.CLASS_BUBBLE = 0 # Empty speech bubble
219
- self.CLASS_TEXT_BUBBLE = 1 # Bubble with text
220
- self.CLASS_TEXT_FREE = 2 # Text without bubble
221
-
222
- # Detection settings
223
- self.default_confidence = 0.3
224
- self.default_iou_threshold = 0.45
225
- # Allow override from settings
226
- try:
227
- ocr_cfg = self.config.get('manga_settings', {}).get('ocr', {}) if isinstance(self.config, dict) else {}
228
- self.default_max_detections = int(ocr_cfg.get('bubble_max_detections', 100))
229
- self.max_det_yolo = int(ocr_cfg.get('bubble_max_detections_yolo', self.default_max_detections))
230
- self.max_det_rtdetr = int(ocr_cfg.get('bubble_max_detections_rtdetr', self.default_max_detections))
231
- except Exception:
232
- self.default_max_detections = 100
233
- self.max_det_yolo = 100
234
- self.max_det_rtdetr = 100
235
-
236
- # Cache directory for ONNX conversions
237
- self.cache_dir = os.environ.get('BUBBLE_CACHE_DIR', 'models')
238
- os.makedirs(self.cache_dir, exist_ok=True)
239
-
240
- # RT-DETR concurrency setting from config
241
- try:
242
- rtdetr_max_conc = int(ocr_cfg.get('rtdetr_max_concurrency', 2))
243
- # Update class-level semaphore if not yet initialized or if value changed
244
- if not BubbleDetector._rtdetr_onnx_sema_initialized or rtdetr_max_conc != BubbleDetector._rtdetr_onnx_max_concurrent:
245
- BubbleDetector._rtdetr_onnx_max_concurrent = max(1, rtdetr_max_conc)
246
- BubbleDetector._rtdetr_onnx_sema = threading.Semaphore(BubbleDetector._rtdetr_onnx_max_concurrent)
247
- BubbleDetector._rtdetr_onnx_sema_initialized = True
248
- logger.info(f"RT-DETR concurrency set to: {BubbleDetector._rtdetr_onnx_max_concurrent}")
249
- except Exception as e:
250
- logger.warning(f"Failed to set RT-DETR concurrency: {e}")
251
-
252
- # GPU availability
253
- self.use_gpu = TORCH_AVAILABLE and torch.cuda.is_available()
254
- self.device = 'cuda' if self.use_gpu else 'cpu'
255
-
256
- # Quantization/precision settings
257
- adv_cfg = self.config.get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
258
- ocr_cfg = self.config.get('manga_settings', {}).get('ocr', {}) if isinstance(self.config, dict) else {}
259
- env_quant = os.environ.get('MODEL_QUANTIZE', 'false').lower() == 'true'
260
- self.quantize_enabled = bool(env_quant or adv_cfg.get('quantize_models', False) or ocr_cfg.get('quantize_bubble_detector', False))
261
- self.quantize_dtype = str(adv_cfg.get('torch_precision', os.environ.get('TORCH_PRECISION', 'auto'))).lower()
262
- # Prefer advanced.onnx_quantize; fall back to env or global quantize
263
- self.onnx_quantize_enabled = bool(adv_cfg.get('onnx_quantize', os.environ.get('ONNX_QUANTIZE', 'false').lower() == 'true' or self.quantize_enabled))
264
-
265
- # Stop flag support
266
- self.stop_flag = None
267
- self._stopped = False
268
- self.log_callback = None
269
-
270
- logger.info(f"🗨️ BubbleDetector initialized")
271
- logger.info(f" GPU: {'Available' if self.use_gpu else 'Not available'}")
272
- logger.info(f" YOLO: {'Available' if YOLO_AVAILABLE else 'Not installed'}")
273
- logger.info(f" ONNX: {'Available' if ONNX_AVAILABLE else 'Not installed'}")
274
- logger.info(f" RT-DETR: {'Available' if TRANSFORMERS_AVAILABLE else 'Not installed'}")
275
- logger.info(f" Quantization: {'ENABLED' if self.quantize_enabled else 'disabled'} (torch_precision={self.quantize_dtype}, onnx_quantize={'on' if self.onnx_quantize_enabled else 'off'})" )
276
-
277
- def _load_config(self) -> Dict[str, Any]:
278
- """Load configuration from file."""
279
- if os.path.exists(self.config_path):
280
- try:
281
- with open(self.config_path, 'r', encoding='utf-8') as f:
282
- return json.load(f)
283
- except Exception as e:
284
- logger.warning(f"Failed to load config: {e}")
285
- return {}
286
-
287
- def _save_config(self):
288
- """Save configuration to file."""
289
- try:
290
- with open(self.config_path, 'w', encoding='utf-8') as f:
291
- json.dump(self.config, f, indent=2)
292
- except Exception as e:
293
- logger.error(f"Failed to save config: {e}")
294
-
295
- def set_stop_flag(self, stop_flag):
296
- """Set the stop flag for checking interruptions"""
297
- self.stop_flag = stop_flag
298
- self._stopped = False
299
-
300
- def set_log_callback(self, log_callback):
301
- """Set log callback for GUI integration"""
302
- self.log_callback = log_callback
303
-
304
- def _check_stop(self) -> bool:
305
- """Check if stop has been requested"""
306
- if self._stopped:
307
- return True
308
- if self.stop_flag and self.stop_flag.is_set():
309
- self._stopped = True
310
- return True
311
- # Check global manga translator cancellation
312
- try:
313
- from manga_translator import MangaTranslator
314
- if MangaTranslator.is_globally_cancelled():
315
- self._stopped = True
316
- return True
317
- except Exception:
318
- pass
319
- return False
320
-
321
- def _log(self, message: str, level: str = "info"):
322
- """Log message with stop suppression"""
323
- # Suppress logs when stopped (allow only essential stop confirmation messages)
324
- if self._check_stop():
325
- essential_stop_keywords = [
326
- "⏹️ Translation stopped by user",
327
- "⏹️ Bubble detection stopped",
328
- "cleanup", "🧹"
329
- ]
330
- if not any(keyword in message for keyword in essential_stop_keywords):
331
- return
332
-
333
- if self.log_callback:
334
- self.log_callback(message, level)
335
- else:
336
- logger.info(message) if level == 'info' else getattr(logger, level, logger.info)(message)
337
-
338
- def reset_stop_flags(self):
339
- """Reset stop flags when starting new processing"""
340
- self._stopped = False
341
-
342
- def load_model(self, model_path: str, force_reload: bool = False) -> bool:
343
- """
344
- Load a YOLOv8 model for bubble detection.
345
-
346
- Args:
347
- model_path: Path to model file (.pt, .onnx, or .torchscript)
348
- force_reload: Force reload even if model is already loaded
349
-
350
- Returns:
351
- True if model loaded successfully, False otherwise
352
- """
353
- try:
354
- # If given a Hugging Face repo ID (e.g., 'owner/name'), fetch detector.onnx into models/
355
- if model_path and (('/' in model_path) and not os.path.exists(model_path)):
356
- try:
357
- from huggingface_hub import hf_hub_download
358
- os.makedirs(self.cache_dir, exist_ok=True)
359
- logger.info(f"📥 Resolving repo '{model_path}' to detector.onnx in {self.cache_dir}...")
360
- resolved = hf_hub_download(repo_id=model_path, filename='detector.onnx', cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False)
361
- if resolved and os.path.exists(resolved):
362
- model_path = resolved
363
- logger.info(f"✅ Downloaded detector.onnx to: {model_path}")
364
- except Exception as repo_err:
365
- logger.error(f"Failed to download from repo '{model_path}': {repo_err}")
366
- if not os.path.exists(model_path):
367
- logger.error(f"Model file not found: {model_path}")
368
- return False
369
-
370
- # Check if it's the same model already loaded
371
- if self.model_loaded and not force_reload:
372
- last_path = self.config.get('last_model_path', '')
373
- if last_path == model_path:
374
- logger.info("Model already loaded (same path)")
375
- return True
376
- else:
377
- logger.info(f"Model path changed from {last_path} to {model_path}, reloading...")
378
- force_reload = True
379
-
380
- # Clear previous model if force reload
381
- if force_reload:
382
- logger.info("Force reloading model...")
383
- self.model = None
384
- self.onnx_session = None
385
- self.model_loaded = False
386
- self.model_type = None
387
-
388
- logger.info(f"📥 Loading bubble detection model: {model_path}")
389
-
390
- # Determine model type by extension
391
- ext = Path(model_path).suffix.lower()
392
-
393
- if ext in ['.pt', '.pth']:
394
- if not YOLO_AVAILABLE:
395
- logger.warning("Ultralytics package not available in this build")
396
- logger.info("Bubble detection will be disabled - this is normal for lightweight builds")
397
- # Don't return False immediately, try other fallbacks
398
- self.model_loaded = False
399
- return False
400
-
401
- # Load YOLOv8 model
402
- try:
403
- self.model = YOLO(model_path)
404
- self.model_type = 'yolo'
405
-
406
- # Set to eval mode
407
- if hasattr(self.model, 'model'):
408
- self.model.model.eval()
409
-
410
- # Move to GPU if available
411
- if self.use_gpu and TORCH_AVAILABLE:
412
- try:
413
- self.model.to('cuda')
414
- except Exception as gpu_error:
415
- logger.warning(f"Could not move model to GPU: {gpu_error}")
416
-
417
- logger.info("✅ YOLOv8 model loaded successfully")
418
- # Apply optional FP16 precision to reduce VRAM if enabled
419
- if self.quantize_enabled and self.use_gpu and TORCH_AVAILABLE:
420
- try:
421
- m = self.model.model if hasattr(self.model, 'model') else self.model
422
- m.half()
423
- logger.info("🔻 Applied FP16 precision to YOLO model (GPU)")
424
- except Exception as _e:
425
- logger.warning(f"Could not switch YOLO model to FP16: {_e}")
426
-
427
- except Exception as yolo_error:
428
- logger.error(f"Failed to load YOLO model: {yolo_error}")
429
- return False
430
-
431
- elif ext == '.onnx':
432
- if not ONNX_AVAILABLE:
433
- logger.warning("ONNX Runtime not available in this build")
434
- logger.info("ONNX model support disabled - this is normal for lightweight builds")
435
- return False
436
-
437
- try:
438
- # Load ONNX model
439
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.use_gpu else ['CPUExecutionProvider']
440
- session_path = model_path
441
- if self.quantize_enabled:
442
- try:
443
- from onnxruntime.quantization import quantize_dynamic, QuantType
444
- quant_path = os.path.splitext(model_path)[0] + ".int8.onnx"
445
- if not os.path.exists(quant_path) or os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
446
- logger.info("🔻 Quantizing ONNX model weights to INT8 (dynamic)...")
447
- quantize_dynamic(model_input=model_path, model_output=quant_path, weight_type=QuantType.QInt8, op_types_to_quantize=['Conv', 'MatMul'])
448
- session_path = quant_path
449
- self.config['last_onnx_quantized_path'] = quant_path
450
- self._save_config()
451
- logger.info(f"✅ Using quantized ONNX model: {quant_path}")
452
- except Exception as qe:
453
- logger.warning(f"ONNX quantization not applied: {qe}")
454
- # Use conservative ORT memory options to reduce RAM growth
455
- so = ort.SessionOptions()
456
- try:
457
- so.enable_mem_pattern = False
458
- so.enable_cpu_mem_arena = False
459
- except Exception:
460
- pass
461
- self.onnx_session = ort.InferenceSession(session_path, sess_options=so, providers=providers)
462
- self.model_type = 'onnx'
463
-
464
- logger.info("✅ ONNX model loaded successfully")
465
-
466
- except Exception as onnx_error:
467
- logger.error(f"Failed to load ONNX model: {onnx_error}")
468
- return False
469
-
470
- elif ext == '.torchscript':
471
- if not TORCH_AVAILABLE:
472
- logger.warning("PyTorch not available in this build")
473
- logger.info("TorchScript model support disabled - this is normal for lightweight builds")
474
- return False
475
-
476
- try:
477
- # Add safety check for torch being None
478
- if torch is None:
479
- logger.error("PyTorch module is None - cannot load TorchScript model")
480
- return False
481
-
482
- # Load TorchScript model
483
- self.model = torch.jit.load(model_path, map_location='cpu')
484
- self.model.eval()
485
- self.model_type = 'torch'
486
-
487
- if self.use_gpu:
488
- try:
489
- self.model = self.model.cuda()
490
- except Exception as gpu_error:
491
- logger.warning(f"Could not move TorchScript model to GPU: {gpu_error}")
492
-
493
- logger.info("✅ TorchScript model loaded successfully")
494
-
495
- # Optional FP16 precision on GPU
496
- if self.quantize_enabled and self.use_gpu and TORCH_AVAILABLE:
497
- try:
498
- self.model = self.model.half()
499
- logger.info("🔻 Applied FP16 precision to TorchScript model (GPU)")
500
- except Exception as _e:
501
- logger.warning(f"Could not switch TorchScript model to FP16: {_e}")
502
-
503
- except Exception as torch_error:
504
- logger.error(f"Failed to load TorchScript model: {torch_error}")
505
- return False
506
-
507
- else:
508
- logger.error(f"Unsupported model format: {ext}")
509
- logger.info("Supported formats: .pt/.pth (YOLOv8), .onnx (ONNX), .torchscript (TorchScript)")
510
- return False
511
-
512
- # Only set loaded if we actually succeeded
513
- self.model_loaded = True
514
- self.config['last_model_path'] = model_path
515
- self.config['model_type'] = self.model_type
516
- self._save_config()
517
-
518
- return True
519
-
520
- except Exception as e:
521
- logger.error(f"Failed to load model: {e}")
522
- logger.error(traceback.format_exc())
523
- self.model_loaded = False
524
-
525
- # Provide helpful context for .exe users
526
- logger.info("Note: If running from .exe, some ML libraries may not be included")
527
- logger.info("This is normal for lightweight builds - bubble detection will be disabled")
528
-
529
- return False
530
-
531
- def load_rtdetr_model(self, model_path: str = None, model_id: str = None, force_reload: bool = False) -> bool:
532
- """
533
- Load RT-DETR model for advanced bubble and text detection.
534
- This implementation avoids the 'meta tensor' copy error by:
535
- - Serializing the entire load under a class lock (no concurrent loads)
536
- - Loading directly onto the target device (CUDA if available) via device_map='auto'
537
- - Avoiding .to() on a potentially-meta model; no device migration post-load
538
-
539
- Args:
540
- model_path: Optional path to local model
541
- model_id: Optional HuggingFace model ID (default: 'ogkalu/comic-text-and-bubble-detector')
542
- force_reload: Force reload even if already loaded
543
-
544
- Returns:
545
- True if successful, False otherwise
546
- """
547
- if not TRANSFORMERS_AVAILABLE:
548
- logger.error("Transformers library required for RT-DETR. Install with: pip install transformers")
549
- return False
550
-
551
- if not PIL_AVAILABLE:
552
- logger.error("PIL required for RT-DETR. Install with: pip install pillow")
553
- return False
554
-
555
- if self.rtdetr_loaded and not force_reload:
556
- logger.info("RT-DETR model already loaded")
557
- return True
558
-
559
- # Fast path: if shared already loaded and not forcing reload, attach
560
- if BubbleDetector._rtdetr_loaded and not force_reload:
561
- self.rtdetr_model = BubbleDetector._rtdetr_shared_model
562
- self.rtdetr_processor = BubbleDetector._rtdetr_shared_processor
563
- self.rtdetr_loaded = True
564
- logger.info("RT-DETR model attached from shared cache")
565
- return True
566
-
567
- # Serialize the ENTIRE loading sequence to avoid concurrent init issues
568
- with BubbleDetector._rtdetr_init_lock:
569
- try:
570
- # Re-check after acquiring lock
571
- if BubbleDetector._rtdetr_loaded and not force_reload:
572
- self.rtdetr_model = BubbleDetector._rtdetr_shared_model
573
- self.rtdetr_processor = BubbleDetector._rtdetr_shared_processor
574
- self.rtdetr_loaded = True
575
- logger.info("RT-DETR model attached from shared cache (post-lock)")
576
- return True
577
-
578
- # Use custom model_id if provided, otherwise use default
579
- repo_id = model_id if model_id else self.rtdetr_repo
580
- logger.info(f"📥 Loading RT-DETR model from {repo_id}...")
581
-
582
- # Ensure TorchDynamo/compile doesn't interfere on some builds
583
- try:
584
- os.environ.setdefault('TORCHDYNAMO_DISABLE', '1')
585
- except Exception:
586
- pass
587
-
588
- # Decide device strategy
589
- gpu_available = bool(TORCH_AVAILABLE and hasattr(torch, 'cuda') and torch.cuda.is_available())
590
- device_map = 'auto' if gpu_available else None
591
- # Choose dtype
592
- dtype = None
593
- if TORCH_AVAILABLE:
594
- try:
595
- dtype = torch.float16 if gpu_available else torch.float32
596
- except Exception:
597
- dtype = None
598
- low_cpu = True if gpu_available else False
599
-
600
- # Load processor (once)
601
- self.rtdetr_processor = RTDetrImageProcessor.from_pretrained(
602
- repo_id,
603
- size={"width": 640, "height": 640},
604
- cache_dir=self.cache_dir if not model_path else None
605
- )
606
-
607
- # Prepare kwargs for from_pretrained
608
- from_kwargs = {
609
- 'cache_dir': self.cache_dir if not model_path else None,
610
- 'low_cpu_mem_usage': low_cpu,
611
- 'device_map': device_map,
612
- }
613
- if dtype is not None:
614
- from_kwargs['dtype'] = dtype
615
-
616
- # First attempt: load directly to target (CUDA if available)
617
- try:
618
- self.rtdetr_model = RTDetrForObjectDetection.from_pretrained(
619
- model_path if model_path else repo_id,
620
- **from_kwargs,
621
- )
622
- except Exception as primary_err:
623
- # Fallback to a simple CPU load (no device move) if CUDA path fails
624
- logger.warning(f"RT-DETR primary load failed ({primary_err}); retrying on CPU...")
625
- from_kwargs_fallback = {
626
- 'cache_dir': self.cache_dir if not model_path else None,
627
- 'low_cpu_mem_usage': False,
628
- 'device_map': None,
629
- }
630
- if TORCH_AVAILABLE:
631
- from_kwargs_fallback['dtype'] = torch.float32
632
- self.rtdetr_model = RTDetrForObjectDetection.from_pretrained(
633
- model_path if model_path else repo_id,
634
- **from_kwargs_fallback,
635
- )
636
-
637
- # Optional dynamic quantization for linear layers (CPU only)
638
- if self.quantize_enabled and TORCH_AVAILABLE and (not gpu_available):
639
- try:
640
- try:
641
- import torch.ao.quantization as tq
642
- quantize_dynamic = tq.quantize_dynamic # type: ignore
643
- except Exception:
644
- import torch.quantization as tq # type: ignore
645
- quantize_dynamic = tq.quantize_dynamic # type: ignore
646
- self.rtdetr_model = quantize_dynamic(self.rtdetr_model, {torch.nn.Linear}, dtype=torch.qint8)
647
- logger.info("🔻 Applied dynamic INT8 quantization to RT-DETR linear layers (CPU)")
648
- except Exception as qe:
649
- logger.warning(f"RT-DETR dynamic quantization skipped: {qe}")
650
-
651
- # Finalize
652
- self.rtdetr_model.eval()
653
-
654
- # Sanity check: ensure no parameter is left on 'meta' device
655
- try:
656
- for n, p in self.rtdetr_model.named_parameters():
657
- dev = getattr(p, 'device', None)
658
- if dev is not None and getattr(dev, 'type', '') == 'meta':
659
- raise RuntimeError(f"Parameter {n} is on 'meta' device after load")
660
- except Exception as e:
661
- logger.error(f"RT-DETR load sanity check failed: {e}")
662
- self.rtdetr_loaded = False
663
- return False
664
-
665
- # Publish shared cache
666
- BubbleDetector._rtdetr_shared_model = self.rtdetr_model
667
- BubbleDetector._rtdetr_shared_processor = self.rtdetr_processor
668
- BubbleDetector._rtdetr_loaded = True
669
- BubbleDetector._rtdetr_repo_id = repo_id
670
-
671
- self.rtdetr_loaded = True
672
-
673
- # Save the model ID that was used
674
- self.config['rtdetr_loaded'] = True
675
- self.config['rtdetr_model_id'] = repo_id
676
- self._save_config()
677
-
678
- loc = 'CUDA' if gpu_available else 'CPU'
679
- logger.info(f"✅ RT-DETR model loaded successfully ({loc})")
680
- logger.info(" Classes: Empty bubbles, Text bubbles, Free text")
681
-
682
- # Auto-convert to ONNX for RT-DETR only if explicitly enabled
683
- if os.environ.get('AUTO_CONVERT_RTDETR_ONNX', 'false').lower() == 'true':
684
- onnx_path = os.path.join(self.cache_dir, 'rtdetr_comic.onnx')
685
- if self.convert_to_onnx('rtdetr', onnx_path):
686
- logger.info("🚀 RT-DETR converted to ONNX for faster inference")
687
- # Store ONNX path for later use
688
- self.config['rtdetr_onnx_path'] = onnx_path
689
- self._save_config()
690
- # Optionally quantize ONNX for reduced RAM
691
- if self.onnx_quantize_enabled:
692
- try:
693
- from onnxruntime.quantization import quantize_dynamic, QuantType
694
- quant_path = os.path.splitext(onnx_path)[0] + ".int8.onnx"
695
- if not os.path.exists(quant_path) or os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
696
- logger.info("🔻 Quantizing RT-DETR ONNX to INT8 (dynamic)...")
697
- quantize_dynamic(model_input=onnx_path, model_output=quant_path, weight_type=QuantType.QInt8, op_types_to_quantize=['Conv', 'MatMul'])
698
- self.config['rtdetr_onnx_quantized_path'] = quant_path
699
- self._save_config()
700
- logger.info(f"✅ Quantized RT-DETR ONNX saved to: {quant_path}")
701
- except Exception as qe:
702
- logger.warning(f"ONNX quantization for RT-DETR skipped: {qe}")
703
- else:
704
- logger.info("ℹ️ Skipping RT-DETR ONNX export (converter not supported in current environment)")
705
-
706
- return True
707
- except Exception as e:
708
- logger.error(f"❌ Failed to load RT-DETR: {e}")
709
- self.rtdetr_loaded = False
710
- return False
711
-
712
- def check_rtdetr_available(self, model_id: str = None) -> bool:
713
- """
714
- Check if RT-DETR model is available (cached).
715
-
716
- Args:
717
- model_id: Optional HuggingFace model ID
718
-
719
- Returns:
720
- True if model is cached and available
721
- """
722
- try:
723
- from pathlib import Path
724
-
725
- # Use provided model_id or default
726
- repo_id = model_id if model_id else self.rtdetr_repo
727
-
728
- # Check HuggingFace cache
729
- cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
730
- model_id_formatted = repo_id.replace("/", "--")
731
-
732
- # Look for model folder
733
- model_folders = list(cache_dir.glob(f"models--{model_id_formatted}*"))
734
-
735
- if model_folders:
736
- for folder in model_folders:
737
- if (folder / "snapshots").exists():
738
- snapshots = list((folder / "snapshots").iterdir())
739
- if snapshots:
740
- return True
741
-
742
- return False
743
-
744
- except Exception:
745
- return False
746
-
747
- def detect_bubbles(self,
748
- image_path: str,
749
- confidence: float = None,
750
- iou_threshold: float = None,
751
- max_detections: int = None,
752
- use_rtdetr: bool = None) -> List[Tuple[int, int, int, int]]:
753
- """
754
- Detect speech bubbles in an image (backward compatible method).
755
-
756
- Args:
757
- image_path: Path to image file
758
- confidence: Minimum confidence threshold (0-1)
759
- iou_threshold: IOU threshold for NMS (0-1)
760
- max_detections: Maximum number of detections to return
761
- use_rtdetr: If True, use RT-DETR instead of YOLOv8 (if available)
762
-
763
- Returns:
764
- List of bubble bounding boxes as (x, y, width, height) tuples
765
- """
766
- # Check for stop at start
767
- if self._check_stop():
768
- self._log("⏹️ Bubble detection stopped by user", "warning")
769
- return []
770
-
771
- # Decide which model to use
772
- if use_rtdetr is None:
773
- # Auto-select: prefer RT-DETR if available
774
- use_rtdetr = self.rtdetr_loaded
775
-
776
- if use_rtdetr:
777
- # Prefer ONNX backend if available, else PyTorch
778
- if getattr(self, 'rtdetr_onnx_loaded', False):
779
- results = self.detect_with_rtdetr_onnx(
780
- image_path=image_path,
781
- confidence=confidence,
782
- return_all_bubbles=True
783
- )
784
- return results
785
- if self.rtdetr_loaded:
786
- results = self.detect_with_rtdetr(
787
- image_path=image_path,
788
- confidence=confidence,
789
- return_all_bubbles=True
790
- )
791
- return results
792
-
793
- # Original YOLOv8 detection
794
- if not self.model_loaded:
795
- logger.error("No model loaded. Call load_model() first.")
796
- return []
797
-
798
- # Use defaults if not specified
799
- confidence = confidence or self.default_confidence
800
- iou_threshold = iou_threshold or self.default_iou_threshold
801
- max_detections = max_detections or self.default_max_detections
802
-
803
- try:
804
- # Load image
805
- image = cv2.imread(image_path)
806
- if image is None:
807
- logger.error(f"Failed to load image: {image_path}")
808
- return []
809
-
810
- h, w = image.shape[:2]
811
- self._log(f"🔍 Detecting bubbles in {w}x{h} image")
812
-
813
- # Check for stop before inference
814
- if self._check_stop():
815
- self._log("⏹️ Bubble detection inference stopped by user", "warning")
816
- return []
817
-
818
- if self.model_type == 'yolo':
819
- # YOLOv8 inference
820
- results = self.model(
821
- image_path,
822
- conf=confidence,
823
- iou=iou_threshold,
824
- max_det=min(max_detections, getattr(self, 'max_det_yolo', max_detections)),
825
- verbose=False
826
- )
827
-
828
- bubbles = []
829
- for r in results:
830
- if r.boxes is not None:
831
- for box in r.boxes:
832
- # Get box coordinates
833
- x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
834
- x, y = int(x1), int(y1)
835
- width = int(x2 - x1)
836
- height = int(y2 - y1)
837
-
838
- # Get confidence
839
- conf = float(box.conf[0])
840
-
841
- # Add to list
842
- if len(bubbles) < max_detections:
843
- bubbles.append((x, y, width, height))
844
-
845
- logger.debug(f" Bubble: ({x},{y}) {width}x{height} conf={conf:.2f}")
846
-
847
- elif self.model_type == 'onnx':
848
- # ONNX inference
849
- bubbles = self._detect_with_onnx(image, confidence, iou_threshold, max_detections)
850
-
851
- elif self.model_type == 'torch':
852
- # TorchScript inference
853
- bubbles = self._detect_with_torchscript(image, confidence, iou_threshold, max_detections)
854
-
855
- else:
856
- logger.error(f"Unknown model type: {self.model_type}")
857
- return []
858
-
859
- logger.info(f"✅ Detected {len(bubbles)} speech bubbles")
860
- time.sleep(0.1) # Brief pause for stability
861
- logger.debug("💤 Bubble detection pausing briefly for stability")
862
- return bubbles
863
-
864
- except Exception as e:
865
- logger.error(f"Detection failed: {e}")
866
- logger.error(traceback.format_exc())
867
- return []
868
-
869
- def detect_with_rtdetr(self,
870
- image_path: str = None,
871
- image: np.ndarray = None,
872
- confidence: float = None,
873
- return_all_bubbles: bool = False) -> Any:
874
- """
875
- Detect using RT-DETR model with 3-class detection (PyTorch backend).
876
-
877
- Args:
878
- image_path: Path to image file
879
- image: Image array (BGR format)
880
- confidence: Confidence threshold
881
- return_all_bubbles: If True, return list of bubble boxes (for compatibility)
882
- If False, return dict with all classes
883
-
884
- Returns:
885
- List of bubbles if return_all_bubbles=True, else dict with classes
886
- """
887
- # Check for stop at start
888
- if self._check_stop():
889
- self._log("⏹️ RT-DETR detection stopped by user", "warning")
890
- if return_all_bubbles:
891
- return []
892
- return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
893
-
894
- if not self.rtdetr_loaded:
895
- self._log("RT-DETR not loaded. Call load_rtdetr_model() first.", "warning")
896
- if return_all_bubbles:
897
- return []
898
- return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
899
-
900
- confidence = confidence or self.default_confidence
901
-
902
- try:
903
- # Load image
904
- if image_path:
905
- image = cv2.imread(image_path)
906
- elif image is None:
907
- logger.error("No image provided")
908
- if return_all_bubbles:
909
- return []
910
- return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
911
-
912
- # Convert BGR to RGB for PIL
913
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
914
- pil_image = Image.fromarray(image_rgb)
915
-
916
- # Prepare image for model
917
- inputs = self.rtdetr_processor(images=pil_image, return_tensors="pt")
918
-
919
- # Move inputs to the same device as the model and match model dtype for floating tensors
920
- model_device = next(self.rtdetr_model.parameters()).device if self.rtdetr_model is not None else (torch.device('cpu') if TORCH_AVAILABLE else 'cpu')
921
- model_dtype = None
922
- if TORCH_AVAILABLE and self.rtdetr_model is not None:
923
- try:
924
- model_dtype = next(self.rtdetr_model.parameters()).dtype
925
- except Exception:
926
- model_dtype = None
927
-
928
- if TORCH_AVAILABLE:
929
- new_inputs = {}
930
- for k, v in inputs.items():
931
- if isinstance(v, torch.Tensor):
932
- v = v.to(model_device)
933
- if model_dtype is not None and torch.is_floating_point(v):
934
- v = v.to(model_dtype)
935
- new_inputs[k] = v
936
- inputs = new_inputs
937
-
938
- # Run inference with autocast when model is half/bfloat16 on CUDA
939
- use_amp = TORCH_AVAILABLE and hasattr(model_device, 'type') and model_device.type == 'cuda' and (model_dtype in (torch.float16, torch.bfloat16))
940
- autocast_dtype = model_dtype if model_dtype in (torch.float16, torch.bfloat16) else None
941
-
942
- with torch.no_grad():
943
- if use_amp and autocast_dtype is not None:
944
- with torch.autocast('cuda', dtype=autocast_dtype):
945
- outputs = self.rtdetr_model(**inputs)
946
- else:
947
- outputs = self.rtdetr_model(**inputs)
948
-
949
- # Brief pause for stability after inference
950
- time.sleep(0.1)
951
- logger.debug("💤 RT-DETR inference pausing briefly for stability")
952
-
953
- # Post-process results
954
- target_sizes = torch.tensor([pil_image.size[::-1]]) if TORCH_AVAILABLE else None
955
- if TORCH_AVAILABLE and hasattr(model_device, 'type') and model_device.type == "cuda":
956
- target_sizes = target_sizes.to(model_device)
957
-
958
- results = self.rtdetr_processor.post_process_object_detection(
959
- outputs,
960
- target_sizes=target_sizes,
961
- threshold=confidence
962
- )[0]
963
-
964
- # Apply per-detector cap if configured
965
- cap = getattr(self, 'max_det_rtdetr', self.default_max_detections)
966
- if cap and len(results['boxes']) > cap:
967
- # Keep top-scoring first
968
- scores = results['scores']
969
- top_idx = scores.topk(k=cap).indices if hasattr(scores, 'topk') else range(cap)
970
- results = {
971
- 'boxes': [results['boxes'][i] for i in top_idx],
972
- 'scores': [results['scores'][i] for i in top_idx],
973
- 'labels': [results['labels'][i] for i in top_idx]
974
- }
975
-
976
- logger.info(f"📊 RT-DETR found {len(results['boxes'])} detections above {confidence:.2f} confidence")
977
-
978
- # Organize detections by class
979
- detections = {
980
- 'bubbles': [], # Empty speech bubbles
981
- 'text_bubbles': [], # Bubbles with text
982
- 'text_free': [] # Text without bubbles
983
- }
984
-
985
- for box, score, label in zip(results['boxes'], results['scores'], results['labels']):
986
- x1, y1, x2, y2 = map(int, box.tolist())
987
- width = x2 - x1
988
- height = y2 - y1
989
-
990
- # Store as (x, y, width, height) to match YOLOv8 format
991
- bbox = (x1, y1, width, height)
992
-
993
- label_id = label.item()
994
- if label_id == self.CLASS_BUBBLE:
995
- detections['bubbles'].append(bbox)
996
- elif label_id == self.CLASS_TEXT_BUBBLE:
997
- detections['text_bubbles'].append(bbox)
998
- elif label_id == self.CLASS_TEXT_FREE:
999
- detections['text_free'].append(bbox)
1000
-
1001
- # Stop early if we hit the configured cap across all classes
1002
- total_count = len(detections['bubbles']) + len(detections['text_bubbles']) + len(detections['text_free'])
1003
- if total_count >= (self.config.get('manga_settings', {}).get('ocr', {}).get('bubble_max_detections', self.default_max_detections) if isinstance(self.config, dict) else self.default_max_detections):
1004
- break
1005
-
1006
- # Log results
1007
- total = len(detections['bubbles']) + len(detections['text_bubbles']) + len(detections['text_free'])
1008
- logger.info(f"✅ RT-DETR detected {total} objects:")
1009
- logger.info(f" - Empty bubbles: {len(detections['bubbles'])}")
1010
- logger.info(f" - Text bubbles: {len(detections['text_bubbles'])}")
1011
- logger.info(f" - Free text: {len(detections['text_free'])}")
1012
-
1013
- # Return format based on compatibility mode
1014
- if return_all_bubbles:
1015
- # Return all bubbles (empty + with text) for backward compatibility
1016
- all_bubbles = detections['bubbles'] + detections['text_bubbles']
1017
- return all_bubbles
1018
- else:
1019
- return detections
1020
-
1021
- except Exception as e:
1022
- logger.error(f"RT-DETR detection failed: {e}")
1023
- logger.error(traceback.format_exc())
1024
- if return_all_bubbles:
1025
- return []
1026
- return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1027
-
1028
- def detect_all_text_regions(self, image_path: str = None, image: np.ndarray = None) -> List[Tuple[int, int, int, int]]:
1029
- """
1030
- Detect all text regions using RT-DETR (both in bubbles and free text).
1031
-
1032
- Returns:
1033
- List of bounding boxes for all text regions
1034
- """
1035
- if not self.rtdetr_loaded:
1036
- logger.warning("RT-DETR required for text detection")
1037
- return []
1038
-
1039
- detections = self.detect_with_rtdetr(image_path=image_path, image=image, return_all_bubbles=False)
1040
-
1041
- # Combine text bubbles and free text
1042
- all_text = detections['text_bubbles'] + detections['text_free']
1043
-
1044
- logger.info(f"📝 Found {len(all_text)} text regions total")
1045
- return all_text
1046
-
1047
- def _detect_with_onnx(self, image: np.ndarray, confidence: float,
1048
- iou_threshold: float, max_detections: int) -> List[Tuple[int, int, int, int]]:
1049
- """Run detection using ONNX model."""
1050
- # Preprocess image
1051
- img_size = 640 # Standard YOLOv8 input size
1052
- img_resized = cv2.resize(image, (img_size, img_size))
1053
- img_norm = img_resized.astype(np.float32) / 255.0
1054
- img_transposed = np.transpose(img_norm, (2, 0, 1))
1055
- img_batch = np.expand_dims(img_transposed, axis=0)
1056
-
1057
- # Run inference
1058
- input_name = self.onnx_session.get_inputs()[0].name
1059
- outputs = self.onnx_session.run(None, {input_name: img_batch})
1060
-
1061
- # Process outputs (YOLOv8 format)
1062
- predictions = outputs[0][0] # Remove batch dimension
1063
-
1064
- # Filter by confidence and apply NMS
1065
- bubbles = []
1066
- boxes = []
1067
- scores = []
1068
-
1069
- for pred in predictions.T: # Transpose to get predictions per detection
1070
- if len(pred) >= 5:
1071
- x_center, y_center, width, height, obj_conf = pred[:5]
1072
-
1073
- if obj_conf >= confidence:
1074
- # Convert to corner coordinates
1075
- x1 = x_center - width / 2
1076
- y1 = y_center - height / 2
1077
-
1078
- # Scale to original image size
1079
- h, w = image.shape[:2]
1080
- x1 = int(x1 * w / img_size)
1081
- y1 = int(y1 * h / img_size)
1082
- width = int(width * w / img_size)
1083
- height = int(height * h / img_size)
1084
-
1085
- boxes.append([x1, y1, x1 + width, y1 + height])
1086
- scores.append(float(obj_conf))
1087
-
1088
- # Apply NMS
1089
- if boxes:
1090
- indices = cv2.dnn.NMSBoxes(boxes, scores, confidence, iou_threshold)
1091
- if len(indices) > 0:
1092
- indices = indices.flatten()[:max_detections]
1093
- for i in indices:
1094
- x1, y1, x2, y2 = boxes[i]
1095
- bubbles.append((x1, y1, x2 - x1, y2 - y1))
1096
-
1097
- return bubbles
1098
-
1099
- def _detect_with_torchscript(self, image: np.ndarray, confidence: float,
1100
- iou_threshold: float, max_detections: int) -> List[Tuple[int, int, int, int]]:
1101
- """Run detection using TorchScript model."""
1102
- # Similar to ONNX but using PyTorch tensors
1103
- img_size = 640
1104
- img_resized = cv2.resize(image, (img_size, img_size))
1105
- img_norm = img_resized.astype(np.float32) / 255.0
1106
- img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0)
1107
-
1108
- if self.use_gpu:
1109
- img_tensor = img_tensor.cuda()
1110
-
1111
- with torch.no_grad():
1112
- outputs = self.model(img_tensor)
1113
-
1114
- # Process outputs similar to ONNX
1115
- # Implementation depends on exact model output format
1116
- # This is a placeholder - adjust based on your model
1117
- return []
1118
-
1119
- def visualize_detections(self, image_path: str, bubbles: List[Tuple[int, int, int, int]] = None,
1120
- output_path: str = None, use_rtdetr: bool = False) -> np.ndarray:
1121
- """
1122
- Visualize detected bubbles on the image.
1123
-
1124
- Args:
1125
- image_path: Path to original image
1126
- bubbles: List of bubble bounding boxes (if None, will detect)
1127
- output_path: Optional path to save visualization
1128
- use_rtdetr: Use RT-DETR for visualization with class colors
1129
-
1130
- Returns:
1131
- Image with drawn bounding boxes
1132
- """
1133
- image = cv2.imread(image_path)
1134
- if image is None:
1135
- logger.error(f"Failed to load image: {image_path}")
1136
- return None
1137
-
1138
- vis_image = image.copy()
1139
-
1140
- if use_rtdetr and self.rtdetr_loaded:
1141
- # RT-DETR visualization with different colors per class
1142
- detections = self.detect_with_rtdetr(image_path=image_path, return_all_bubbles=False)
1143
-
1144
- # Colors for each class
1145
- colors = {
1146
- 'bubbles': (0, 255, 0), # Green for empty bubbles
1147
- 'text_bubbles': (255, 0, 0), # Blue for text bubbles
1148
- 'text_free': (0, 0, 255) # Red for free text
1149
- }
1150
-
1151
- # Draw detections
1152
- for class_name, bboxes in detections.items():
1153
- color = colors[class_name]
1154
-
1155
- for i, (x, y, w, h) in enumerate(bboxes):
1156
- # Draw rectangle
1157
- cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
1158
-
1159
- # Add label
1160
- label = f"{class_name.replace('_', ' ').title()} {i+1}"
1161
- label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
1162
- cv2.rectangle(vis_image, (x, y - label_size[1] - 4),
1163
- (x + label_size[0], y), color, -1)
1164
- cv2.putText(vis_image, label, (x, y - 2),
1165
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
1166
- else:
1167
- # Original YOLOv8 visualization
1168
- if bubbles is None:
1169
- bubbles = self.detect_bubbles(image_path)
1170
-
1171
- # Draw bounding boxes
1172
- for i, (x, y, w, h) in enumerate(bubbles):
1173
- # Draw rectangle
1174
- color = (0, 255, 0) # Green
1175
- thickness = 2
1176
- cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, thickness)
1177
-
1178
- # Add label
1179
- label = f"Bubble {i+1}"
1180
- label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
1181
- cv2.rectangle(vis_image, (x, y - label_size[1] - 4), (x + label_size[0], y), color, -1)
1182
- cv2.putText(vis_image, label, (x, y - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
1183
-
1184
- # Save if output path provided
1185
- if output_path:
1186
- cv2.imwrite(output_path, vis_image)
1187
- logger.info(f"💾 Visualization saved to: {output_path}")
1188
-
1189
- return vis_image
1190
-
1191
- def convert_to_onnx(self, model_path: str, output_path: str = None) -> bool:
1192
- """
1193
- Convert a YOLOv8 or RT-DETR model to ONNX format.
1194
-
1195
- Args:
1196
- model_path: Path to model file or 'rtdetr' for loaded RT-DETR
1197
- output_path: Path for ONNX output (auto-generated if None)
1198
-
1199
- Returns:
1200
- True if conversion successful, False otherwise
1201
- """
1202
- try:
1203
- logger.info(f"🔄 Converting {model_path} to ONNX...")
1204
-
1205
- # Generate output path if not provided
1206
- if output_path is None:
1207
- if model_path == 'rtdetr' and self.rtdetr_loaded:
1208
- base_name = 'rtdetr_comic'
1209
- else:
1210
- base_name = Path(model_path).stem
1211
- output_path = os.path.join(self.cache_dir, f"{base_name}.onnx")
1212
-
1213
- # Check if already exists
1214
- if os.path.exists(output_path) and not os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
1215
- logger.info(f"✅ ONNX model already exists: {output_path}")
1216
- return True
1217
-
1218
- # Handle RT-DETR conversion
1219
- if model_path == 'rtdetr' and self.rtdetr_loaded:
1220
- if not TORCH_AVAILABLE:
1221
- logger.error("PyTorch required for RT-DETR ONNX conversion")
1222
- return False
1223
-
1224
- # RT-DETR specific conversion
1225
- self.rtdetr_model.eval()
1226
-
1227
- # Create dummy input (pixel values): BxCxHxW
1228
- dummy_input = torch.randn(1, 3, 640, 640)
1229
- if self.device == 'cuda':
1230
- dummy_input = dummy_input.to('cuda')
1231
-
1232
- # Wrap the model to return only tensors (logits, pred_boxes)
1233
- class _RTDetrExportWrapper(torch.nn.Module):
1234
- def __init__(self, mdl):
1235
- super().__init__()
1236
- self.mdl = mdl
1237
- def forward(self, images):
1238
- out = self.mdl(pixel_values=images)
1239
- # Handle dict/ModelOutput/tuple outputs
1240
- logits = None
1241
- boxes = None
1242
- try:
1243
- if isinstance(out, dict):
1244
- logits = out.get('logits', None)
1245
- boxes = out.get('pred_boxes', out.get('boxes', None))
1246
- else:
1247
- logits = getattr(out, 'logits', None)
1248
- boxes = getattr(out, 'pred_boxes', getattr(out, 'boxes', None))
1249
- except Exception:
1250
- pass
1251
- if (logits is None or boxes is None) and isinstance(out, (tuple, list)) and len(out) >= 2:
1252
- logits, boxes = out[0], out[1]
1253
- return logits, boxes
1254
-
1255
- wrapper = _RTDetrExportWrapper(self.rtdetr_model)
1256
- if self.device == 'cuda':
1257
- wrapper = wrapper.to('cuda')
1258
-
1259
- # Try PyTorch 2.x dynamo_export first (more tolerant of newer aten ops)
1260
- try:
1261
- success = False
1262
- try:
1263
- from torch.onnx import dynamo_export
1264
- try:
1265
- exp = dynamo_export(wrapper, dummy_input)
1266
- except TypeError:
1267
- # Older PyTorch dynamo_export may not support this calling convention
1268
- exp = dynamo_export(wrapper, dummy_input)
1269
- # exp may have save(); otherwise, it may expose model_proto
1270
- try:
1271
- exp.save(output_path) # type: ignore
1272
- success = True
1273
- except Exception:
1274
- try:
1275
- import onnx as _onnx
1276
- _onnx.save(exp.model_proto, output_path) # type: ignore
1277
- success = True
1278
- except Exception as _se:
1279
- logger.warning(f"dynamo_export produced model but could not save: {_se}")
1280
- except Exception as de:
1281
- logger.warning(f"dynamo_export failed; falling back to legacy exporter: {de}")
1282
- if success:
1283
- logger.info(f"✅ RT-DETR ONNX saved to: {output_path} (dynamo_export)")
1284
- return True
1285
- except Exception as de2:
1286
- logger.warning(f"dynamo_export path error: {de2}")
1287
-
1288
- # Legacy exporter with opset fallback
1289
- last_err = None
1290
- for opset in [19, 18, 17, 16, 15, 14, 13]:
1291
- try:
1292
- torch.onnx.export(
1293
- wrapper,
1294
- dummy_input,
1295
- output_path,
1296
- export_params=True,
1297
- opset_version=opset,
1298
- do_constant_folding=True,
1299
- input_names=['pixel_values'],
1300
- output_names=['logits', 'boxes'],
1301
- dynamic_axes={
1302
- 'pixel_values': {0: 'batch', 2: 'height', 3: 'width'},
1303
- 'logits': {0: 'batch'},
1304
- 'boxes': {0: 'batch'}
1305
- }
1306
- )
1307
- logger.info(f"✅ RT-DETR ONNX saved to: {output_path} (opset {opset})")
1308
- return True
1309
- except Exception as _e:
1310
- last_err = _e
1311
- try:
1312
- msg = str(_e)
1313
- except Exception:
1314
- msg = ''
1315
- logger.warning(f"RT-DETR ONNX export failed at opset {opset}: {msg}")
1316
- continue
1317
-
1318
- logger.error(f"All RT-DETR ONNX export attempts failed. Last error: {last_err}")
1319
- return False
1320
-
1321
- # Handle YOLOv8 conversion - FIXED
1322
- elif YOLO_AVAILABLE and os.path.exists(model_path):
1323
- logger.info(f"Loading YOLOv8 model from: {model_path}")
1324
-
1325
- # Load model
1326
- model = YOLO(model_path)
1327
-
1328
- # Export to ONNX - this returns the path to the exported model
1329
- logger.info("Exporting to ONNX format...")
1330
- exported_path = model.export(format='onnx', imgsz=640, simplify=True)
1331
-
1332
- # exported_path could be a string or Path object
1333
- exported_path = str(exported_path) if exported_path else None
1334
-
1335
- if exported_path and os.path.exists(exported_path):
1336
- # Move to desired location if different
1337
- if exported_path != output_path:
1338
- import shutil
1339
- logger.info(f"Moving ONNX from {exported_path} to {output_path}")
1340
- shutil.move(exported_path, output_path)
1341
-
1342
- logger.info(f"✅ YOLOv8 ONNX saved to: {output_path}")
1343
- return True
1344
- else:
1345
- # Fallback: check if it was created with expected name
1346
- expected_onnx = model_path.replace('.pt', '.onnx')
1347
- if os.path.exists(expected_onnx):
1348
- if expected_onnx != output_path:
1349
- import shutil
1350
- shutil.move(expected_onnx, output_path)
1351
- logger.info(f"✅ YOLOv8 ONNX saved to: {output_path}")
1352
- return True
1353
- else:
1354
- logger.error(f"ONNX export failed - no output file found")
1355
- return False
1356
-
1357
- else:
1358
- logger.error(f"Cannot convert {model_path}: Model not found or dependencies missing")
1359
- return False
1360
-
1361
- except Exception as e:
1362
- logger.error(f"Conversion failed: {e}")
1363
- # Avoid noisy full stack trace in production logs; return False gracefully
1364
- return False
1365
-
1366
- def batch_detect(self, image_paths: List[str], **kwargs) -> Dict[str, List[Tuple[int, int, int, int]]]:
1367
- """
1368
- Detect bubbles in multiple images.
1369
-
1370
- Args:
1371
- image_paths: List of image paths
1372
- **kwargs: Detection parameters (confidence, iou_threshold, max_detections, use_rtdetr)
1373
-
1374
- Returns:
1375
- Dictionary mapping image paths to bubble lists
1376
- """
1377
- results = {}
1378
-
1379
- for i, image_path in enumerate(image_paths):
1380
- logger.info(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}")
1381
- bubbles = self.detect_bubbles(image_path, **kwargs)
1382
- results[image_path] = bubbles
1383
-
1384
- return results
1385
-
1386
- def unload(self, release_shared: bool = False):
1387
- """Release model resources held by this detector instance.
1388
- Args:
1389
- release_shared: If True, also clear class-level shared RT-DETR caches.
1390
- """
1391
- try:
1392
- # Release instance-level models and sessions
1393
- try:
1394
- if getattr(self, 'onnx_session', None) is not None:
1395
- self.onnx_session = None
1396
- except Exception:
1397
- pass
1398
- try:
1399
- if getattr(self, 'rtdetr_onnx_session', None) is not None:
1400
- self.rtdetr_onnx_session = None
1401
- except Exception:
1402
- pass
1403
- for attr in ['model', 'rtdetr_model', 'rtdetr_processor']:
1404
- try:
1405
- if hasattr(self, attr):
1406
- setattr(self, attr, None)
1407
- except Exception:
1408
- pass
1409
- for flag in ['model_loaded', 'rtdetr_loaded', 'rtdetr_onnx_loaded']:
1410
- try:
1411
- if hasattr(self, flag):
1412
- setattr(self, flag, False)
1413
- except Exception:
1414
- pass
1415
-
1416
- # Optional: release shared caches
1417
- if release_shared:
1418
- try:
1419
- BubbleDetector._rtdetr_shared_model = None
1420
- BubbleDetector._rtdetr_shared_processor = None
1421
- BubbleDetector._rtdetr_loaded = False
1422
- except Exception:
1423
- pass
1424
-
1425
- # Free CUDA cache and trigger GC
1426
- try:
1427
- if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
1428
- torch.cuda.empty_cache()
1429
- except Exception:
1430
- pass
1431
- try:
1432
- import gc
1433
- gc.collect()
1434
- except Exception:
1435
- pass
1436
- except Exception:
1437
- # Best-effort only
1438
- pass
1439
-
1440
- def get_bubble_masks(self, image_path: str, bubbles: List[Tuple[int, int, int, int]]) -> np.ndarray:
1441
- """
1442
- Create a mask image with bubble regions.
1443
-
1444
- Args:
1445
- image_path: Path to original image
1446
- bubbles: List of bubble bounding boxes
1447
-
1448
- Returns:
1449
- Binary mask with bubble regions as white (255)
1450
- """
1451
- image = cv2.imread(image_path)
1452
- if image is None:
1453
- return None
1454
-
1455
- h, w = image.shape[:2]
1456
- mask = np.zeros((h, w), dtype=np.uint8)
1457
-
1458
- # Fill bubble regions
1459
- for x, y, bw, bh in bubbles:
1460
- cv2.rectangle(mask, (x, y), (x + bw, y + bh), 255, -1)
1461
-
1462
- return mask
1463
-
1464
- def filter_bubbles_by_size(self, bubbles: List[Tuple[int, int, int, int]],
1465
- min_area: int = 100,
1466
- max_area: int = None) -> List[Tuple[int, int, int, int]]:
1467
- """
1468
- Filter bubbles by area.
1469
-
1470
- Args:
1471
- bubbles: List of bubble bounding boxes
1472
- min_area: Minimum area in pixels
1473
- max_area: Maximum area in pixels (None for no limit)
1474
-
1475
- Returns:
1476
- Filtered list of bubbles
1477
- """
1478
- filtered = []
1479
-
1480
- for x, y, w, h in bubbles:
1481
- area = w * h
1482
- if area >= min_area and (max_area is None or area <= max_area):
1483
- filtered.append((x, y, w, h))
1484
-
1485
- return filtered
1486
-
1487
- def merge_overlapping_bubbles(self, bubbles: List[Tuple[int, int, int, int]],
1488
- overlap_threshold: float = 0.1) -> List[Tuple[int, int, int, int]]:
1489
- """
1490
- Merge overlapping bubble detections.
1491
-
1492
- Args:
1493
- bubbles: List of bubble bounding boxes
1494
- overlap_threshold: Minimum overlap ratio to merge
1495
-
1496
- Returns:
1497
- Merged list of bubbles
1498
- """
1499
- if not bubbles:
1500
- return []
1501
-
1502
- # Convert to numpy array for easier manipulation
1503
- boxes = np.array([(x, y, x+w, y+h) for x, y, w, h in bubbles])
1504
-
1505
- merged = []
1506
- used = set()
1507
-
1508
- for i, box1 in enumerate(boxes):
1509
- if i in used:
1510
- continue
1511
-
1512
- # Start with current box
1513
- x1, y1, x2, y2 = box1
1514
-
1515
- # Check for overlaps with remaining boxes
1516
- for j in range(i + 1, len(boxes)):
1517
- if j in used:
1518
- continue
1519
-
1520
- box2 = boxes[j]
1521
-
1522
- # Calculate intersection
1523
- ix1 = max(x1, box2[0])
1524
- iy1 = max(y1, box2[1])
1525
- ix2 = min(x2, box2[2])
1526
- iy2 = min(y2, box2[3])
1527
-
1528
- if ix1 < ix2 and iy1 < iy2:
1529
- # Calculate overlap ratio
1530
- intersection = (ix2 - ix1) * (iy2 - iy1)
1531
- area1 = (x2 - x1) * (y2 - y1)
1532
- area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
1533
- overlap = intersection / min(area1, area2)
1534
-
1535
- if overlap >= overlap_threshold:
1536
- # Merge boxes
1537
- x1 = min(x1, box2[0])
1538
- y1 = min(y1, box2[1])
1539
- x2 = max(x2, box2[2])
1540
- y2 = max(y2, box2[3])
1541
- used.add(j)
1542
-
1543
- merged.append((int(x1), int(y1), int(x2 - x1), int(y2 - y1)))
1544
-
1545
- return merged
1546
-
1547
- # ============================
1548
- # RT-DETR (ONNX) BACKEND
1549
- # ============================
1550
- def load_rtdetr_onnx_model(self, model_id: str = None, force_reload: bool = False) -> bool:
1551
- """
1552
- Load RT-DETR ONNX model using onnxruntime. Downloads detector.onnx and config.json
1553
- from the provided Hugging Face repo if not already cached.
1554
- """
1555
- if not ONNX_AVAILABLE:
1556
- logger.error("ONNX Runtime not available for RT-DETR ONNX backend")
1557
- return False
1558
- try:
1559
- # If singleton mode and already loaded, just attach shared session
1560
- try:
1561
- adv = (self.config or {}).get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
1562
- singleton = bool(adv.get('use_singleton_models', True))
1563
- except Exception:
1564
- singleton = True
1565
- if singleton and BubbleDetector._rtdetr_onnx_loaded and not force_reload and BubbleDetector._rtdetr_onnx_shared_session is not None:
1566
- self.rtdetr_onnx_session = BubbleDetector._rtdetr_onnx_shared_session
1567
- self.rtdetr_onnx_loaded = True
1568
- return True
1569
-
1570
- repo = model_id or self.rtdetr_onnx_repo
1571
- try:
1572
- from huggingface_hub import hf_hub_download
1573
- except Exception as e:
1574
- logger.error(f"huggingface-hub required to fetch RT-DETR ONNX: {e}")
1575
- return False
1576
-
1577
- # Ensure local models dir (use configured cache_dir directly: e.g., 'models')
1578
- cache_dir = self.cache_dir
1579
- os.makedirs(cache_dir, exist_ok=True)
1580
-
1581
- # Download files into models/ and avoid symlinks so the file is visible there
1582
- try:
1583
- _ = hf_hub_download(repo_id=repo, filename='config.json', cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False)
1584
- except Exception:
1585
- pass
1586
- onnx_fp = hf_hub_download(repo_id=repo, filename='detector.onnx', cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False)
1587
- BubbleDetector._rtdetr_onnx_model_path = onnx_fp
1588
-
1589
- # Pick providers: prefer CUDA if available; otherwise CPU. Do NOT use DML.
1590
- providers = ['CPUExecutionProvider']
1591
- try:
1592
- avail = ort.get_available_providers() if ONNX_AVAILABLE else []
1593
- if 'CUDAExecutionProvider' in avail:
1594
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
1595
- except Exception:
1596
- pass
1597
-
1598
- # Session options with reduced memory arena and optional thread limiting in singleton mode
1599
- so = ort.SessionOptions()
1600
- try:
1601
- so.enable_mem_pattern = False
1602
- so.enable_cpu_mem_arena = False
1603
- except Exception:
1604
- pass
1605
- # If singleton models mode is enabled in config, limit ORT threading to reduce CPU spikes
1606
- try:
1607
- adv = (self.config or {}).get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
1608
- if bool(adv.get('use_singleton_models', True)):
1609
- so.intra_op_num_threads = 1
1610
- so.inter_op_num_threads = 1
1611
- try:
1612
- so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
1613
- except Exception:
1614
- pass
1615
- try:
1616
- so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
1617
- except Exception:
1618
- pass
1619
- except Exception:
1620
- pass
1621
-
1622
- # Create session (serialize creation in singleton mode to avoid device storms)
1623
- if singleton:
1624
- with BubbleDetector._rtdetr_onnx_init_lock:
1625
- # Re-check after acquiring lock
1626
- if BubbleDetector._rtdetr_onnx_loaded and BubbleDetector._rtdetr_onnx_shared_session is not None and not force_reload:
1627
- self.rtdetr_onnx_session = BubbleDetector._rtdetr_onnx_shared_session
1628
- self.rtdetr_onnx_loaded = True
1629
- return True
1630
- sess = ort.InferenceSession(onnx_fp, providers=providers, sess_options=so)
1631
- BubbleDetector._rtdetr_onnx_shared_session = sess
1632
- BubbleDetector._rtdetr_onnx_loaded = True
1633
- BubbleDetector._rtdetr_onnx_providers = providers
1634
- self.rtdetr_onnx_session = sess
1635
- self.rtdetr_onnx_loaded = True
1636
- else:
1637
- self.rtdetr_onnx_session = ort.InferenceSession(onnx_fp, providers=providers, sess_options=so)
1638
- self.rtdetr_onnx_loaded = True
1639
- logger.info("✅ RT-DETR (ONNX) model ready")
1640
- return True
1641
- except Exception as e:
1642
- logger.error(f"Failed to load RT-DETR ONNX: {e}")
1643
- self.rtdetr_onnx_session = None
1644
- self.rtdetr_onnx_loaded = False
1645
- return False
1646
-
1647
- def detect_with_rtdetr_onnx(self,
1648
- image_path: str = None,
1649
- image: np.ndarray = None,
1650
- confidence: float = 0.3,
1651
- return_all_bubbles: bool = False) -> Any:
1652
- """Detect using RT-DETR ONNX backend.
1653
- Returns bubbles list if return_all_bubbles else dict by classes similar to PyTorch path.
1654
- """
1655
- if not self.rtdetr_onnx_loaded or self.rtdetr_onnx_session is None:
1656
- logger.warning("RT-DETR ONNX not loaded")
1657
- return [] if return_all_bubbles else {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1658
- try:
1659
- # Acquire image
1660
- if image_path is not None:
1661
- import cv2
1662
- image = cv2.imread(image_path)
1663
- if image is None:
1664
- raise RuntimeError(f"Failed to read image: {image_path}")
1665
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1666
- else:
1667
- if image is None:
1668
- raise RuntimeError("No image provided")
1669
- # Assume image is BGR np.ndarray if from OpenCV
1670
- try:
1671
- import cv2
1672
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1673
- except Exception:
1674
- image_rgb = image
1675
-
1676
- # To PIL then resize 640x640 as in reference
1677
- from PIL import Image as _PILImage
1678
- pil_image = _PILImage.fromarray(image_rgb)
1679
- im_resized = pil_image.resize((640, 640))
1680
- arr = np.asarray(im_resized, dtype=np.float32) / 255.0
1681
- arr = np.transpose(arr, (2, 0, 1)) # (3,H,W)
1682
- im_data = arr[np.newaxis, ...]
1683
-
1684
- w, h = pil_image.size
1685
- orig_size = np.array([[w, h]], dtype=np.int64)
1686
-
1687
- # Run with a concurrency guard to prevent device hangs and limit memory usage
1688
- # Apply semaphore for ALL providers (not just DML) to control concurrency
1689
- providers = BubbleDetector._rtdetr_onnx_providers or []
1690
- def _do_run(session):
1691
- return session.run(None, {
1692
- 'images': im_data,
1693
- 'orig_target_sizes': orig_size
1694
- })
1695
-
1696
- # Always use semaphore to limit concurrent RT-DETR calls
1697
- acquired = False
1698
- try:
1699
- BubbleDetector._rtdetr_onnx_sema.acquire()
1700
- acquired = True
1701
-
1702
- # Special DML error handling
1703
- if 'DmlExecutionProvider' in providers:
1704
- try:
1705
- outputs = _do_run(self.rtdetr_onnx_session)
1706
- except Exception as dml_err:
1707
- msg = str(dml_err)
1708
- if '887A0005' in msg or '887A0006' in msg or 'Dml' in msg:
1709
- # Rebuild CPU session and retry once
1710
- try:
1711
- base_path = BubbleDetector._rtdetr_onnx_model_path
1712
- if base_path:
1713
- so = ort.SessionOptions()
1714
- so.enable_mem_pattern = False
1715
- so.enable_cpu_mem_arena = False
1716
- cpu_providers = ['CPUExecutionProvider']
1717
- # Serialize rebuild
1718
- with BubbleDetector._rtdetr_onnx_init_lock:
1719
- sess = ort.InferenceSession(base_path, providers=cpu_providers, sess_options=so)
1720
- BubbleDetector._rtdetr_onnx_shared_session = sess
1721
- BubbleDetector._rtdetr_onnx_providers = cpu_providers
1722
- self.rtdetr_onnx_session = sess
1723
- outputs = _do_run(self.rtdetr_onnx_session)
1724
- else:
1725
- raise
1726
- except Exception:
1727
- raise
1728
- else:
1729
- raise
1730
- else:
1731
- # Non-DML providers - just run directly
1732
- outputs = _do_run(self.rtdetr_onnx_session)
1733
- finally:
1734
- if acquired:
1735
- try:
1736
- BubbleDetector._rtdetr_onnx_sema.release()
1737
- except Exception:
1738
- pass
1739
-
1740
- # outputs expected: labels, boxes, scores
1741
- labels, boxes, scores = outputs[:3]
1742
- if labels.ndim == 2 and labels.shape[0] == 1:
1743
- labels = labels[0]
1744
- if scores.ndim == 2 and scores.shape[0] == 1:
1745
- scores = scores[0]
1746
- if boxes.ndim == 3 and boxes.shape[0] == 1:
1747
- boxes = boxes[0]
1748
-
1749
- detections = {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1750
- bubbles_all = []
1751
- for lab, box, scr in zip(labels, boxes, scores):
1752
- if float(scr) < float(confidence):
1753
- continue
1754
- x1, y1, x2, y2 = map(int, box)
1755
- bbox = (x1, y1, x2 - x1, y2 - y1)
1756
- label_id = int(lab)
1757
- if label_id == self.CLASS_BUBBLE:
1758
- detections['bubbles'].append(bbox)
1759
- bubbles_all.append(bbox)
1760
- elif label_id == self.CLASS_TEXT_BUBBLE:
1761
- detections['text_bubbles'].append(bbox)
1762
- bubbles_all.append(bbox)
1763
- elif label_id == self.CLASS_TEXT_FREE:
1764
- detections['text_free'].append(bbox)
1765
-
1766
- return bubbles_all if return_all_bubbles else detections
1767
- except Exception as e:
1768
- logger.error(f"RT-DETR ONNX detection failed: {e}")
1769
- return [] if return_all_bubbles else {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1770
-
1771
-
1772
- # Standalone utility functions
1773
- def download_model_from_huggingface(repo_id: str = "ogkalu/comic-speech-bubble-detector-yolov8m",
1774
- filename: str = "comic-speech-bubble-detector-yolov8m.pt",
1775
- cache_dir: str = "models") -> str:
1776
- """
1777
- Download model from Hugging Face Hub.
1778
-
1779
- Args:
1780
- repo_id: Hugging Face repository ID
1781
- filename: Model filename in the repository
1782
- cache_dir: Local directory to cache the model
1783
-
1784
- Returns:
1785
- Path to downloaded model file
1786
- """
1787
- try:
1788
- from huggingface_hub import hf_hub_download
1789
-
1790
- os.makedirs(cache_dir, exist_ok=True)
1791
-
1792
- logger.info(f"📥 Downloading {filename} from {repo_id}...")
1793
-
1794
- model_path = hf_hub_download(
1795
- repo_id=repo_id,
1796
- filename=filename,
1797
- cache_dir=cache_dir,
1798
- local_dir=cache_dir
1799
- )
1800
-
1801
- logger.info(f"✅ Model downloaded to: {model_path}")
1802
- return model_path
1803
-
1804
- except ImportError:
1805
- logger.error("huggingface-hub package required. Install with: pip install huggingface-hub")
1806
- return None
1807
- except Exception as e:
1808
- logger.error(f"Download failed: {e}")
1809
- return None
1810
-
1811
-
1812
- def download_rtdetr_model(cache_dir: str = "models") -> bool:
1813
- """
1814
- Download RT-DETR model for advanced detection.
1815
-
1816
- Args:
1817
- cache_dir: Directory to cache the model
1818
-
1819
- Returns:
1820
- True if successful
1821
- """
1822
- if not TRANSFORMERS_AVAILABLE:
1823
- logger.error("Transformers required. Install with: pip install transformers")
1824
- return False
1825
-
1826
- try:
1827
- logger.info("📥 Downloading RT-DETR model...")
1828
- from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
1829
-
1830
- # This will download and cache the model
1831
- processor = RTDetrImageProcessor.from_pretrained(
1832
- "ogkalu/comic-text-and-bubble-detector",
1833
- cache_dir=cache_dir
1834
- )
1835
- model = RTDetrForObjectDetection.from_pretrained(
1836
- "ogkalu/comic-text-and-bubble-detector",
1837
- cache_dir=cache_dir
1838
- )
1839
-
1840
- logger.info("✅ RT-DETR model downloaded successfully")
1841
- return True
1842
-
1843
- except Exception as e:
1844
- logger.error(f"Download failed: {e}")
1845
- return False
1846
-
1847
-
1848
- # Example usage and testing
1849
- if __name__ == "__main__":
1850
- import sys
1851
-
1852
- # Create detector
1853
- detector = BubbleDetector()
1854
-
1855
- if len(sys.argv) > 1:
1856
- if sys.argv[1] == "download":
1857
- # Download model from Hugging Face
1858
- model_path = download_model_from_huggingface()
1859
- if model_path:
1860
- print(f"YOLOv8 model downloaded to: {model_path}")
1861
-
1862
- # Also download RT-DETR
1863
- if download_rtdetr_model():
1864
- print("RT-DETR model downloaded")
1865
-
1866
- elif sys.argv[1] == "detect" and len(sys.argv) > 3:
1867
- # Detect bubbles in an image
1868
- model_path = sys.argv[2]
1869
- image_path = sys.argv[3]
1870
-
1871
- # Load appropriate model
1872
- if 'rtdetr' in model_path.lower():
1873
- if detector.load_rtdetr_model():
1874
- # Use RT-DETR
1875
- results = detector.detect_with_rtdetr(image_path)
1876
- print(f"RT-DETR Detection:")
1877
- print(f" Empty bubbles: {len(results['bubbles'])}")
1878
- print(f" Text bubbles: {len(results['text_bubbles'])}")
1879
- print(f" Free text: {len(results['text_free'])}")
1880
- else:
1881
- if detector.load_model(model_path):
1882
- bubbles = detector.detect_bubbles(image_path, confidence=0.5)
1883
- print(f"YOLOv8 detected {len(bubbles)} bubbles:")
1884
- for i, (x, y, w, h) in enumerate(bubbles):
1885
- print(f" Bubble {i+1}: position=({x},{y}) size=({w}x{h})")
1886
-
1887
- # Optionally visualize
1888
- if len(sys.argv) > 4:
1889
- output_path = sys.argv[4]
1890
- detector.visualize_detections(image_path, output_path=output_path,
1891
- use_rtdetr='rtdetr' in model_path.lower())
1892
-
1893
- elif sys.argv[1] == "test-both" and len(sys.argv) > 2:
1894
- # Test both models
1895
- image_path = sys.argv[2]
1896
-
1897
- # Load YOLOv8
1898
- yolo_path = "models/comic-speech-bubble-detector-yolov8m.pt"
1899
- if os.path.exists(yolo_path):
1900
- detector.load_model(yolo_path)
1901
- yolo_bubbles = detector.detect_bubbles(image_path, use_rtdetr=False)
1902
- print(f"YOLOv8: {len(yolo_bubbles)} bubbles")
1903
-
1904
- # Load RT-DETR
1905
- if detector.load_rtdetr_model():
1906
- rtdetr_bubbles = detector.detect_bubbles(image_path, use_rtdetr=True)
1907
- print(f"RT-DETR: {len(rtdetr_bubbles)} bubbles")
1908
-
1909
- else:
1910
- print("Usage:")
1911
- print(" python bubble_detector.py download")
1912
- print(" python bubble_detector.py detect <model_path> <image_path> [output_path]")
1913
- print(" python bubble_detector.py test-both <image_path>")
1914
-
1915
- else:
1916
- print("Bubble Detector Module (YOLOv8 + RT-DETR)")
1917
- print("Usage:")
1918
- print(" python bubble_detector.py download")
1919
- print(" python bubble_detector.py detect <model_path> <image_path> [output_path]")
1920
- print(" python bubble_detector.py test-both <image_path>")