MogensR commited on
Commit
4fc49fc
·
1 Parent(s): 7d67503

Update models/loaders/model_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/model_loader.py +23 -28
models/loaders/model_loader.py CHANGED
@@ -5,9 +5,6 @@
5
  (Modern version for BackgroundFX Pro – only edit this file for model loading logic)
6
  """
7
 
8
- # ============================================================================
9
- # IMPORTS AND DEPENDENCIES
10
- # ============================================================================
11
  import os
12
  import gc
13
  import sys
@@ -19,7 +16,6 @@
19
 
20
  import torch
21
 
22
- # Modular dependencies (adapt as your structure changes)
23
  from core.exceptions import ModelLoadingError
24
  from utils.hardware.device_manager import DeviceManager
25
  from utils.system.memory_manager import MemoryManager
@@ -30,10 +26,6 @@
30
  # LOADED MODEL DATA CONTAINER
31
  # ============================================================================
32
  class LoadedModel:
33
- """
34
- Tracks loaded model + metadata.
35
- Useful for dashboards, export, analytics, etc.
36
- """
37
  def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""):
38
  self.model = model
39
  self.model_id = model_id
@@ -57,10 +49,6 @@ def __repr__(self):
57
  # MODEL LOADER CLASS
58
  # ============================================================================
59
  class ModelLoader:
60
- """
61
- Loads and manages SAM2 and MatAnyOne models.
62
- Tune all model-specific logic/settings here.
63
- """
64
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
65
  self.device_manager = device_mgr
66
  self.memory_manager = memory_mgr
@@ -86,10 +74,6 @@ def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
86
  # MAIN LOADING FUNCTION (ORCHESTRATION)
87
  # ============================================================================
88
  def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_event=None) -> Tuple[Any, Any]:
89
- """
90
- Loads both SAM2 and MatAnyOne models with error handling.
91
- Returns: (sam2_predictor, matanyone_model)
92
- """
93
  start_time = time.time()
94
  self.loading_stats['loading_attempts'] += 1
95
 
@@ -100,6 +84,9 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
100
 
101
  self._cleanup_models()
102
 
 
 
 
103
  # Load SAM2 first
104
  logger.info("Loading SAM2 predictor...")
105
  if progress_callback:
@@ -158,10 +145,6 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
158
  # SAM2 LOADING (OFFICIAL FROM_PRETRAINED)
159
  # ============================================================================
160
  def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
161
- """
162
- Loads SAM2 using the official Hugging Face interface.
163
- Returns: LoadedModel instance or None
164
- """
165
  model_size = "large"
166
  try:
167
  if hasattr(self.device_manager, 'get_device_memory_gb'):
@@ -183,6 +166,7 @@ def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
183
  "large": "facebook/sam2.1-hiera-large"
184
  }
185
  model_id = model_map.get(model_size, model_map["large"])
 
186
 
187
  if progress_callback:
188
  progress_callback(0.3, f"Loading SAM2 {model_size} model...")
@@ -191,6 +175,8 @@ def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
191
  from sam2.sam2_image_predictor import SAM2ImagePredictor
192
  t0 = time.time()
193
  predictor = SAM2ImagePredictor.from_pretrained(model_id)
 
 
194
  if hasattr(predictor, 'model'):
195
  predictor.model = predictor.model.to(self.device)
196
  t1 = time.time()
@@ -202,21 +188,22 @@ def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
202
  device=str(self.device),
203
  framework="sam2"
204
  )
 
 
 
 
205
  except ImportError:
206
  logger.error("SAM2 module not found. Install with: pip install sam2")
207
  return None
208
  except Exception as e:
209
  logger.error(f"SAM2 loading failed: {e}")
 
210
  return None
211
 
212
  # ============================================================================
213
  # MATANYONE LOADING (OFFICIAL INFERENCECORE)
214
  # ============================================================================
215
  def _load_matanyone_model(self, progress_callback: Optional[Callable] = None):
216
- """
217
- Loads MatAnyOne using Hugging Face official 'matanyone' package.
218
- Returns: LoadedModel instance or None
219
- """
220
  try:
221
  if progress_callback:
222
  progress_callback(0.7, "Loading MatAnyOne model...")
@@ -224,12 +211,14 @@ def _load_matanyone_model(self, progress_callback: Optional[Callable] = None):
224
  from matanyone import InferenceCore
225
  t0 = time.time()
226
  matanyone_kwargs = dict(
227
- repo_id="PeiqingYang/MatAnyone", # You can change to any compatible Hugging Face repo
228
- device=self.device, # Device to load on ("cuda" or "cpu")
229
- dtype=torch.float32, # Or torch.float16 for fast, but only for GPUs with good fp16
230
- # chunk_size=512, # Optional: for memory tuning on large videos
231
  )
 
232
  processor = InferenceCore(**matanyone_kwargs)
 
233
  t1 = time.time()
234
  logger.info("MatAnyOne loaded successfully (InferenceCore)")
235
  return LoadedModel(
@@ -239,11 +228,16 @@ def _load_matanyone_model(self, progress_callback: Optional[Callable] = None):
239
  device=str(self.device),
240
  framework="matanyone"
241
  )
 
 
 
 
242
  except ImportError:
243
  logger.error("MatAnyOne module not found. Install with: pip install matanyone")
244
  return None
245
  except Exception as e:
246
  logger.error(f"MatAnyOne loading failed: {e}")
 
247
  return None
248
 
249
  # ============================================================================
@@ -335,3 +329,4 @@ def models_ready(self) -> bool:
335
  # ============================================================================
336
  # END MODEL LOADER
337
  # ============================================================================
 
 
5
  (Modern version for BackgroundFX Pro – only edit this file for model loading logic)
6
  """
7
 
 
 
 
8
  import os
9
  import gc
10
  import sys
 
16
 
17
  import torch
18
 
 
19
  from core.exceptions import ModelLoadingError
20
  from utils.hardware.device_manager import DeviceManager
21
  from utils.system.memory_manager import MemoryManager
 
26
  # LOADED MODEL DATA CONTAINER
27
  # ============================================================================
28
  class LoadedModel:
 
 
 
 
29
  def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""):
30
  self.model = model
31
  self.model_id = model_id
 
49
  # MODEL LOADER CLASS
50
  # ============================================================================
51
  class ModelLoader:
 
 
 
 
52
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
53
  self.device_manager = device_mgr
54
  self.memory_manager = memory_mgr
 
74
  # MAIN LOADING FUNCTION (ORCHESTRATION)
75
  # ============================================================================
76
  def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_event=None) -> Tuple[Any, Any]:
 
 
 
 
77
  start_time = time.time()
78
  self.loading_stats['loading_attempts'] += 1
79
 
 
84
 
85
  self._cleanup_models()
86
 
87
+ # --- DIAG: Log device and model selection step
88
+ logger.info(f"Device for models: {self.device}")
89
+
90
  # Load SAM2 first
91
  logger.info("Loading SAM2 predictor...")
92
  if progress_callback:
 
145
  # SAM2 LOADING (OFFICIAL FROM_PRETRAINED)
146
  # ============================================================================
147
  def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
 
 
 
 
148
  model_size = "large"
149
  try:
150
  if hasattr(self.device_manager, 'get_device_memory_gb'):
 
166
  "large": "facebook/sam2.1-hiera-large"
167
  }
168
  model_id = model_map.get(model_size, model_map["large"])
169
+ logger.info(f"[DIAG] About to load SAM2 model_id: {model_id} on device {self.device}")
170
 
171
  if progress_callback:
172
  progress_callback(0.3, f"Loading SAM2 {model_size} model...")
 
175
  from sam2.sam2_image_predictor import SAM2ImagePredictor
176
  t0 = time.time()
177
  predictor = SAM2ImagePredictor.from_pretrained(model_id)
178
+ logger.info(f"[DIAG] SAM2 predictor instance type: {type(predictor)}")
179
+ # If this fails, it's likely a missing model or bad download
180
  if hasattr(predictor, 'model'):
181
  predictor.model = predictor.model.to(self.device)
182
  t1 = time.time()
 
188
  device=str(self.device),
189
  framework="sam2"
190
  )
191
+ except IndexError as e:
192
+ logger.error(f"SAM2 IndexError: {e}. (Did the model download fail? Wrong model_id?)")
193
+ logger.error(traceback.format_exc())
194
+ return None
195
  except ImportError:
196
  logger.error("SAM2 module not found. Install with: pip install sam2")
197
  return None
198
  except Exception as e:
199
  logger.error(f"SAM2 loading failed: {e}")
200
+ logger.error(traceback.format_exc())
201
  return None
202
 
203
  # ============================================================================
204
  # MATANYONE LOADING (OFFICIAL INFERENCECORE)
205
  # ============================================================================
206
  def _load_matanyone_model(self, progress_callback: Optional[Callable] = None):
 
 
 
 
207
  try:
208
  if progress_callback:
209
  progress_callback(0.7, "Loading MatAnyOne model...")
 
211
  from matanyone import InferenceCore
212
  t0 = time.time()
213
  matanyone_kwargs = dict(
214
+ repo_id="PeiqingYang/MatAnyone",
215
+ device=self.device,
216
+ dtype=torch.float32,
217
+ # chunk_size=512,
218
  )
219
+ logger.info(f"[DIAG] About to load MatAnyOne from repo: {matanyone_kwargs['repo_id']} on device {self.device}")
220
  processor = InferenceCore(**matanyone_kwargs)
221
+ logger.info(f"[DIAG] MatAnyOne processor type: {type(processor)}")
222
  t1 = time.time()
223
  logger.info("MatAnyOne loaded successfully (InferenceCore)")
224
  return LoadedModel(
 
228
  device=str(self.device),
229
  framework="matanyone"
230
  )
231
+ except IndexError as e:
232
+ logger.error(f"MatAnyOne IndexError: {e}. (Did the model download fail? Wrong repo_id?)")
233
+ logger.error(traceback.format_exc())
234
+ return None
235
  except ImportError:
236
  logger.error("MatAnyOne module not found. Install with: pip install matanyone")
237
  return None
238
  except Exception as e:
239
  logger.error(f"MatAnyOne loading failed: {e}")
240
+ logger.error(traceback.format_exc())
241
  return None
242
 
243
  # ============================================================================
 
329
  # ============================================================================
330
  # END MODEL LOADER
331
  # ============================================================================
332
+