MogensR commited on
Commit
260d38d
·
1 Parent(s): 5916c53

Update models/loaders/model_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/model_loader.py +113 -379
models/loaders/model_loader.py CHANGED
@@ -19,7 +19,6 @@
19
  from pathlib import Path
20
 
21
  import torch
22
- import gradio as gr
23
  from omegaconf import DictConfig, OmegaConf
24
 
25
  # Import modular components - Updated paths for BackgroundFX Pro structure
@@ -45,179 +44,14 @@ def __init__(self, model=None, model_id: str = "", load_time: float = 0.0):
45
  def __repr__(self):
46
  return f"LoadedModel(id={self.model_id}, loaded={self.model is not None})"
47
 
48
- # ============================================================================ #
49
- # HARD CACHE CLEANER
50
- # ============================================================================ #
51
-
52
- class HardCacheCleaner:
53
- """
54
- Comprehensive cache cleaning system to resolve SAM2 loading issues
55
- Clears Python module cache, HuggingFace cache, and temp files
56
- """
57
-
58
- @staticmethod
59
- def clean_all_caches(verbose: bool = True):
60
- """Clean all caches that might interfere with SAM2 loading"""
61
-
62
- if verbose:
63
- logger.info("Starting comprehensive cache cleanup...")
64
-
65
- # 1. Clean Python module cache
66
- HardCacheCleaner._clean_python_cache(verbose)
67
-
68
- # 2. Clean HuggingFace cache
69
- HardCacheCleaner._clean_huggingface_cache(verbose)
70
-
71
- # 3. Clean PyTorch cache
72
- HardCacheCleaner._clean_pytorch_cache(verbose)
73
-
74
- # 4. Clean temp directories
75
- HardCacheCleaner._clean_temp_directories(verbose)
76
-
77
- # 5. Clear import cache
78
- HardCacheCleaner._clear_import_cache(verbose)
79
-
80
- # 6. Force garbage collection
81
- HardCacheCleaner._force_gc_cleanup(verbose)
82
-
83
- if verbose:
84
- logger.info("Cache cleanup completed")
85
-
86
- @staticmethod
87
- def _clean_python_cache(verbose: bool = True):
88
- """Clean Python bytecode cache"""
89
- try:
90
- # Clear sys.modules cache for SAM2 related modules
91
- sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()]
92
- for module in sam2_modules:
93
- if verbose:
94
- logger.info(f"Removing cached module: {module}")
95
- del sys.modules[module]
96
-
97
- # Clear __pycache__ directories
98
- for root, dirs, files in os.walk("."):
99
- for dir_name in dirs[:]: # Use slice to modify list during iteration
100
- if dir_name == "__pycache__":
101
- cache_path = os.path.join(root, dir_name)
102
- if verbose:
103
- logger.info(f"Removing __pycache__: {cache_path}")
104
- shutil.rmtree(cache_path, ignore_errors=True)
105
- dirs.remove(dir_name)
106
-
107
- except Exception as e:
108
- logger.warning(f"Python cache cleanup failed: {e}")
109
-
110
- @staticmethod
111
- def _clean_huggingface_cache(verbose: bool = True):
112
- """Clean HuggingFace model cache"""
113
- try:
114
- cache_paths = [
115
- os.path.expanduser("~/.cache/huggingface/"),
116
- os.path.expanduser("~/.cache/torch/"),
117
- "./checkpoints/",
118
- "./.cache/",
119
- ]
120
-
121
- for cache_path in cache_paths:
122
- if os.path.exists(cache_path):
123
- if verbose:
124
- logger.info(f"Cleaning cache directory: {cache_path}")
125
-
126
- # Remove SAM2 specific files
127
- for root, dirs, files in os.walk(cache_path):
128
- for file in files:
129
- if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']):
130
- file_path = os.path.join(root, file)
131
- try:
132
- os.remove(file_path)
133
- if verbose:
134
- logger.info(f"Removed cached file: {file_path}")
135
- except:
136
- pass
137
-
138
- for dir_name in dirs[:]:
139
- if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']):
140
- dir_path = os.path.join(root, dir_name)
141
- try:
142
- shutil.rmtree(dir_path, ignore_errors=True)
143
- if verbose:
144
- logger.info(f"Removed cached directory: {dir_path}")
145
- dirs.remove(dir_name)
146
- except:
147
- pass
148
-
149
- except Exception as e:
150
- logger.warning(f"HuggingFace cache cleanup failed: {e}")
151
-
152
- @staticmethod
153
- def _clean_pytorch_cache(verbose: bool = True):
154
- """Clean PyTorch cache"""
155
- try:
156
- import torch
157
- if torch.cuda.is_available():
158
- torch.cuda.empty_cache()
159
- if verbose:
160
- logger.info("Cleared PyTorch CUDA cache")
161
- except Exception as e:
162
- logger.warning(f"PyTorch cache cleanup failed: {e}")
163
-
164
- @staticmethod
165
- def _clean_temp_directories(verbose: bool = True):
166
- """Clean temporary directories"""
167
- try:
168
- temp_dirs = [tempfile.gettempdir(), "/tmp", "./tmp", "./temp"]
169
-
170
- for temp_dir in temp_dirs:
171
- if os.path.exists(temp_dir):
172
- for item in os.listdir(temp_dir):
173
- if 'sam2' in item.lower() or 'segment' in item.lower():
174
- item_path = os.path.join(temp_dir, item)
175
- try:
176
- if os.path.isfile(item_path):
177
- os.remove(item_path)
178
- elif os.path.isdir(item_path):
179
- shutil.rmtree(item_path, ignore_errors=True)
180
- if verbose:
181
- logger.info(f"Removed temp item: {item_path}")
182
- except:
183
- pass
184
-
185
- except Exception as e:
186
- logger.warning(f"Temp directory cleanup failed: {e}")
187
-
188
- @staticmethod
189
- def _clear_import_cache(verbose: bool = True):
190
- """Clear Python import cache"""
191
- try:
192
- import importlib
193
-
194
- # Invalidate import caches
195
- importlib.invalidate_caches()
196
-
197
- if verbose:
198
- logger.info("Cleared Python import cache")
199
-
200
- except Exception as e:
201
- logger.warning(f"Import cache cleanup failed: {e}")
202
-
203
- @staticmethod
204
- def _force_gc_cleanup(verbose: bool = True):
205
- """Force garbage collection"""
206
- try:
207
- collected = gc.collect()
208
- if verbose:
209
- logger.info(f"Garbage collection freed {collected} objects")
210
- except Exception as e:
211
- logger.warning(f"Garbage collection failed: {e}")
212
-
213
  # ============================================================================ #
214
  # MODEL LOADER CLASS - MAIN INTERFACE
215
  # ============================================================================ #
216
 
217
  class ModelLoader:
218
  """
219
- Comprehensive model loading and management for SAM2 and MatAnyone
220
- Handles automatic config detection, multiple fallback strategies, and memory management
221
  """
222
 
223
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
@@ -244,31 +78,6 @@ def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
244
  }
245
 
246
  logger.info(f"ModelLoader initialized for device: {self.device}")
247
- self._apply_gradio_patch()
248
-
249
- # ============================================================================ #
250
- # INITIALIZATION AND SETUP
251
- # ============================================================================ #
252
-
253
- def _apply_gradio_patch(self):
254
- """Apply Gradio schema monkey patch to prevent validation errors"""
255
- try:
256
- import gradio.components.base
257
- original_get_config = gradio.components.base.Component.get_config
258
-
259
- def patched_get_config(self):
260
- config = original_get_config(self)
261
- # Remove problematic keys that cause validation errors
262
- config.pop("show_progress_bar", None)
263
- config.pop("min_width", None)
264
- config.pop("scale", None)
265
- return config
266
-
267
- gradio.components.base.Component.get_config = patched_get_config
268
- logger.debug("Applied Gradio schema monkey patch")
269
-
270
- except (ImportError, AttributeError) as e:
271
- logger.warning(f"Could not apply Gradio monkey patch: {e}")
272
 
273
  # ============================================================================ #
274
  # MAIN MODEL LOADING ORCHESTRATION
@@ -296,7 +105,7 @@ def load_all_models(self, progress_callback: Optional[callable] = None, cancel_e
296
  # Clear any existing models
297
  self._cleanup_models()
298
 
299
- # Load SAM2 first (typically faster)
300
  logger.info("Loading SAM2 predictor...")
301
  if progress_callback:
302
  progress_callback(0.1, "Loading SAM2 predictor...")
@@ -304,11 +113,11 @@ def load_all_models(self, progress_callback: Optional[callable] = None, cancel_e
304
  self.sam2_predictor = self._load_sam2_predictor(progress_callback)
305
 
306
  if self.sam2_predictor is None:
307
- raise ModelLoadingError("SAM2", "Failed to load SAM2 predictor")
308
-
309
- sam2_time = time.time() - start_time
310
- self.loading_stats['sam2_load_time'] = sam2_time
311
- logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
312
 
313
  # Load MatAnyone
314
  logger.info("Loading MatAnyone model...")
@@ -320,11 +129,11 @@ def load_all_models(self, progress_callback: Optional[callable] = None, cancel_e
320
  self.matanyone_model, self.matanyone_core = self._load_matanyone_model(progress_callback)
321
 
322
  if self.matanyone_model is None:
323
- raise ModelLoadingError("MatAnyone", "Failed to load MatAnyone model")
324
-
325
- matanyone_time = time.time() - matanyone_start
326
- self.loading_stats['matanyone_load_time'] = matanyone_time
327
- logger.info(f"MatAnyone loaded in {matanyone_time:.1f}s")
328
 
329
  # Final setup
330
  total_time = time.time() - start_time
@@ -332,9 +141,12 @@ def load_all_models(self, progress_callback: Optional[callable] = None, cancel_e
332
  self.loading_stats['models_loaded'] = True
333
 
334
  if progress_callback:
335
- progress_callback(1.0, "Models loaded successfully!")
 
 
 
336
 
337
- logger.info(f"All models loaded successfully in {total_time:.2f}s")
338
 
339
  return self.sam2_predictor, self.matanyone_model
340
 
@@ -352,29 +164,20 @@ def load_all_models(self, progress_callback: Optional[callable] = None, cancel_e
352
  return None, None
353
 
354
  # ============================================================================ #
355
- # SAM2 MODEL LOADING - HUGGINGFACE TRANSFORMERS APPROACH
356
  # ============================================================================ #
357
 
358
  def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
359
  """
360
- Load SAM2 using HuggingFace Transformers integration with cache cleanup
361
- This method works reliably on HuggingFace Spaces without config file issues
362
 
363
  Args:
364
  progress_callback: Progress update callback
365
 
366
  Returns:
367
- SAM2 model or None
368
  """
369
- logger.info("=== USING NEW HF TRANSFORMERS SAM2 LOADER ===")
370
-
371
- # Step 1: Clean caches before loading
372
- if progress_callback:
373
- progress_callback(0.15, "Cleaning caches...")
374
-
375
- HardCacheCleaner.clean_all_caches(verbose=True)
376
-
377
- # Step 2: Determine model size based on device memory
378
  model_size = "large" # default
379
  if hasattr(self.device_manager, 'get_device_memory_gb'):
380
  try:
@@ -382,12 +185,13 @@ def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
382
  if memory_gb < 4:
383
  model_size = "tiny"
384
  elif memory_gb < 8:
 
 
385
  model_size = "base"
386
  logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory")
387
  except Exception as e:
388
  logger.warning(f"Could not determine device memory: {e}")
389
 
390
- # Step 3: Try multiple HuggingFace approaches
391
  model_map = {
392
  "tiny": "facebook/sam2.1-hiera-tiny",
393
  "small": "facebook/sam2.1-hiera-small",
@@ -398,86 +202,57 @@ def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
398
  model_id = model_map.get(model_size, model_map["large"])
399
 
400
  if progress_callback:
401
- progress_callback(0.3, f"Loading SAM2 {model_size}...")
402
-
403
- # Method 1: HuggingFace Transformers Pipeline (most reliable)
404
- try:
405
- logger.info("Trying Transformers pipeline approach...")
406
- from transformers import pipeline
407
-
408
- sam2_pipeline = pipeline(
409
- "mask-generation",
410
- model=model_id,
411
- device=0 if str(self.device) == "cuda" else -1
412
- )
413
-
414
- logger.info("SAM2 loaded successfully via Transformers pipeline")
415
- return sam2_pipeline
416
-
417
- except Exception as e:
418
- logger.warning(f"Pipeline approach failed: {e}")
419
 
420
- # Method 2: Direct Transformers classes
421
  try:
422
- logger.info("Trying direct Transformers classes...")
423
- from transformers import Sam2Processor, Sam2Model
424
-
425
- processor = Sam2Processor.from_pretrained(model_id)
426
- model = Sam2Model.from_pretrained(model_id).to(self.device)
427
-
428
- logger.info("SAM2 loaded successfully via Transformers classes")
429
- return {"model": model, "processor": processor}
430
-
431
- except Exception as e:
432
- logger.warning(f"Direct class approach failed: {e}")
433
-
434
- # Method 3: Official SAM2 with from_pretrained
435
- try:
436
- logger.info("Trying official SAM2 from_pretrained...")
437
  from sam2.sam2_image_predictor import SAM2ImagePredictor
438
 
 
439
  predictor = SAM2ImagePredictor.from_pretrained(model_id)
440
 
 
 
 
 
441
  logger.info("SAM2 loaded successfully via official from_pretrained")
442
  return predictor
443
 
444
- except Exception as e:
445
- logger.warning(f"Official from_pretrained approach failed: {e}")
446
-
447
- # Method 4: Fallback to direct checkpoint download
448
- try:
449
- logger.info("Trying fallback checkpoint approach...")
450
- from huggingface_hub import hf_hub_download
451
- from transformers import Sam2Model
452
-
453
- # Download checkpoint directly
454
- checkpoint_path = hf_hub_download(
455
- repo_id=model_id,
456
- filename="model.safetensors" if "sam2.1" in model_id else "pytorch_model.bin"
457
- )
458
-
459
- logger.info(f"Downloaded checkpoint to: {checkpoint_path}")
460
-
461
- # Load with minimal approach
462
- model = Sam2Model.from_pretrained(model_id)
463
- model = model.to(self.device)
464
-
465
- logger.info("SAM2 loaded successfully via fallback approach")
466
- return model
467
 
468
  except Exception as e:
469
- logger.warning(f"Fallback approach failed: {e}")
470
-
471
- logger.error("All SAM2 loading methods failed")
472
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  # ============================================================================ #
475
- # MATANYONE MODEL LOADING - MULTIPLE STRATEGIES
476
  # ============================================================================ #
477
 
478
  def _load_matanyone_model(self, progress_callback: Optional[callable] = None):
479
  """
480
- Load MatAnyone model with multiple import strategies
481
 
482
  Args:
483
  progress_callback: Progress update callback
@@ -485,84 +260,26 @@ def _load_matanyone_model(self, progress_callback: Optional[callable] = None):
485
  Returns:
486
  Tuple[model, core] or (None, None)
487
  """
488
- import_strategies = [
489
- self._load_matanyone_strategy_1,
490
- self._load_matanyone_strategy_2,
491
- self._load_matanyone_strategy_3,
492
- self._load_matanyone_strategy_4
493
- ]
494
-
495
- for i, strategy in enumerate(import_strategies, 1):
496
- try:
497
- logger.info(f"Trying MatAnyone loading strategy {i}...")
498
- if progress_callback:
499
- progress_callback(0.7 + (i * 0.05), f"MatAnyone strategy {i}...")
500
-
501
- model, core = strategy()
502
- if model is not None and core is not None:
503
- logger.info(f"MatAnyone loaded successfully with strategy {i}")
504
- return model, core
505
-
506
- except Exception as e:
507
- logger.warning(f"MatAnyone strategy {i} failed: {e}")
508
- continue
509
-
510
- logger.error("All MatAnyone loading strategies failed")
511
- return None, None
512
-
513
- # ============================================================================ #
514
- # MATANYONE LOADING STRATEGIES
515
- # ============================================================================ #
516
-
517
- def _load_matanyone_strategy_1(self):
518
- """MatAnyone loading strategy 1: Official HuggingFace InferenceCore"""
519
- from matanyone import InferenceCore
520
-
521
- # Initialize with the official model repo
522
- processor = InferenceCore("PeiqingYang/MatAnyone")
523
- return processor, processor
524
-
525
- def _load_matanyone_strategy_2(self):
526
- """MatAnyone loading strategy 2: Alternative import paths"""
527
- from matanyone import MatAnyOne
528
- from matanyone import InferenceCore
529
-
530
- cfg = OmegaConf.create({
531
- 'model_name': 'matanyone',
532
- 'device': str(self.device)
533
- })
534
-
535
- model = MatAnyOne(cfg)
536
- core = InferenceCore(model, cfg)
537
-
538
- return model, core
539
-
540
- def _load_matanyone_strategy_3(self):
541
- """MatAnyone loading strategy 3: Repository-specific imports"""
542
  try:
543
- from matanyone.models.matanyone import MatAnyOneModel
544
- from matanyone.core import InferenceEngine
 
 
 
 
 
 
 
 
 
 
545
  except ImportError:
546
- from matanyone.src.models import MatAnyOneModel
547
- from matanyone.src.core import InferenceEngine
548
-
549
- config = {
550
- 'model_path': None, # Will use default
551
- 'device': self.device,
552
- 'precision': 'fp16' if self.device.type == 'cuda' else 'fp32'
553
- }
554
-
555
- model = MatAnyOneModel.from_pretrained(config)
556
- engine = InferenceEngine(model)
557
-
558
- return model, engine
559
-
560
- def _load_matanyone_strategy_4(self):
561
- """MatAnyone loading strategy 4: Direct model class"""
562
- from matanyone.model.matanyone import MatAnyone
563
-
564
- model = MatAnyone.from_pretrained("not-lain/matanyone")
565
- return model, model
566
 
567
  # ============================================================================ #
568
  # MODEL MANAGEMENT AND CLEANUP
@@ -583,7 +300,8 @@ def _cleanup_models(self):
583
  self.matanyone_core = None
584
 
585
  # Clear GPU cache
586
- self.memory_manager.cleanup_aggressive()
 
587
  gc.collect()
588
 
589
  logger.debug("Model cleanup completed")
@@ -615,6 +333,10 @@ def get_model_info(self) -> Dict[str, Any]:
615
  if self.sam2_predictor is not None:
616
  try:
617
  info['sam2_model_type'] = type(self.sam2_predictor).__name__
 
 
 
 
618
  except:
619
  info['sam2_model_type'] = "Unknown"
620
 
@@ -639,9 +361,18 @@ def get_load_summary(self) -> str:
639
  matanyone_time = self.loading_stats['matanyone_load_time']
640
  total_time = self.loading_stats['total_load_time']
641
 
642
- summary = f"Models loaded successfully in {total_time:.1f}s\n"
643
- summary += f"SAM2: {sam2_time:.1f}s\n"
644
- summary += f"MatAnyone: {matanyone_time:.1f}s\n"
 
 
 
 
 
 
 
 
 
645
  summary += f"Device: {self.device}"
646
 
647
  return summary
@@ -663,20 +394,27 @@ def validate_models(self) -> bool:
663
  Validate that models are properly loaded and functional
664
 
665
  Returns:
666
- bool: True if models are valid
667
  """
668
  try:
669
- # Basic validation
670
- if not self.loading_stats['models_loaded']:
671
- return False
 
 
 
 
 
 
 
 
672
 
673
- if self.sam2_predictor is None or self.matanyone_model is None:
674
- return False
 
 
675
 
676
- # Try basic model operations
677
- # This could include running a small test inference
678
- logger.info("Model validation passed")
679
- return True
680
 
681
  except Exception as e:
682
  logger.error(f"Model validation failed: {e}")
@@ -704,9 +442,5 @@ def reload_models(self, progress_callback: Optional[callable] = None) -> Tuple[A
704
 
705
  @property
706
  def models_ready(self) -> bool:
707
- """Check if all models are loaded and ready"""
708
- return (
709
- self.loading_stats['models_loaded'] and
710
- self.sam2_predictor is not None and
711
- self.matanyone_model is not None
712
- )
 
19
  from pathlib import Path
20
 
21
  import torch
 
22
  from omegaconf import DictConfig, OmegaConf
23
 
24
  # Import modular components - Updated paths for BackgroundFX Pro structure
 
44
  def __repr__(self):
45
  return f"LoadedModel(id={self.model_id}, loaded={self.model is not None})"
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # ============================================================================ #
48
  # MODEL LOADER CLASS - MAIN INTERFACE
49
  # ============================================================================ #
50
 
51
  class ModelLoader:
52
  """
53
+ Simplified model loading for SAM2 and MatAnyone
54
+ Uses only the working loading strategies without redundant attempts
55
  """
56
 
57
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
 
78
  }
79
 
80
  logger.info(f"ModelLoader initialized for device: {self.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # ============================================================================ #
83
  # MAIN MODEL LOADING ORCHESTRATION
 
105
  # Clear any existing models
106
  self._cleanup_models()
107
 
108
+ # Load SAM2 first
109
  logger.info("Loading SAM2 predictor...")
110
  if progress_callback:
111
  progress_callback(0.1, "Loading SAM2 predictor...")
 
113
  self.sam2_predictor = self._load_sam2_predictor(progress_callback)
114
 
115
  if self.sam2_predictor is None:
116
+ logger.warning("SAM2 loading failed - will use fallback segmentation")
117
+ else:
118
+ sam2_time = time.time() - start_time
119
+ self.loading_stats['sam2_load_time'] = sam2_time
120
+ logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
121
 
122
  # Load MatAnyone
123
  logger.info("Loading MatAnyone model...")
 
129
  self.matanyone_model, self.matanyone_core = self._load_matanyone_model(progress_callback)
130
 
131
  if self.matanyone_model is None:
132
+ logger.warning("MatAnyone loading failed - will use OpenCV refinement")
133
+ else:
134
+ matanyone_time = time.time() - matanyone_start
135
+ self.loading_stats['matanyone_load_time'] = matanyone_time
136
+ logger.info(f"MatAnyone loaded in {matanyone_time:.1f}s")
137
 
138
  # Final setup
139
  total_time = time.time() - start_time
 
141
  self.loading_stats['models_loaded'] = True
142
 
143
  if progress_callback:
144
+ if self.sam2_predictor or self.matanyone_model:
145
+ progress_callback(1.0, "Models loaded (with fallbacks available)")
146
+ else:
147
+ progress_callback(1.0, "Using fallback methods (models failed to load)")
148
 
149
+ logger.info(f"Model loading completed in {total_time:.2f}s")
150
 
151
  return self.sam2_predictor, self.matanyone_model
152
 
 
164
  return None, None
165
 
166
  # ============================================================================ #
167
+ # SAM2 MODEL LOADING - DIRECT OFFICIAL APPROACH ONLY
168
  # ============================================================================ #
169
 
170
  def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
171
  """
172
+ Load SAM2 using only the official from_pretrained method that works
 
173
 
174
  Args:
175
  progress_callback: Progress update callback
176
 
177
  Returns:
178
+ SAM2 predictor or None
179
  """
180
+ # Determine model size based on device memory
 
 
 
 
 
 
 
 
181
  model_size = "large" # default
182
  if hasattr(self.device_manager, 'get_device_memory_gb'):
183
  try:
 
185
  if memory_gb < 4:
186
  model_size = "tiny"
187
  elif memory_gb < 8:
188
+ model_size = "small"
189
+ elif memory_gb < 12:
190
  model_size = "base"
191
  logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory")
192
  except Exception as e:
193
  logger.warning(f"Could not determine device memory: {e}")
194
 
 
195
  model_map = {
196
  "tiny": "facebook/sam2.1-hiera-tiny",
197
  "small": "facebook/sam2.1-hiera-small",
 
202
  model_id = model_map.get(model_size, model_map["large"])
203
 
204
  if progress_callback:
205
+ progress_callback(0.3, f"Loading SAM2 {model_size} model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ # Use ONLY the official SAM2 from_pretrained method that works
208
  try:
209
+ logger.info(f"Loading SAM2 from {model_id}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  from sam2.sam2_image_predictor import SAM2ImagePredictor
211
 
212
+ # This is the method that successfully downloads and loads the model
213
  predictor = SAM2ImagePredictor.from_pretrained(model_id)
214
 
215
+ # Move to correct device if needed
216
+ if hasattr(predictor, 'model'):
217
+ predictor.model = predictor.model.to(self.device)
218
+
219
  logger.info("SAM2 loaded successfully via official from_pretrained")
220
  return predictor
221
 
222
+ except ImportError as e:
223
+ logger.error(f"SAM2 module not found. Install with: pip install sam2")
224
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  except Exception as e:
227
+ logger.error(f"SAM2 loading failed: {e}")
228
+ # Try downloading checkpoint manually as fallback
229
+ try:
230
+ logger.info("Attempting manual checkpoint download...")
231
+ import urllib.request
232
+
233
+ checkpoint_url = f"https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2.1_hiera_{model_size}.pt"
234
+ checkpoint_path = os.path.join(self.checkpoints_dir, f"sam2.1_hiera_{model_size}.pt")
235
+
236
+ if not os.path.exists(checkpoint_path):
237
+ logger.info(f"Downloading checkpoint from {checkpoint_url}")
238
+ urllib.request.urlretrieve(checkpoint_url, checkpoint_path)
239
+
240
+ # Try loading with downloaded checkpoint
241
+ predictor = SAM2ImagePredictor.from_pretrained(model_id, checkpoint=checkpoint_path)
242
+ logger.info("SAM2 loaded successfully with manual checkpoint")
243
+ return predictor
244
+
245
+ except Exception as fallback_error:
246
+ logger.error(f"Manual checkpoint fallback also failed: {fallback_error}")
247
+ return None
248
 
249
  # ============================================================================ #
250
+ # MATANYONE MODEL LOADING
251
  # ============================================================================ #
252
 
253
  def _load_matanyone_model(self, progress_callback: Optional[callable] = None):
254
  """
255
+ Load MatAnyone model - try official method only
256
 
257
  Args:
258
  progress_callback: Progress update callback
 
260
  Returns:
261
  Tuple[model, core] or (None, None)
262
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  try:
264
+ logger.info("Loading MatAnyone from HuggingFace...")
265
+ if progress_callback:
266
+ progress_callback(0.7, "Loading MatAnyone model...")
267
+
268
+ from matanyone import InferenceCore
269
+
270
+ # Initialize with the official model repo
271
+ processor = InferenceCore("PeiqingYang/MatAnyone")
272
+
273
+ logger.info("MatAnyone loaded successfully")
274
+ return processor, processor
275
+
276
  except ImportError:
277
+ logger.error("MatAnyone module not found. Install with: pip install matanyone")
278
+ return None, None
279
+
280
+ except Exception as e:
281
+ logger.error(f"MatAnyone loading failed: {e}")
282
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  # ============================================================================ #
285
  # MODEL MANAGEMENT AND CLEANUP
 
300
  self.matanyone_core = None
301
 
302
  # Clear GPU cache
303
+ if torch.cuda.is_available():
304
+ torch.cuda.empty_cache()
305
  gc.collect()
306
 
307
  logger.debug("Model cleanup completed")
 
333
  if self.sam2_predictor is not None:
334
  try:
335
  info['sam2_model_type'] = type(self.sam2_predictor).__name__
336
+ if hasattr(self.sam2_predictor, 'model'):
337
+ info['sam2_has_model'] = True
338
+ if hasattr(self.sam2_predictor, 'predictor'):
339
+ info['sam2_has_predictor'] = True
340
  except:
341
  info['sam2_model_type'] = "Unknown"
342
 
 
361
  matanyone_time = self.loading_stats['matanyone_load_time']
362
  total_time = self.loading_stats['total_load_time']
363
 
364
+ summary = f"Models loaded in {total_time:.1f}s\n"
365
+
366
+ if self.sam2_predictor:
367
+ summary += f"✓ SAM2: {sam2_time:.1f}s\n"
368
+ else:
369
+ summary += f"✗ SAM2: Failed (using fallback)\n"
370
+
371
+ if self.matanyone_model:
372
+ summary += f"✓ MatAnyone: {matanyone_time:.1f}s\n"
373
+ else:
374
+ summary += f"✗ MatAnyone: Failed (using OpenCV)\n"
375
+
376
  summary += f"Device: {self.device}"
377
 
378
  return summary
 
394
  Validate that models are properly loaded and functional
395
 
396
  Returns:
397
+ bool: True if at least one model is valid
398
  """
399
  try:
400
+ has_valid_model = False
401
+
402
+ # Check SAM2
403
+ if self.sam2_predictor is not None:
404
+ # Check for required methods/attributes
405
+ if hasattr(self.sam2_predictor, 'set_image') or hasattr(self.sam2_predictor, 'predict'):
406
+ has_valid_model = True
407
+ logger.info("SAM2 validation passed")
408
+ elif hasattr(self.sam2_predictor, 'model'):
409
+ has_valid_model = True
410
+ logger.info("SAM2 model found")
411
 
412
+ # Check MatAnyone
413
+ if self.matanyone_model is not None:
414
+ has_valid_model = True
415
+ logger.info("MatAnyone validation passed")
416
 
417
+ return has_valid_model
 
 
 
418
 
419
  except Exception as e:
420
  logger.error(f"Model validation failed: {e}")
 
442
 
443
  @property
444
  def models_ready(self) -> bool:
445
+ """Check if at least one model is loaded and ready"""
446
+ return self.sam2_predictor is not None or self.matanyone_model is not None