MogensR commited on
Commit
7d67503
·
1 Parent(s): aa315a3

Update models/loaders/model_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/model_loader.py +68 -34
models/loaders/model_loader.py CHANGED
@@ -26,6 +26,33 @@
26
 
27
  logger = logging.getLogger(__name__)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # ============================================================================
30
  # MODEL LOADER CLASS
31
  # ============================================================================
@@ -34,14 +61,13 @@ class ModelLoader:
34
  Loads and manages SAM2 and MatAnyOne models.
35
  Tune all model-specific logic/settings here.
36
  """
37
-
38
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
39
  self.device_manager = device_mgr
40
  self.memory_manager = memory_mgr
41
  self.device = self.device_manager.get_optimal_device()
42
 
43
- self.sam2_predictor = None
44
- self.matanyone_model = None # This is usually InferenceCore
45
 
46
  self.checkpoints_dir = "./checkpoints"
47
  os.makedirs(self.checkpoints_dir, exist_ok=True)
@@ -78,12 +104,13 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
78
  logger.info("Loading SAM2 predictor...")
79
  if progress_callback:
80
  progress_callback(0.1, "Loading SAM2 predictor...")
81
- self.sam2_predictor = self._load_sam2_predictor(progress_callback)
82
 
83
- if self.sam2_predictor is None:
84
  logger.warning("SAM2 loading failed - will use fallback segmentation")
85
  else:
86
- sam2_time = time.time() - start_time
 
87
  self.loading_stats['sam2_load_time'] = sam2_time
88
  logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
89
 
@@ -93,19 +120,20 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
93
  progress_callback(0.6, "Loading MatAnyOne model...")
94
  matanyone_start = time.time()
95
 
96
- self.matanyone_model = self._load_matanyone_model(progress_callback)
97
 
98
- if self.matanyone_model is None:
99
  logger.warning("MatAnyOne loading failed - will use OpenCV refinement")
100
  else:
101
- matanyone_time = time.time() - matanyone_start
 
102
  self.loading_stats['matanyone_load_time'] = matanyone_time
103
  logger.info(f"MatAnyOne loaded in {matanyone_time:.1f}s")
104
 
105
  # Final status
106
  total_time = time.time() - start_time
107
  self.loading_stats['total_load_time'] = total_time
108
- self.loading_stats['models_loaded'] = True
109
 
110
  if progress_callback:
111
  if self.sam2_predictor or self.matanyone_model:
@@ -115,7 +143,7 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
115
 
116
  logger.info(f"Model loading completed in {total_time:.2f}s")
117
 
118
- return self.sam2_predictor, self.matanyone_model
119
 
120
  except Exception as e:
121
  error_msg = f"Model loading failed: {str(e)}"
@@ -132,7 +160,7 @@ def load_all_models(self, progress_callback: Optional[Callable] = None, cancel_e
132
  def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
133
  """
134
  Loads SAM2 using the official Hugging Face interface.
135
- Returns: SAM2 predictor object or None
136
  """
137
  model_size = "large"
138
  try:
@@ -161,11 +189,19 @@ def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
161
 
162
  try:
163
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
164
  predictor = SAM2ImagePredictor.from_pretrained(model_id)
165
  if hasattr(predictor, 'model'):
166
  predictor.model = predictor.model.to(self.device)
 
167
  logger.info("SAM2 loaded successfully via official from_pretrained")
168
- return predictor
 
 
 
 
 
 
169
  except ImportError:
170
  logger.error("SAM2 module not found. Install with: pip install sam2")
171
  return None
@@ -179,34 +215,30 @@ def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
179
  def _load_matanyone_model(self, progress_callback: Optional[Callable] = None):
180
  """
181
  Loads MatAnyOne using Hugging Face official 'matanyone' package.
182
- Returns: InferenceCore object or None
183
-
184
- ---------- MATANYONE TUNING SECTION ----------
185
- To adjust MatAnyOne settings, change arguments to InferenceCore below!
186
- (e.g., for precision, model variant, device, chunk size, etc.)
187
- ---------------------------------------------
188
  """
189
  try:
190
  if progress_callback:
191
  progress_callback(0.7, "Loading MatAnyOne model...")
192
 
193
- # --- HIGHLIGHT: SET ANY MatAnyOne SETTINGS HERE ---
194
  from matanyone import InferenceCore
195
-
196
- # Example: To set chunk size or custom model repo, add kwargs here.
197
- # See: https://huggingface.co/PeiqingYang/MatAnyone for config options
198
-
199
  matanyone_kwargs = dict(
200
  repo_id="PeiqingYang/MatAnyone", # You can change to any compatible Hugging Face repo
201
  device=self.device, # Device to load on ("cuda" or "cpu")
202
- dtype=torch.float32, # Change to torch.float16 for faster inference on good GPUs
203
  # chunk_size=512, # Optional: for memory tuning on large videos
204
  )
205
-
206
  processor = InferenceCore(**matanyone_kwargs)
 
207
  logger.info("MatAnyOne loaded successfully (InferenceCore)")
208
- return processor
209
-
 
 
 
 
 
210
  except ImportError:
211
  logger.error("MatAnyOne module not found. Install with: pip install matanyone")
212
  return None
@@ -245,9 +277,11 @@ def get_model_info(self) -> Dict[str, Any]:
245
  'loading_stats': self.loading_stats.copy()
246
  }
247
  if self.sam2_predictor is not None:
248
- info['sam2_model_type'] = type(self.sam2_predictor).__name__
 
249
  if self.matanyone_model is not None:
250
- info['matanyone_model_type'] = type(self.matanyone_model).__name__
 
251
  return info
252
 
253
  def get_load_summary(self) -> str:
@@ -258,11 +292,11 @@ def get_load_summary(self) -> str:
258
  total_time = self.loading_stats['total_load_time']
259
  summary = f"Models loaded in {total_time:.1f}s\n"
260
  if self.sam2_predictor:
261
- summary += f"✓ SAM2: {sam2_time:.1f}s\n"
262
  else:
263
  summary += f"✗ SAM2: Failed (using fallback)\n"
264
  if self.matanyone_model:
265
- summary += f"✓ MatAnyOne: {matanyone_time:.1f}s\n"
266
  else:
267
  summary += f"✗ MatAnyOne: Failed (using OpenCV)\n"
268
  summary += f"Device: {self.device}"
@@ -278,7 +312,8 @@ def validate_models(self) -> bool:
278
  try:
279
  has_valid_model = False
280
  if self.sam2_predictor is not None:
281
- if hasattr(self.sam2_predictor, 'set_image') or hasattr(self.sam2_predictor, 'predict'):
 
282
  has_valid_model = True
283
  if self.matanyone_model is not None:
284
  has_valid_model = True
@@ -300,4 +335,3 @@ def models_ready(self) -> bool:
300
  # ============================================================================
301
  # END MODEL LOADER
302
  # ============================================================================
303
-
 
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
+ # ============================================================================
30
+ # LOADED MODEL DATA CONTAINER
31
+ # ============================================================================
32
+ class LoadedModel:
33
+ """
34
+ Tracks loaded model + metadata.
35
+ Useful for dashboards, export, analytics, etc.
36
+ """
37
+ def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""):
38
+ self.model = model
39
+ self.model_id = model_id
40
+ self.load_time = load_time
41
+ self.device = device
42
+ self.framework = framework
43
+
44
+ def to_dict(self):
45
+ return {
46
+ "model_id": self.model_id,
47
+ "framework": self.framework,
48
+ "device": self.device,
49
+ "load_time": self.load_time,
50
+ "loaded": self.model is not None
51
+ }
52
+
53
+ def __repr__(self):
54
+ return f"LoadedModel(id={self.model_id}, loaded={self.model is not None}, device={self.device}, framework={self.framework}, load_time={self.load_time:.2f}s)"
55
+
56
  # ============================================================================
57
  # MODEL LOADER CLASS
58
  # ============================================================================
 
61
  Loads and manages SAM2 and MatAnyOne models.
62
  Tune all model-specific logic/settings here.
63
  """
 
64
  def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
65
  self.device_manager = device_mgr
66
  self.memory_manager = memory_mgr
67
  self.device = self.device_manager.get_optimal_device()
68
 
69
+ self.sam2_predictor = None # LoadedModel instance or None
70
+ self.matanyone_model = None # LoadedModel instance or None
71
 
72
  self.checkpoints_dir = "./checkpoints"
73
  os.makedirs(self.checkpoints_dir, exist_ok=True)
 
104
  logger.info("Loading SAM2 predictor...")
105
  if progress_callback:
106
  progress_callback(0.1, "Loading SAM2 predictor...")
107
+ sam2_loaded = self._load_sam2_predictor(progress_callback)
108
 
109
+ if sam2_loaded is None:
110
  logger.warning("SAM2 loading failed - will use fallback segmentation")
111
  else:
112
+ self.sam2_predictor = sam2_loaded
113
+ sam2_time = self.sam2_predictor.load_time
114
  self.loading_stats['sam2_load_time'] = sam2_time
115
  logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
116
 
 
120
  progress_callback(0.6, "Loading MatAnyOne model...")
121
  matanyone_start = time.time()
122
 
123
+ matanyone_loaded = self._load_matanyone_model(progress_callback)
124
 
125
+ if matanyone_loaded is None:
126
  logger.warning("MatAnyOne loading failed - will use OpenCV refinement")
127
  else:
128
+ self.matanyone_model = matanyone_loaded
129
+ matanyone_time = self.matanyone_model.load_time
130
  self.loading_stats['matanyone_load_time'] = matanyone_time
131
  logger.info(f"MatAnyOne loaded in {matanyone_time:.1f}s")
132
 
133
  # Final status
134
  total_time = time.time() - start_time
135
  self.loading_stats['total_load_time'] = total_time
136
+ self.loading_stats['models_loaded'] = bool(self.sam2_predictor or self.matanyone_model)
137
 
138
  if progress_callback:
139
  if self.sam2_predictor or self.matanyone_model:
 
143
 
144
  logger.info(f"Model loading completed in {total_time:.2f}s")
145
 
146
+ return (self.sam2_predictor, self.matanyone_model)
147
 
148
  except Exception as e:
149
  error_msg = f"Model loading failed: {str(e)}"
 
160
  def _load_sam2_predictor(self, progress_callback: Optional[Callable] = None):
161
  """
162
  Loads SAM2 using the official Hugging Face interface.
163
+ Returns: LoadedModel instance or None
164
  """
165
  model_size = "large"
166
  try:
 
189
 
190
  try:
191
  from sam2.sam2_image_predictor import SAM2ImagePredictor
192
+ t0 = time.time()
193
  predictor = SAM2ImagePredictor.from_pretrained(model_id)
194
  if hasattr(predictor, 'model'):
195
  predictor.model = predictor.model.to(self.device)
196
+ t1 = time.time()
197
  logger.info("SAM2 loaded successfully via official from_pretrained")
198
+ return LoadedModel(
199
+ model=predictor,
200
+ model_id=model_id,
201
+ load_time=t1-t0,
202
+ device=str(self.device),
203
+ framework="sam2"
204
+ )
205
  except ImportError:
206
  logger.error("SAM2 module not found. Install with: pip install sam2")
207
  return None
 
215
  def _load_matanyone_model(self, progress_callback: Optional[Callable] = None):
216
  """
217
  Loads MatAnyOne using Hugging Face official 'matanyone' package.
218
+ Returns: LoadedModel instance or None
 
 
 
 
 
219
  """
220
  try:
221
  if progress_callback:
222
  progress_callback(0.7, "Loading MatAnyOne model...")
223
 
 
224
  from matanyone import InferenceCore
225
+ t0 = time.time()
 
 
 
226
  matanyone_kwargs = dict(
227
  repo_id="PeiqingYang/MatAnyone", # You can change to any compatible Hugging Face repo
228
  device=self.device, # Device to load on ("cuda" or "cpu")
229
+ dtype=torch.float32, # Or torch.float16 for fast, but only for GPUs with good fp16
230
  # chunk_size=512, # Optional: for memory tuning on large videos
231
  )
 
232
  processor = InferenceCore(**matanyone_kwargs)
233
+ t1 = time.time()
234
  logger.info("MatAnyOne loaded successfully (InferenceCore)")
235
+ return LoadedModel(
236
+ model=processor,
237
+ model_id=matanyone_kwargs["repo_id"],
238
+ load_time=t1-t0,
239
+ device=str(self.device),
240
+ framework="matanyone"
241
+ )
242
  except ImportError:
243
  logger.error("MatAnyOne module not found. Install with: pip install matanyone")
244
  return None
 
277
  'loading_stats': self.loading_stats.copy()
278
  }
279
  if self.sam2_predictor is not None:
280
+ info['sam2_model_type'] = type(self.sam2_predictor.model).__name__
281
+ info['sam2_metadata'] = self.sam2_predictor.to_dict()
282
  if self.matanyone_model is not None:
283
+ info['matanyone_model_type'] = type(self.matanyone_model.model).__name__
284
+ info['matanyone_metadata'] = self.matanyone_model.to_dict()
285
  return info
286
 
287
  def get_load_summary(self) -> str:
 
292
  total_time = self.loading_stats['total_load_time']
293
  summary = f"Models loaded in {total_time:.1f}s\n"
294
  if self.sam2_predictor:
295
+ summary += f"✓ SAM2: {sam2_time:.1f}s (ID: {self.sam2_predictor.model_id})\n"
296
  else:
297
  summary += f"✗ SAM2: Failed (using fallback)\n"
298
  if self.matanyone_model:
299
+ summary += f"✓ MatAnyOne: {matanyone_time:.1f}s (ID: {self.matanyone_model.model_id})\n"
300
  else:
301
  summary += f"✗ MatAnyOne: Failed (using OpenCV)\n"
302
  summary += f"Device: {self.device}"
 
312
  try:
313
  has_valid_model = False
314
  if self.sam2_predictor is not None:
315
+ model = self.sam2_predictor.model
316
+ if hasattr(model, 'set_image') or hasattr(model, 'predict'):
317
  has_valid_model = True
318
  if self.matanyone_model is not None:
319
  has_valid_model = True
 
335
  # ============================================================================
336
  # END MODEL LOADER
337
  # ============================================================================