MogensR commited on
Commit
9685fa7
·
1 Parent(s): d03832f

Update models/loaders/model_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/model_loader.py +278 -346
models/loaders/model_loader.py CHANGED
@@ -1,12 +1,15 @@
1
  #!/usr/bin/env python3
2
  """
3
- FIXED Model Loading Module for HuggingFace Spaces
4
- Handles the list index out of range error
 
 
5
  """
6
 
 
 
7
  import os
8
  import gc
9
- import sys
10
  import time
11
  import logging
12
  import traceback
@@ -21,6 +24,10 @@
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
24
  class LoadedModel:
25
  def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""):
26
  self.model = model
@@ -29,40 +36,51 @@ def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, devic
29
  self.device = device
30
  self.framework = framework
31
 
32
- def to_dict(self):
33
  return {
34
  "model_id": self.model_id,
35
  "framework": self.framework,
36
  "device": self.device,
37
  "load_time": self.load_time,
38
- "loaded": self.model is not None
39
  }
40
 
 
 
 
 
41
  class ModelLoader:
42
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
43
  self.device_manager = device_mgr
44
  self.memory_manager = memory_mgr
45
- self.device = self.device_manager.get_optimal_device()
46
 
47
- self.sam2_predictor = None
48
- self.matanyone_model = None
49
 
50
  self.checkpoints_dir = "./checkpoints"
51
  os.makedirs(self.checkpoints_dir, exist_ok=True)
52
 
53
  self.loading_stats = {
54
- 'sam2_load_time': 0.0,
55
- 'matanyone_load_time': 0.0,
56
- 'total_load_time': 0.0,
57
- 'models_loaded': False,
58
- 'loading_attempts': 0
59
  }
60
 
61
  logger.info(f"ModelLoader initialized for device: {self.device}")
62
 
63
- def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_event=None) -> Tuple[Any, Any]:
 
 
 
 
 
 
 
64
  start_time = time.time()
65
- self.loading_stats['loading_attempts'] += 1
66
 
67
  try:
68
  logger.info("Starting model loading process...")
@@ -71,66 +89,141 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
71
 
72
  self._cleanup_models()
73
 
74
- # Load SAM2 with better error handling
75
  logger.info("Loading SAM2 predictor...")
76
  if progress_callback:
77
  progress_callback(0.1, "Loading SAM2 predictor...")
78
- sam2_loaded = self._load_sam2_predictor_safe(progress_callback)
79
 
80
  if sam2_loaded is None:
81
- logger.warning("SAM2 loading failed - will use fallback segmentation")
82
  else:
83
  self.sam2_predictor = sam2_loaded
84
- sam2_time = self.sam2_predictor.load_time
85
- self.loading_stats['sam2_load_time'] = sam2_time
86
- logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
 
 
 
 
 
87
 
88
- # Load MatAnyOne with better error handling
89
  logger.info("Loading MatAnyOne model...")
90
  if progress_callback:
91
  progress_callback(0.6, "Loading MatAnyOne model...")
92
-
93
- matanyone_loaded = self._load_matanyone_model_safe(progress_callback)
94
 
95
  if matanyone_loaded is None:
96
- logger.warning("MatAnyOne loading failed - will use OpenCV refinement")
97
  else:
98
  self.matanyone_model = matanyone_loaded
99
- matanyone_time = self.matanyone_model.load_time
100
- self.loading_stats['matanyone_load_time'] = matanyone_time
101
- logger.info(f"MatAnyOne loaded in {matanyone_time:.1f}s")
102
 
103
- # Final status
104
  total_time = time.time() - start_time
105
- self.loading_stats['total_load_time'] = total_time
106
- self.loading_stats['models_loaded'] = bool(self.sam2_predictor or self.matanyone_model)
107
 
108
  if progress_callback:
109
- if self.sam2_predictor or self.matanyone_model:
110
- progress_callback(1.0, "Models loaded (with fallbacks available)")
111
  else:
112
  progress_callback(1.0, "Using fallback methods (models failed to load)")
113
 
114
  logger.info(f"Model loading completed in {total_time:.2f}s")
115
-
116
- return (self.sam2_predictor, self.matanyone_model)
117
 
118
  except Exception as e:
119
  error_msg = f"Model loading failed: {str(e)}"
120
  logger.error(f"{error_msg}\n{traceback.format_exc()}")
121
  self._cleanup_models()
122
- self.loading_stats['models_loaded'] = False
123
  if progress_callback:
124
  progress_callback(1.0, f"Error: {error_msg}")
125
  return None, None
126
 
127
- def _load_sam2_predictor_safe(self, progress_callback: Optional[Callable] = None):
128
- """Load SAM2 with comprehensive error handling for HuggingFace Spaces"""
129
-
130
- # Determine model size based on available memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  model_size = "large"
132
  try:
133
- if hasattr(self.device_manager, 'get_device_memory_gb'):
134
  memory_gb = self.device_manager.get_device_memory_gb()
135
  if memory_gb < 4:
136
  model_size = "tiny"
@@ -138,290 +231,203 @@ def _load_sam2_predictor_safe(self, progress_callback: Optional[Callable] = None
138
  model_size = "small"
139
  elif memory_gb < 12:
140
  model_size = "base"
141
- logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory")
142
  except Exception as e:
143
  logger.warning(f"Could not determine device memory: {e}")
144
- model_size = "tiny" # Default to tiny for Spaces
145
 
146
  model_map = {
147
  "tiny": "facebook/sam2.1-hiera-tiny",
148
- "small": "facebook/sam2.1-hiera-small",
149
  "base": "facebook/sam2.1-hiera-base-plus",
150
- "large": "facebook/sam2.1-hiera-large"
151
  }
152
  model_id = model_map.get(model_size, model_map["tiny"])
153
-
154
- logger.info(f"[DIAG] Loading SAM2 model_id: {model_id} on device {self.device}")
155
 
156
  if progress_callback:
157
- progress_callback(0.3, f"Loading SAM2 {model_size} model...")
158
 
159
- # Try multiple loading strategies
160
- loading_methods = [
161
  ("official", self._try_load_sam2_official, model_id),
162
  ("direct", self._try_load_sam2_direct, model_id),
163
  ("manual", self._try_load_sam2_manual, model_id),
164
  ]
165
 
166
- for method_name, method_func, model_id in loading_methods:
167
  try:
168
- logger.info(f"Attempting SAM2 load via {method_name} method...")
169
- result = method_func(model_id)
170
  if result is not None:
171
- logger.info(f"SAM2 loaded successfully via {method_name} method")
172
  return result
173
- except IndexError as e:
174
- logger.error(f"SAM2 {method_name} method - IndexError: {e}")
175
- logger.debug(f"Full traceback:\n{traceback.format_exc()}")
176
- continue
177
  except Exception as e:
178
- logger.error(f"SAM2 {method_name} method failed: {e}")
 
179
  continue
180
 
181
  logger.error("All SAM2 loading methods failed")
182
  return None
183
 
184
- def _try_load_sam2_official(self, model_id: str):
185
- """Try the official from_pretrained method"""
186
- try:
187
- from sam2.sam2_image_predictor import SAM2ImagePredictor
188
-
189
- # Set environment variables that might help in Spaces
190
- os.environ['HF_HUB_DISABLE_SYMLINKS'] = '1'
191
- os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '0'
192
-
193
- t0 = time.time()
194
-
195
- # Try with explicit cache directory
196
- cache_dir = os.path.join(self.checkpoints_dir, "sam2_cache")
197
- os.makedirs(cache_dir, exist_ok=True)
198
-
199
- # Log what we're about to do
200
- logger.debug(f"Calling SAM2ImagePredictor.from_pretrained('{model_id}')")
201
-
202
- # This is where the IndexError likely happens
203
- predictor = SAM2ImagePredictor.from_pretrained(
204
- model_id,
205
- cache_dir=cache_dir,
206
- local_files_only=False,
207
- trust_remote_code=True
208
- )
209
-
210
- if hasattr(predictor, 'model'):
211
- predictor.model = predictor.model.to(self.device)
212
-
213
- t1 = time.time()
214
-
215
- return LoadedModel(
216
- model=predictor,
217
- model_id=model_id,
218
- load_time=t1-t0,
219
- device=str(self.device),
220
- framework="sam2"
221
- )
222
- except Exception as e:
223
- logger.error(f"Official SAM2 loading failed: {e}")
224
- raise
225
-
226
- def _try_load_sam2_direct(self, model_id: str):
227
- """Try loading SAM2 using transformers AutoModel"""
228
- try:
229
- from transformers import AutoModel, AutoProcessor
230
-
231
- t0 = time.time()
232
-
233
- # Try loading as a standard transformers model
234
- model = AutoModel.from_pretrained(
235
- model_id,
236
- trust_remote_code=True,
237
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
238
- ).to(self.device)
239
-
240
- # Try to get processor
241
- try:
242
- processor = AutoProcessor.from_pretrained(model_id)
243
- except:
244
- processor = None
245
-
246
- t1 = time.time()
247
-
248
- # Wrap in a compatible interface
249
- class SAM2Wrapper:
250
- def __init__(self, model, processor=None):
251
- self.model = model
252
- self.processor = processor
253
-
254
- def set_image(self, image):
255
- self.current_image = image
256
-
257
- def predict(self, *args, **kwargs):
258
- # Basic prediction interface
259
- return self.model(*args, **kwargs)
260
-
261
- wrapped = SAM2Wrapper(model, processor)
262
-
263
- return LoadedModel(
264
- model=wrapped,
265
- model_id=model_id,
266
- load_time=t1-t0,
267
- device=str(self.device),
268
- framework="sam2-transformers"
269
- )
270
- except Exception as e:
271
- logger.error(f"Direct SAM2 loading failed: {e}")
272
- raise
273
-
274
- def _try_load_sam2_manual(self, model_id: str):
275
- """Try manual model construction as last resort"""
276
- try:
277
- # This is a fallback - create a dummy model that at least won't crash
278
- logger.warning("Using manual SAM2 construction (limited functionality)")
279
-
280
- class DummySAM2:
281
- def __init__(self, device):
282
- self.device = device
283
- self.model = None
284
-
285
- def set_image(self, image):
286
- self.current_image = image
287
-
288
- def predict(self, point_coords=None, point_labels=None, box=None, **kwargs):
289
- # Return a dummy mask
290
- import numpy as np
291
- if hasattr(self, 'current_image'):
292
- h, w = self.current_image.shape[:2]
293
- else:
294
- h, w = 512, 512
295
- return {
296
- 'masks': np.ones((1, h, w), dtype=np.float32),
297
- 'scores': np.array([0.5]),
298
- 'logits': np.ones((1, h, w), dtype=np.float32)
299
- }
300
-
301
- dummy = DummySAM2(self.device)
302
-
303
- return LoadedModel(
304
- model=dummy,
305
- model_id=f"{model_id}-fallback",
306
- load_time=0.1,
307
- device=str(self.device),
308
- framework="sam2-fallback"
309
- )
310
- except Exception as e:
311
- logger.error(f"Manual SAM2 construction failed: {e}")
312
- raise
313
-
314
- def _load_matanyone_model_safe(self, progress_callback: Optional[Callable] = None):
315
- """Load MatAnyOne with comprehensive error handling"""
316
-
317
- loading_methods = [
318
- ("official", self._try_load_matanyone_official),
319
- ("alternative", self._try_load_matanyone_alternative),
320
- ("fallback", self._try_load_matanyone_fallback),
321
- ]
322
-
323
- for method_name, method_func in loading_methods:
324
- try:
325
- logger.info(f"Attempting MatAnyOne load via {method_name} method...")
326
- result = method_func(progress_callback)
327
- if result is not None:
328
- logger.info(f"MatAnyOne loaded successfully via {method_name} method")
329
- return result
330
- except IndexError as e:
331
- logger.error(f"MatAnyOne {method_name} method - IndexError: {e}")
332
- logger.debug(f"Full traceback:\n{traceback.format_exc()}")
333
- continue
334
- except Exception as e:
335
- logger.error(f"MatAnyOne {method_name} method failed: {e}")
336
- continue
337
 
338
- logger.error("All MatAnyOne loading methods failed")
339
- return None
 
340
 
341
- def _try_load_matanyone_official(self, progress_callback):
342
- """Try the official MatAnyOne loading method"""
343
- if progress_callback:
344
- progress_callback(0.7, "Loading MatAnyOne model (official)...")
345
 
346
- from matanyone import InferenceCore
347
  t0 = time.time()
348
-
349
- # Set cache directory
350
- cache_dir = os.path.join(self.checkpoints_dir, "matanyone_cache")
351
- os.makedirs(cache_dir, exist_ok=True)
352
-
353
- processor = InferenceCore(
354
- repo_id="PeiqingYang/MatAnyone",
355
- device=self.device,
356
- dtype=torch.float32,
357
- cache_dir=cache_dir
358
  )
359
-
 
360
  t1 = time.time()
361
-
362
  return LoadedModel(
363
- model=processor,
364
- model_id="PeiqingYang/MatAnyone",
365
- load_time=t1-t0,
366
- device=str(self.device),
367
- framework="matanyone"
368
  )
369
 
370
- def _try_load_matanyone_alternative(self, progress_callback):
371
- """Try alternative loading for MatAnyOne"""
372
- if progress_callback:
373
- progress_callback(0.7, "Loading MatAnyOne model (alternative)...")
374
-
375
- # Try loading via transformers
376
- from transformers import AutoModel
377
-
378
  t0 = time.time()
379
  model = AutoModel.from_pretrained(
380
- "PeiqingYang/MatAnyone",
381
  trust_remote_code=True,
382
- torch_dtype=torch.float32
383
  ).to(self.device)
 
 
 
 
 
 
384
  t1 = time.time()
385
-
386
- # Wrap for compatibility
387
- class MatAnyoneWrapper:
388
- def __init__(self, model):
389
  self.model = model
390
-
391
- def process(self, image, mask):
392
- return self.model(image, mask)
393
-
 
 
 
 
 
 
394
  return LoadedModel(
395
- model=MatAnyoneWrapper(model),
396
- model_id="PeiqingYang/MatAnyone-alt",
397
- load_time=t1-t0,
398
  device=str(self.device),
399
- framework="matanyone-transformers"
400
  )
401
 
402
- def _try_load_matanyone_fallback(self, progress_callback):
403
- """Create a fallback MatAnyOne that won't crash"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  if progress_callback:
405
- progress_callback(0.7, "Using MatAnyOne fallback...")
406
-
407
- logger.warning("Using fallback MatAnyOne (limited functionality)")
408
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  class FallbackMatAnyone:
410
  def __init__(self, device):
411
  self.device = device
412
-
413
  def process(self, image, mask):
414
- # Just return the mask unchanged
415
  return mask
416
-
 
 
 
 
 
417
  return LoadedModel(
418
- model=FallbackMatAnyone(self.device),
419
- model_id="MatAnyone-fallback",
420
- load_time=0.1,
421
- device=str(self.device),
422
- framework="matanyone-fallback"
423
  )
424
 
 
 
425
  def _cleanup_models(self):
426
  if self.sam2_predictor is not None:
427
  del self.sam2_predictor
@@ -433,77 +439,3 @@ def _cleanup_models(self):
433
  torch.cuda.empty_cache()
434
  gc.collect()
435
  logger.debug("Model cleanup completed")
436
-
437
- def cleanup(self):
438
- self._cleanup_models()
439
- logger.info("ModelLoader cleanup completed")
440
-
441
- def get_model_info(self) -> Dict[str, Any]:
442
- info = {
443
- 'models_loaded': self.loading_stats['models_loaded'],
444
- 'sam2_loaded': self.sam2_predictor is not None,
445
- 'matanyone_loaded': self.matanyone_model is not None,
446
- 'device': str(self.device),
447
- 'loading_stats': self.loading_stats.copy()
448
- }
449
- if self.sam2_predictor is not None:
450
- info['sam2_model_type'] = type(self.sam2_predictor.model).__name__
451
- info['sam2_metadata'] = self.sam2_predictor.to_dict()
452
- if self.matanyone_model is not None:
453
- info['matanyone_model_type'] = type(self.matanyone_model.model).__name__
454
- info['matanyone_metadata'] = self.matanyone_model.to_dict()
455
- return info
456
-
457
- def get_load_summary(self) -> str:
458
- if not self.loading_stats['models_loaded']:
459
- return "Models not loaded"
460
- sam2_time = self.loading_stats['sam2_load_time']
461
- matanyone_time = self.loading_stats['matanyone_load_time']
462
- total_time = self.loading_stats['total_load_time']
463
- summary = f"Models loaded in {total_time:.1f}s\n"
464
- if self.sam2_predictor:
465
- summary += f"✓ SAM2: {sam2_time:.1f}s (ID: {self.sam2_predictor.model_id})\n"
466
- else:
467
- summary += f"✗ SAM2: Failed (using fallback)\n"
468
- if self.matanyone_model:
469
- summary += f"✓ MatAnyOne: {matanyone_time:.1f}s (ID: {self.matanyone_model.model_id})\n"
470
- else:
471
- summary += f"✗ MatAnyOne: Failed (using OpenCV)\n"
472
- summary += f"Device: {self.device}"
473
- return summary
474
-
475
- def get_matanyone(self):
476
- # Return the actual model from inside the LoadedModel wrapper
477
- if self.matanyone_model is not None:
478
- return self.matanyone_model.model if hasattr(self.matanyone_model, 'model') else None
479
- return None
480
-
481
- def get_sam2(self):
482
- # Return the actual model from inside the LoadedModel wrapper
483
- if self.sam2_predictor is not None:
484
- return self.sam2_predictor.model if hasattr(self.sam2_predictor, 'model') else None
485
- return None
486
-
487
- def validate_models(self) -> bool:
488
- try:
489
- has_valid_model = False
490
- if self.sam2_predictor is not None:
491
- model = self.sam2_predictor.model
492
- if hasattr(model, 'set_image') or hasattr(model, 'predict'):
493
- has_valid_model = True
494
- if self.matanyone_model is not None:
495
- has_valid_model = True
496
- return has_valid_model
497
- except Exception as e:
498
- logger.error(f"Model validation failed: {e}")
499
- return False
500
-
501
- def reload_models(self, progress_callback: Optional[Callable] = None) -> Tuple[Any, Any]:
502
- logger.info("Reloading models...")
503
- self._cleanup_models()
504
- self.loading_stats['models_loaded'] = False
505
- return self.load_all_models(progress_callback)
506
-
507
- @property
508
- def models_ready(self) -> bool:
509
- return self.sam2_predictor is not None or self.matanyone_model is not None
 
1
  #!/usr/bin/env python3
2
  """
3
+ Model Loader for Hugging Face Spaces
4
+ - Robust SAM2 loader with multiple strategies
5
+ - Correct MatAnyOne loader via official InferenceCore (no transformers)
6
+ - Clean progress reporting, cleanup, and diagnostics
7
  """
8
 
9
+ from __future__ import annotations
10
+
11
  import os
12
  import gc
 
13
  import time
14
  import logging
15
  import traceback
 
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
+
28
+ # ------------------------------
29
+ # Data wrapper
30
+ # ------------------------------
31
  class LoadedModel:
32
  def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""):
33
  self.model = model
 
36
  self.device = device
37
  self.framework = framework
38
 
39
+ def to_dict(self) -> Dict[str, Any]:
40
  return {
41
  "model_id": self.model_id,
42
  "framework": self.framework,
43
  "device": self.device,
44
  "load_time": self.load_time,
45
+ "loaded": self.model is not None,
46
  }
47
 
48
+
49
+ # ------------------------------
50
+ # Loader
51
+ # ------------------------------
52
  class ModelLoader:
53
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
54
  self.device_manager = device_mgr
55
  self.memory_manager = memory_mgr
56
+ self.device = self.device_manager.get_optimal_device() # e.g., cuda:0 or cpu
57
 
58
+ self.sam2_predictor: Optional[LoadedModel] = None
59
+ self.matanyone_model: Optional[LoadedModel] = None
60
 
61
  self.checkpoints_dir = "./checkpoints"
62
  os.makedirs(self.checkpoints_dir, exist_ok=True)
63
 
64
  self.loading_stats = {
65
+ "sam2_load_time": 0.0,
66
+ "matanyone_load_time": 0.0,
67
+ "total_load_time": 0.0,
68
+ "models_loaded": False,
69
+ "loading_attempts": 0,
70
  }
71
 
72
  logger.info(f"ModelLoader initialized for device: {self.device}")
73
 
74
+ # ---------- Public API ----------
75
+
76
+ def load_all_models(
77
+ self, progress_callback: Optional[Callable[[float, str], None]] = None, cancel_event=None
78
+ ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
79
+ """
80
+ Loads SAM2 + MatAnyOne. Returns (LoadedModel|None, LoadedModel|None).
81
+ """
82
  start_time = time.time()
83
+ self.loading_stats["loading_attempts"] += 1
84
 
85
  try:
86
  logger.info("Starting model loading process...")
 
89
 
90
  self._cleanup_models()
91
 
92
+ # ---- SAM2 ----
93
  logger.info("Loading SAM2 predictor...")
94
  if progress_callback:
95
  progress_callback(0.1, "Loading SAM2 predictor...")
96
+ sam2_loaded = self._load_sam2_predictor(progress_callback)
97
 
98
  if sam2_loaded is None:
99
+ logger.warning("SAM2 loading failed - a limited fallback will be used at runtime if needed.")
100
  else:
101
  self.sam2_predictor = sam2_loaded
102
+ self.loading_stats["sam2_load_time"] = self.sam2_predictor.load_time
103
+ logger.info(f"SAM2 loaded in {self.loading_stats['sam2_load_time']:.2f}s")
104
+
105
+ # Early exit if cancelled
106
+ if cancel_event is not None and getattr(cancel_event, "is_set", lambda: False)():
107
+ if progress_callback:
108
+ progress_callback(1.0, "Model loading cancelled")
109
+ return self.sam2_predictor, None
110
 
111
+ # ---- MatAnyOne ----
112
  logger.info("Loading MatAnyOne model...")
113
  if progress_callback:
114
  progress_callback(0.6, "Loading MatAnyOne model...")
115
+ matanyone_loaded = self._load_matanyone(progress_callback)
 
116
 
117
  if matanyone_loaded is None:
118
+ logger.warning("MatAnyOne loading failed - will use simple refinement fallbacks.")
119
  else:
120
  self.matanyone_model = matanyone_loaded
121
+ self.loading_stats["matanyone_load_time"] = self.matanyone_model.load_time
122
+ logger.info(f"MatAnyOne loaded in {self.loading_stats['matanyone_load_time']:.2f}s")
 
123
 
124
+ # ---- Final status ----
125
  total_time = time.time() - start_time
126
+ self.loading_stats["total_load_time"] = total_time
127
+ self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model)
128
 
129
  if progress_callback:
130
+ if self.loading_stats["models_loaded"]:
131
+ progress_callback(1.0, "Models loaded (fallbacks available if any model failed)")
132
  else:
133
  progress_callback(1.0, "Using fallback methods (models failed to load)")
134
 
135
  logger.info(f"Model loading completed in {total_time:.2f}s")
136
+ return self.sam2_predictor, self.matanyone_model
 
137
 
138
  except Exception as e:
139
  error_msg = f"Model loading failed: {str(e)}"
140
  logger.error(f"{error_msg}\n{traceback.format_exc()}")
141
  self._cleanup_models()
142
+ self.loading_stats["models_loaded"] = False
143
  if progress_callback:
144
  progress_callback(1.0, f"Error: {error_msg}")
145
  return None, None
146
 
147
+ def reload_models(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Tuple[
148
+ Optional[LoadedModel], Optional[LoadedModel]
149
+ ]:
150
+ logger.info("Reloading models...")
151
+ self._cleanup_models()
152
+ self.loading_stats["models_loaded"] = False
153
+ return self.load_all_models(progress_callback)
154
+
155
+ @property
156
+ def models_ready(self) -> bool:
157
+ return self.sam2_predictor is not None or self.matanyone_model is not None
158
+
159
+ def get_sam2(self):
160
+ return self.sam2_predictor.model if self.sam2_predictor is not None else None
161
+
162
+ def get_matanyone(self):
163
+ return self.matanyone_model.model if self.matanyone_model is not None else None
164
+
165
+ def validate_models(self) -> bool:
166
+ try:
167
+ ok = False
168
+ if self.sam2_predictor is not None:
169
+ model = self.sam2_predictor.model
170
+ if hasattr(model, "set_image") or hasattr(model, "predict"):
171
+ ok = True
172
+ if self.matanyone_model is not None:
173
+ ok = True
174
+ return ok
175
+ except Exception as e:
176
+ logger.error(f"Model validation failed: {e}")
177
+ return False
178
+
179
+ def get_model_info(self) -> Dict[str, Any]:
180
+ info = {
181
+ "models_loaded": self.loading_stats["models_loaded"],
182
+ "sam2_loaded": self.sam2_predictor is not None,
183
+ "matanyone_loaded": self.matanyone_model is not None,
184
+ "device": str(self.device),
185
+ "loading_stats": self.loading_stats.copy(),
186
+ }
187
+ if self.sam2_predictor is not None:
188
+ info["sam2_model_type"] = type(self.sam2_predictor.model).__name__
189
+ info["sam2_metadata"] = self.sam2_predictor.to_dict()
190
+ if self.matanyone_model is not None:
191
+ info["matanyone_model_type"] = type(self.matanyone_model.model).__name__
192
+ info["matanyone_metadata"] = self.matanyone_model.to_dict()
193
+ return info
194
+
195
+ def get_load_summary(self) -> str:
196
+ if not self.loading_stats["models_loaded"]:
197
+ return "Models not loaded"
198
+ sam2_time = self.loading_stats["sam2_load_time"]
199
+ matanyone_time = self.loading_stats["matanyone_load_time"]
200
+ total_time = self.loading_stats["total_load_time"]
201
+ summary = f"Models loaded in {total_time:.1f}s\n"
202
+ if self.sam2_predictor:
203
+ summary += f"✓ SAM2: {sam2_time:.1f}s (ID: {self.sam2_predictor.model_id})\n"
204
+ else:
205
+ summary += "✗ SAM2: Failed (using fallback)\n"
206
+ if self.matanyone_model:
207
+ summary += f"✓ MatAnyOne: {matanyone_time:.1f}s (ID: {self.matanyone_model.model_id})\n"
208
+ else:
209
+ summary += "✗ MatAnyOne: Failed (using simple refinement)\n"
210
+ summary += f"Device: {self.device}"
211
+ return summary
212
+
213
+ def cleanup(self):
214
+ self._cleanup_models()
215
+ logger.info("ModelLoader cleanup completed")
216
+
217
+ # ---------- Internal: SAM2 ----------
218
+
219
+ def _load_sam2_predictor(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Optional[LoadedModel]:
220
+ """
221
+ Try multiple SAM2 loading strategies: official -> transformers -> dummy fallback.
222
+ """
223
+ # Choose model size heuristically
224
  model_size = "large"
225
  try:
226
+ if hasattr(self.device_manager, "get_device_memory_gb"):
227
  memory_gb = self.device_manager.get_device_memory_gb()
228
  if memory_gb < 4:
229
  model_size = "tiny"
 
231
  model_size = "small"
232
  elif memory_gb < 12:
233
  model_size = "base"
234
+ logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB VRAM")
235
  except Exception as e:
236
  logger.warning(f"Could not determine device memory: {e}")
237
+ model_size = "tiny"
238
 
239
  model_map = {
240
  "tiny": "facebook/sam2.1-hiera-tiny",
241
+ "small": "facebook/sam2.1-hiera-small",
242
  "base": "facebook/sam2.1-hiera-base-plus",
243
+ "large": "facebook/sam2.1-hiera-large",
244
  }
245
  model_id = model_map.get(model_size, model_map["tiny"])
 
 
246
 
247
  if progress_callback:
248
+ progress_callback(0.3, f"Loading SAM2 ({model_size})...")
249
 
250
+ methods = [
 
251
  ("official", self._try_load_sam2_official, model_id),
252
  ("direct", self._try_load_sam2_direct, model_id),
253
  ("manual", self._try_load_sam2_manual, model_id),
254
  ]
255
 
256
+ for name, fn, mid in methods:
257
  try:
258
+ logger.info(f"Attempting SAM2 load via {name} method ({mid})...")
259
+ result = fn(mid)
260
  if result is not None:
261
+ logger.info(f"SAM2 loaded successfully via {name} method")
262
  return result
 
 
 
 
263
  except Exception as e:
264
+ logger.error(f"SAM2 {name} method failed: {e}")
265
+ logger.debug(traceback.format_exc())
266
  continue
267
 
268
  logger.error("All SAM2 loading methods failed")
269
  return None
270
 
271
+ def _try_load_sam2_official(self, model_id: str) -> Optional[LoadedModel]:
272
+ """
273
+ Official predictor path (Meta's SAM2ImagePredictor).
274
+ """
275
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
+ # Space-specific hub flags
278
+ os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"
279
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
280
 
281
+ cache_dir = os.path.join(self.checkpoints_dir, "sam2_cache")
282
+ os.makedirs(cache_dir, exist_ok=True)
 
 
283
 
 
284
  t0 = time.time()
285
+ predictor = SAM2ImagePredictor.from_pretrained(
286
+ model_id,
287
+ cache_dir=cache_dir,
288
+ local_files_only=False,
289
+ trust_remote_code=True,
 
 
 
 
 
290
  )
291
+ if hasattr(predictor, "model"):
292
+ predictor.model = predictor.model.to(self.device)
293
  t1 = time.time()
294
+
295
  return LoadedModel(
296
+ model=predictor, model_id=model_id, load_time=t1 - t0, device=str(self.device), framework="sam2"
 
 
 
 
297
  )
298
 
299
+ def _try_load_sam2_direct(self, model_id: str) -> Optional[LoadedModel]:
300
+ """
301
+ Transformers AutoModel path (best-effort; API may vary).
302
+ """
303
+ from transformers import AutoModel, AutoProcessor
304
+
 
 
305
  t0 = time.time()
306
  model = AutoModel.from_pretrained(
307
+ model_id,
308
  trust_remote_code=True,
309
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
310
  ).to(self.device)
311
+
312
+ try:
313
+ processor = AutoProcessor.from_pretrained(model_id)
314
+ except Exception:
315
+ processor = None
316
+
317
  t1 = time.time()
318
+
319
+ class SAM2Wrapper:
320
+ def __init__(self, model, processor=None):
 
321
  self.model = model
322
+ self.processor = processor
323
+
324
+ def set_image(self, image):
325
+ self.current_image = image
326
+
327
+ def predict(self, *args, **kwargs):
328
+ return self.model(*args, **kwargs)
329
+
330
+ wrapped = SAM2Wrapper(model, processor)
331
+
332
  return LoadedModel(
333
+ model=wrapped,
334
+ model_id=model_id,
335
+ load_time=t1 - t0,
336
  device=str(self.device),
337
+ framework="sam2-transformers",
338
  )
339
 
340
+ def _try_load_sam2_manual(self, model_id: str) -> Optional[LoadedModel]:
341
+ """
342
+ Dummy fallback that won't crash the app.
343
+ """
344
+ class DummySAM2:
345
+ def __init__(self, device):
346
+ self.device = device
347
+ self.model = None
348
+
349
+ def set_image(self, image):
350
+ self.current_image = image
351
+
352
+ def predict(self, point_coords=None, point_labels=None, box=None, **kwargs):
353
+ import numpy as np
354
+ if hasattr(self, "current_image"):
355
+ h, w = self.current_image.shape[:2]
356
+ else:
357
+ h, w = 512, 512
358
+ return {
359
+ "masks": np.ones((1, h, w), dtype=np.float32),
360
+ "scores": np.array([0.5]),
361
+ "logits": np.ones((1, h, w), dtype=np.float32),
362
+ }
363
+
364
+ t0 = time.time()
365
+ dummy = DummySAM2(self.device)
366
+ t1 = time.time()
367
+
368
+ logger.warning("Using manual SAM2 fallback (limited functionality)")
369
+ return LoadedModel(
370
+ model=dummy, model_id=f"{model_id}-fallback", load_time=t1 - t0, device=str(self.device), framework="sam2-fallback"
371
+ )
372
+
373
+ # ---------- Internal: MatAnyOne ----------
374
+
375
+ def _load_matanyone(self, progress_callback: Optional[Callable[[float, str], None]] = None) -> Optional[LoadedModel]:
376
+ """
377
+ Correct MatAnyOne loader using official package API.
378
+ """
379
  if progress_callback:
380
+ progress_callback(0.7, "Loading MatAnyOne (InferenceCore)...")
381
+ try:
382
+ return self._try_load_matanyone_official()
383
+ except Exception as e:
384
+ logger.error(f"MatAnyOne official loader failed: {e}")
385
+ logger.debug(traceback.format_exc())
386
+ logger.warning("Falling back to simple MatAnyOne placeholder.")
387
+ return self._try_load_matanyone_fallback()
388
+
389
+ def _try_load_matanyone_official(self) -> Optional[LoadedModel]:
390
+ """
391
+ Official MatAnyOne via package's InferenceCore.
392
+ IMPORTANT: pass model id POSITIONALLY; do NOT use repo_id= or transformers.
393
+ """
394
+ from matanyone import InferenceCore
395
+
396
+ t0 = time.time()
397
+ processor = InferenceCore("PeiqingYang/MatAnyone")
398
+ t1 = time.time()
399
+
400
+ return LoadedModel(
401
+ model=processor,
402
+ model_id="PeiqingYang/MatAnyone",
403
+ load_time=t1 - t0,
404
+ device=str(self.device),
405
+ framework="matanyone",
406
+ )
407
+
408
+ def _try_load_matanyone_fallback(self) -> Optional[LoadedModel]:
409
+ """
410
+ Minimal placeholder that safely passes masks through.
411
+ """
412
  class FallbackMatAnyone:
413
  def __init__(self, device):
414
  self.device = device
415
+
416
  def process(self, image, mask):
417
+ # Identity pass-through (keeps pipeline alive)
418
  return mask
419
+
420
+ t0 = time.time()
421
+ model = FallbackMatAnyone(self.device)
422
+ t1 = time.time()
423
+
424
+ logger.warning("Using MatAnyOne fallback (limited functionality)")
425
  return LoadedModel(
426
+ model=model, model_id="MatAnyone-fallback", load_time=t1 - t0, device=str(self.device), framework="matanyone-fallback"
 
 
 
 
427
  )
428
 
429
+ # ---------- Internal: cleanup ----------
430
+
431
  def _cleanup_models(self):
432
  if self.sam2_predictor is not None:
433
  del self.sam2_predictor
 
439
  torch.cuda.empty_cache()
440
  gc.collect()
441
  logger.debug("Model cleanup completed")