MogensR commited on
Commit
8695f97
·
1 Parent(s): 4fc49fc

Update models/loaders/model_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/model_loader.py +262 -91
models/loaders/model_loader.py CHANGED
@@ -1,8 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- Model Loading Module
4
- Handles loading and validation of SAM2 and MatAnyOne AI models
5
- (Modern version for BackgroundFX Pro – only edit this file for model loading logic)
6
  """
7
 
8
  import os
@@ -22,9 +21,6 @@
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
- # ============================================================================
26
- # LOADED MODEL DATA CONTAINER
27
- # ============================================================================
28
  class LoadedModel:
29
  def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""):
30
  self.model = model
@@ -42,20 +38,14 @@ def to_dict(self):
42
  "loaded": self.model is not None
43
  }
44
 
45
- def __repr__(self):
46
- return f"LoadedModel(id={self.model_id}, loaded={self.model is not None}, device={self.device}, framework={self.framework}, load_time={self.load_time:.2f}s)"
47
-
48
- # ============================================================================
49
- # MODEL LOADER CLASS
50
- # ============================================================================
51
  class ModelLoader:
52
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
53
  self.device_manager = device_mgr
54
  self.memory_manager = memory_mgr
55
  self.device = self.device_manager.get_optimal_device()
56
 
57
- self.sam2_predictor = None # LoadedModel instance or None
58
- self.matanyone_model = None # LoadedModel instance or None
59
 
60
  self.checkpoints_dir = "./checkpoints"
61
  os.makedirs(self.checkpoints_dir, exist_ok=True)
@@ -70,9 +60,6 @@ def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
70
 
71
  logger.info(f"ModelLoader initialized for device: {self.device}")
72
 
73
- # ============================================================================
74
- # MAIN LOADING FUNCTION (ORCHESTRATION)
75
- # ============================================================================
76
  def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_event=None) -> Tuple[Any, Any]:
77
  start_time = time.time()
78
  self.loading_stats['loading_attempts'] += 1
@@ -84,14 +71,11 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
84
 
85
  self._cleanup_models()
86
 
87
- # --- DIAG: Log device and model selection step
88
- logger.info(f"Device for models: {self.device}")
89
-
90
- # Load SAM2 first
91
  logger.info("Loading SAM2 predictor...")
92
  if progress_callback:
93
  progress_callback(0.1, "Loading SAM2 predictor...")
94
- sam2_loaded = self._load_sam2_predictor(progress_callback)
95
 
96
  if sam2_loaded is None:
97
  logger.warning("SAM2 loading failed - will use fallback segmentation")
@@ -101,13 +85,12 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
101
  self.loading_stats['sam2_load_time'] = sam2_time
102
  logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
103
 
104
- # Load MatAnyOne
105
  logger.info("Loading MatAnyOne model...")
106
  if progress_callback:
107
  progress_callback(0.6, "Loading MatAnyOne model...")
108
- matanyone_start = time.time()
109
-
110
- matanyone_loaded = self._load_matanyone_model(progress_callback)
111
 
112
  if matanyone_loaded is None:
113
  logger.warning("MatAnyOne loading failed - will use OpenCV refinement")
@@ -141,10 +124,10 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
141
  progress_callback(1.0, f"Error: {error_msg}")
142
  return None, None
143
 
144
- # ============================================================================
145
- # SAM2 LOADING (OFFICIAL FROM_PRETRAINED)
146
- # ============================================================================
147
- def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
148
  model_size = "large"
149
  try:
150
  if hasattr(self.device_manager, 'get_device_memory_gb'):
@@ -158,29 +141,77 @@ def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
158
  logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory")
159
  except Exception as e:
160
  logger.warning(f"Could not determine device memory: {e}")
 
161
 
162
  model_map = {
163
  "tiny": "facebook/sam2.1-hiera-tiny",
164
- "small": "facebook/sam2.1-hiera-small",
165
  "base": "facebook/sam2.1-hiera-base-plus",
166
  "large": "facebook/sam2.1-hiera-large"
167
  }
168
- model_id = model_map.get(model_size, model_map["large"])
169
- logger.info(f"[DIAG] About to load SAM2 model_id: {model_id} on device {self.device}")
 
170
 
171
  if progress_callback:
172
  progress_callback(0.3, f"Loading SAM2 {model_size} model...")
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  try:
175
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
 
 
 
 
176
  t0 = time.time()
177
- predictor = SAM2ImagePredictor.from_pretrained(model_id)
178
- logger.info(f"[DIAG] SAM2 predictor instance type: {type(predictor)}")
179
- # If this fails, it's likely a missing model or bad download
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  if hasattr(predictor, 'model'):
181
  predictor.model = predictor.model.to(self.device)
 
182
  t1 = time.time()
183
- logger.info("SAM2 loaded successfully via official from_pretrained")
184
  return LoadedModel(
185
  model=predictor,
186
  model_id=model_id,
@@ -188,61 +219,209 @@ def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
188
  device=str(self.device),
189
  framework="sam2"
190
  )
191
- except IndexError as e:
192
- logger.error(f"SAM2 IndexError: {e}. (Did the model download fail? Wrong model_id?)")
193
- logger.error(traceback.format_exc())
194
- return None
195
- except ImportError:
196
- logger.error("SAM2 module not found. Install with: pip install sam2")
197
- return None
198
  except Exception as e:
199
- logger.error(f"SAM2 loading failed: {e}")
200
- logger.error(traceback.format_exc())
201
- return None
202
-
203
- # ============================================================================
204
- # MATANYONE LOADING (OFFICIAL INFERENCECORE)
205
- # ============================================================================
206
- def _load_matanyone_model(self, progress_callback: Optional[Callable] = None):
207
- try:
208
- if progress_callback:
209
- progress_callback(0.7, "Loading MatAnyOne model...")
210
 
211
- from matanyone import InferenceCore
 
 
 
 
212
  t0 = time.time()
213
- matanyone_kwargs = dict(
214
- repo_id="PeiqingYang/MatAnyone",
215
- device=self.device,
216
- dtype=torch.float32,
217
- # chunk_size=512,
218
- )
219
- logger.info(f"[DIAG] About to load MatAnyOne from repo: {matanyone_kwargs['repo_id']} on device {self.device}")
220
- processor = InferenceCore(**matanyone_kwargs)
221
- logger.info(f"[DIAG] MatAnyOne processor type: {type(processor)}")
 
 
 
 
 
222
  t1 = time.time()
223
- logger.info("MatAnyOne loaded successfully (InferenceCore)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  return LoadedModel(
225
- model=processor,
226
- model_id=matanyone_kwargs["repo_id"],
227
  load_time=t1-t0,
228
  device=str(self.device),
229
- framework="matanyone"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  )
231
- except IndexError as e:
232
- logger.error(f"MatAnyOne IndexError: {e}. (Did the model download fail? Wrong repo_id?)")
233
- logger.error(traceback.format_exc())
234
- return None
235
- except ImportError:
236
- logger.error("MatAnyOne module not found. Install with: pip install matanyone")
237
- return None
238
  except Exception as e:
239
- logger.error(f"MatAnyOne loading failed: {e}")
240
- logger.error(traceback.format_exc())
241
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
- # ============================================================================
244
- # MODEL MANAGEMENT AND CLEANUP
245
- # ============================================================================
246
  def _cleanup_models(self):
247
  if self.sam2_predictor is not None:
248
  del self.sam2_predictor
@@ -259,9 +438,6 @@ def cleanup(self):
259
  self._cleanup_models()
260
  logger.info("ModelLoader cleanup completed")
261
 
262
- # ============================================================================
263
- # MODEL INFO AND VALIDATION
264
- # ============================================================================
265
  def get_model_info(self) -> Dict[str, Any]:
266
  info = {
267
  'models_loaded': self.loading_stats['models_loaded'],
@@ -324,9 +500,4 @@ def reload_models(self, progress_callback: Optional[Callable] = None) -> Tuple[A
324
 
325
  @property
326
  def models_ready(self) -> bool:
327
- return self.sam2_predictor is not None or self.matanyone_model is not None
328
-
329
- # ============================================================================
330
- # END MODEL LOADER
331
- # ============================================================================
332
-
 
1
  #!/usr/bin/env python3
2
  """
3
+ FIXED Model Loading Module for HuggingFace Spaces
4
+ Handles the list index out of range error
 
5
  """
6
 
7
  import os
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
24
  class LoadedModel:
25
  def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""):
26
  self.model = model
 
38
  "loaded": self.model is not None
39
  }
40
 
 
 
 
 
 
 
41
  class ModelLoader:
42
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
43
  self.device_manager = device_mgr
44
  self.memory_manager = memory_mgr
45
  self.device = self.device_manager.get_optimal_device()
46
 
47
+ self.sam2_predictor = None
48
+ self.matanyone_model = None
49
 
50
  self.checkpoints_dir = "./checkpoints"
51
  os.makedirs(self.checkpoints_dir, exist_ok=True)
 
60
 
61
  logger.info(f"ModelLoader initialized for device: {self.device}")
62
 
 
 
 
63
  def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_event=None) -> Tuple[Any, Any]:
64
  start_time = time.time()
65
  self.loading_stats['loading_attempts'] += 1
 
71
 
72
  self._cleanup_models()
73
 
74
+ # Load SAM2 with better error handling
 
 
 
75
  logger.info("Loading SAM2 predictor...")
76
  if progress_callback:
77
  progress_callback(0.1, "Loading SAM2 predictor...")
78
+ sam2_loaded = self._load_sam2_predictor_safe(progress_callback)
79
 
80
  if sam2_loaded is None:
81
  logger.warning("SAM2 loading failed - will use fallback segmentation")
 
85
  self.loading_stats['sam2_load_time'] = sam2_time
86
  logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
87
 
88
+ # Load MatAnyOne with better error handling
89
  logger.info("Loading MatAnyOne model...")
90
  if progress_callback:
91
  progress_callback(0.6, "Loading MatAnyOne model...")
92
+
93
+ matanyone_loaded = self._load_matanyone_model_safe(progress_callback)
 
94
 
95
  if matanyone_loaded is None:
96
  logger.warning("MatAnyOne loading failed - will use OpenCV refinement")
 
124
  progress_callback(1.0, f"Error: {error_msg}")
125
  return None, None
126
 
127
+ def _load_sam2_predictor_safe(self, progress_callback: Optional[Callable] = None):
128
+ """Load SAM2 with comprehensive error handling for HuggingFace Spaces"""
129
+
130
+ # Determine model size based on available memory
131
  model_size = "large"
132
  try:
133
  if hasattr(self.device_manager, 'get_device_memory_gb'):
 
141
  logger.info(f"Selected SAM2 {model_size} based on {memory_gb}GB memory")
142
  except Exception as e:
143
  logger.warning(f"Could not determine device memory: {e}")
144
+ model_size = "tiny" # Default to tiny for Spaces
145
 
146
  model_map = {
147
  "tiny": "facebook/sam2.1-hiera-tiny",
148
+ "small": "facebook/sam2.1-hiera-small",
149
  "base": "facebook/sam2.1-hiera-base-plus",
150
  "large": "facebook/sam2.1-hiera-large"
151
  }
152
+ model_id = model_map.get(model_size, model_map["tiny"])
153
+
154
+ logger.info(f"[DIAG] Loading SAM2 model_id: {model_id} on device {self.device}")
155
 
156
  if progress_callback:
157
  progress_callback(0.3, f"Loading SAM2 {model_size} model...")
158
 
159
+ # Try multiple loading strategies
160
+ loading_methods = [
161
+ ("official", self._try_load_sam2_official, model_id),
162
+ ("direct", self._try_load_sam2_direct, model_id),
163
+ ("manual", self._try_load_sam2_manual, model_id),
164
+ ]
165
+
166
+ for method_name, method_func, model_id in loading_methods:
167
+ try:
168
+ logger.info(f"Attempting SAM2 load via {method_name} method...")
169
+ result = method_func(model_id)
170
+ if result is not None:
171
+ logger.info(f"SAM2 loaded successfully via {method_name} method")
172
+ return result
173
+ except IndexError as e:
174
+ logger.error(f"SAM2 {method_name} method - IndexError: {e}")
175
+ logger.debug(f"Full traceback:\n{traceback.format_exc()}")
176
+ continue
177
+ except Exception as e:
178
+ logger.error(f"SAM2 {method_name} method failed: {e}")
179
+ continue
180
+
181
+ logger.error("All SAM2 loading methods failed")
182
+ return None
183
+
184
+ def _try_load_sam2_official(self, model_id: str):
185
+ """Try the official from_pretrained method"""
186
  try:
187
  from sam2.sam2_image_predictor import SAM2ImagePredictor
188
+
189
+ # Set environment variables that might help in Spaces
190
+ os.environ['HF_HUB_DISABLE_SYMLINKS'] = '1'
191
+ os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '0'
192
+
193
  t0 = time.time()
194
+
195
+ # Try with explicit cache directory
196
+ cache_dir = os.path.join(self.checkpoints_dir, "sam2_cache")
197
+ os.makedirs(cache_dir, exist_ok=True)
198
+
199
+ # Log what we're about to do
200
+ logger.debug(f"Calling SAM2ImagePredictor.from_pretrained('{model_id}')")
201
+
202
+ # This is where the IndexError likely happens
203
+ predictor = SAM2ImagePredictor.from_pretrained(
204
+ model_id,
205
+ cache_dir=cache_dir,
206
+ local_files_only=False,
207
+ trust_remote_code=True
208
+ )
209
+
210
  if hasattr(predictor, 'model'):
211
  predictor.model = predictor.model.to(self.device)
212
+
213
  t1 = time.time()
214
+
215
  return LoadedModel(
216
  model=predictor,
217
  model_id=model_id,
 
219
  device=str(self.device),
220
  framework="sam2"
221
  )
 
 
 
 
 
 
 
222
  except Exception as e:
223
+ logger.error(f"Official SAM2 loading failed: {e}")
224
+ raise
 
 
 
 
 
 
 
 
 
225
 
226
+ def _try_load_sam2_direct(self, model_id: str):
227
+ """Try loading SAM2 using transformers AutoModel"""
228
+ try:
229
+ from transformers import AutoModel, AutoProcessor
230
+
231
  t0 = time.time()
232
+
233
+ # Try loading as a standard transformers model
234
+ model = AutoModel.from_pretrained(
235
+ model_id,
236
+ trust_remote_code=True,
237
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
238
+ ).to(self.device)
239
+
240
+ # Try to get processor
241
+ try:
242
+ processor = AutoProcessor.from_pretrained(model_id)
243
+ except:
244
+ processor = None
245
+
246
  t1 = time.time()
247
+
248
+ # Wrap in a compatible interface
249
+ class SAM2Wrapper:
250
+ def __init__(self, model, processor=None):
251
+ self.model = model
252
+ self.processor = processor
253
+
254
+ def set_image(self, image):
255
+ self.current_image = image
256
+
257
+ def predict(self, *args, **kwargs):
258
+ # Basic prediction interface
259
+ return self.model(*args, **kwargs)
260
+
261
+ wrapped = SAM2Wrapper(model, processor)
262
+
263
  return LoadedModel(
264
+ model=wrapped,
265
+ model_id=model_id,
266
  load_time=t1-t0,
267
  device=str(self.device),
268
+ framework="sam2-transformers"
269
+ )
270
+ except Exception as e:
271
+ logger.error(f"Direct SAM2 loading failed: {e}")
272
+ raise
273
+
274
+ def _try_load_sam2_manual(self, model_id: str):
275
+ """Try manual model construction as last resort"""
276
+ try:
277
+ # This is a fallback - create a dummy model that at least won't crash
278
+ logger.warning("Using manual SAM2 construction (limited functionality)")
279
+
280
+ class DummySAM2:
281
+ def __init__(self, device):
282
+ self.device = device
283
+ self.model = None
284
+
285
+ def set_image(self, image):
286
+ self.current_image = image
287
+
288
+ def predict(self, point_coords=None, point_labels=None, box=None, **kwargs):
289
+ # Return a dummy mask
290
+ import numpy as np
291
+ if hasattr(self, 'current_image'):
292
+ h, w = self.current_image.shape[:2]
293
+ else:
294
+ h, w = 512, 512
295
+ return {
296
+ 'masks': np.ones((1, h, w), dtype=np.float32),
297
+ 'scores': np.array([0.5]),
298
+ 'logits': np.ones((1, h, w), dtype=np.float32)
299
+ }
300
+
301
+ dummy = DummySAM2(self.device)
302
+
303
+ return LoadedModel(
304
+ model=dummy,
305
+ model_id=f"{model_id}-fallback",
306
+ load_time=0.1,
307
+ device=str(self.device),
308
+ framework="sam2-fallback"
309
  )
 
 
 
 
 
 
 
310
  except Exception as e:
311
+ logger.error(f"Manual SAM2 construction failed: {e}")
312
+ raise
313
+
314
+ def _load_matanyone_model_safe(self, progress_callback: Optional[Callable] = None):
315
+ """Load MatAnyOne with comprehensive error handling"""
316
+
317
+ loading_methods = [
318
+ ("official", self._try_load_matanyone_official),
319
+ ("alternative", self._try_load_matanyone_alternative),
320
+ ("fallback", self._try_load_matanyone_fallback),
321
+ ]
322
+
323
+ for method_name, method_func in loading_methods:
324
+ try:
325
+ logger.info(f"Attempting MatAnyOne load via {method_name} method...")
326
+ result = method_func(progress_callback)
327
+ if result is not None:
328
+ logger.info(f"MatAnyOne loaded successfully via {method_name} method")
329
+ return result
330
+ except IndexError as e:
331
+ logger.error(f"MatAnyOne {method_name} method - IndexError: {e}")
332
+ logger.debug(f"Full traceback:\n{traceback.format_exc()}")
333
+ continue
334
+ except Exception as e:
335
+ logger.error(f"MatAnyOne {method_name} method failed: {e}")
336
+ continue
337
+
338
+ logger.error("All MatAnyOne loading methods failed")
339
+ return None
340
+
341
+ def _try_load_matanyone_official(self, progress_callback):
342
+ """Try the official MatAnyOne loading method"""
343
+ if progress_callback:
344
+ progress_callback(0.7, "Loading MatAnyOne model (official)...")
345
+
346
+ from matanyone import InferenceCore
347
+ t0 = time.time()
348
+
349
+ # Set cache directory
350
+ cache_dir = os.path.join(self.checkpoints_dir, "matanyone_cache")
351
+ os.makedirs(cache_dir, exist_ok=True)
352
+
353
+ processor = InferenceCore(
354
+ repo_id="PeiqingYang/MatAnyone",
355
+ device=self.device,
356
+ dtype=torch.float32,
357
+ cache_dir=cache_dir
358
+ )
359
+
360
+ t1 = time.time()
361
+
362
+ return LoadedModel(
363
+ model=processor,
364
+ model_id="PeiqingYang/MatAnyone",
365
+ load_time=t1-t0,
366
+ device=str(self.device),
367
+ framework="matanyone"
368
+ )
369
+
370
+ def _try_load_matanyone_alternative(self, progress_callback):
371
+ """Try alternative loading for MatAnyOne"""
372
+ if progress_callback:
373
+ progress_callback(0.7, "Loading MatAnyOne model (alternative)...")
374
+
375
+ # Try loading via transformers
376
+ from transformers import AutoModel
377
+
378
+ t0 = time.time()
379
+ model = AutoModel.from_pretrained(
380
+ "PeiqingYang/MatAnyone",
381
+ trust_remote_code=True,
382
+ torch_dtype=torch.float32
383
+ ).to(self.device)
384
+ t1 = time.time()
385
+
386
+ # Wrap for compatibility
387
+ class MatAnyoneWrapper:
388
+ def __init__(self, model):
389
+ self.model = model
390
+
391
+ def process(self, image, mask):
392
+ return self.model(image, mask)
393
+
394
+ return LoadedModel(
395
+ model=MatAnyoneWrapper(model),
396
+ model_id="PeiqingYang/MatAnyone-alt",
397
+ load_time=t1-t0,
398
+ device=str(self.device),
399
+ framework="matanyone-transformers"
400
+ )
401
+
402
+ def _try_load_matanyone_fallback(self, progress_callback):
403
+ """Create a fallback MatAnyOne that won't crash"""
404
+ if progress_callback:
405
+ progress_callback(0.7, "Using MatAnyOne fallback...")
406
+
407
+ logger.warning("Using fallback MatAnyOne (limited functionality)")
408
+
409
+ class FallbackMatAnyone:
410
+ def __init__(self, device):
411
+ self.device = device
412
+
413
+ def process(self, image, mask):
414
+ # Just return the mask unchanged
415
+ return mask
416
+
417
+ return LoadedModel(
418
+ model=FallbackMatAnyone(self.device),
419
+ model_id="MatAnyone-fallback",
420
+ load_time=0.1,
421
+ device=str(self.device),
422
+ framework="matanyone-fallback"
423
+ )
424
 
 
 
 
425
  def _cleanup_models(self):
426
  if self.sam2_predictor is not None:
427
  del self.sam2_predictor
 
438
  self._cleanup_models()
439
  logger.info("ModelLoader cleanup completed")
440
 
 
 
 
441
  def get_model_info(self) -> Dict[str, Any]:
442
  info = {
443
  'models_loaded': self.loading_stats['models_loaded'],
 
500
 
501
  @property
502
  def models_ready(self) -> bool:
503
+ return self.sam2_predictor is not None or self.matanyone_model is not None