MogensR commited on
Commit
462ff09
·
1 Parent(s): 0bec751

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +81 -56
model_loader.py CHANGED
@@ -3,6 +3,10 @@
3
  Handles loading and validation of SAM2 and MatAnyone AI models
4
  """
5
 
 
 
 
 
6
  import os
7
  import gc
8
  import time
@@ -13,7 +17,6 @@
13
  from pathlib import Path
14
 
15
  import torch
16
- import hydra
17
  import gradio as gr
18
  from omegaconf import DictConfig, OmegaConf
19
 
@@ -24,9 +27,14 @@
24
 
25
  logger = logging.getLogger(__name__)
26
 
 
 
 
 
27
  class ModelLoader:
28
  """
29
  Comprehensive model loading and management for SAM2 and MatAnyone
 
30
  """
31
 
32
  def __init__(self, device_mgr: device_manager.DeviceManager, memory_mgr: memory_manager.MemoryManager):
@@ -40,7 +48,6 @@ def __init__(self, device_mgr: device_manager.DeviceManager, memory_mgr: memory_
40
  self.matanyone_core = None
41
 
42
  # Configuration paths
43
- self.configs_dir = os.path.abspath("Configs")
44
  self.checkpoints_dir = "./checkpoints"
45
  os.makedirs(self.checkpoints_dir, exist_ok=True)
46
 
@@ -55,6 +62,10 @@ def __init__(self, device_mgr: device_manager.DeviceManager, memory_mgr: memory_
55
 
56
  logger.info(f"ModelLoader initialized for device: {self.device}")
57
  self._apply_gradio_patch()
 
 
 
 
58
 
59
  def _apply_gradio_patch(self):
60
  """Apply Gradio schema monkey patch to prevent validation errors"""
@@ -75,7 +86,11 @@ def patched_get_config(self):
75
 
76
  except (ImportError, AttributeError) as e:
77
  logger.warning(f"Could not apply Gradio monkey patch: {e}")
78
-
 
 
 
 
79
  def load_all_models(self, progress_callback: Optional[callable] = None, cancel_event=None) -> Tuple[Any, Any]:
80
  """
81
  Load both SAM2 and MatAnyone models with comprehensive error handling
@@ -152,83 +167,69 @@ def load_all_models(self, progress_callback: Optional[callable] = None, cancel_e
152
  progress_callback(1.0, f"Error: {error_msg}")
153
 
154
  return None, None
155
-
 
 
 
 
156
  def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
157
  """
158
- Load SAM2 predictor with multiple fallback strategies
 
159
 
160
  Args:
161
  progress_callback: Progress update callback
162
 
163
  Returns:
164
- SAM2ImagePredictor or None
165
  """
166
- if not os.path.isdir(self.configs_dir):
167
- logger.warning(f"SAM2 Configs directory not found at '{self.configs_dir}', trying fallback loading")
168
-
169
- def try_load_sam2(config_name_with_yaml: str, checkpoint_name: str):
170
- """Attempt to load SAM2 with given config and checkpoint"""
171
  try:
172
- checkpoint_path = os.path.join(self.checkpoints_dir, checkpoint_name)
173
  logger.info(f"Attempting SAM2 checkpoint: {checkpoint_path}")
174
 
175
  # Download checkpoint if needed
176
  if not os.path.exists(checkpoint_path):
177
- logger.info(f"Downloading {checkpoint_name} from Hugging Face Hub...")
178
  if progress_callback:
179
- progress_callback(0.2, f"Downloading {checkpoint_name}...")
180
 
181
  try:
182
  from huggingface_hub import hf_hub_download
183
- repo = f"facebook/{config_name_with_yaml.replace('.yaml','')}"
184
  checkpoint_path = hf_hub_download(
185
- repo_id=repo,
186
- filename=checkpoint_name,
187
  cache_dir=self.checkpoints_dir,
188
  local_dir_use_symlinks=False
189
  )
190
  logger.info(f"Download complete: {checkpoint_path}")
191
  except Exception as download_error:
192
- logger.warning(f"Failed to download {checkpoint_name}: {download_error}")
193
  return None
194
 
195
- # Reset and initialize Hydra if configs directory exists
196
- if os.path.isdir(self.configs_dir):
197
- if hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
198
- hydra.core.global_hydra.GlobalHydra.instance().clear()
199
-
200
- hydra.initialize(
201
- version_base=None,
202
- config_path=os.path.relpath(self.configs_dir),
203
- job_name=f"sam2_load_{int(time.time())}"
204
- )
205
-
206
- # Build SAM2 model
207
- config_name = config_name_with_yaml.replace(".yaml", "")
208
  if progress_callback:
209
- progress_callback(0.4, f"Building {config_name}...")
210
 
211
- from sam2.build_sam import build_sam2
212
- from sam2.sam2_image_predictor import SAM2ImagePredictor
213
 
214
- sam2_model = build_sam2(config_name, checkpoint_path)
215
- sam2_model.to(self.device)
216
- predictor = SAM2ImagePredictor(sam2_model)
217
 
218
- logger.info(f"SAM2 {config_name} loaded successfully on {self.device}")
219
  return predictor
220
 
221
  except Exception as e:
222
- error_msg = f"Failed to load SAM2 {config_name_with_yaml}: {e}"
223
  logger.warning(error_msg)
224
  return None
225
 
226
- # Try different SAM2 model sizes based on device capabilities
227
  model_attempts = [
228
- ("sam2_hiera_large.yaml", "sam2_hiera_large.pt"),
229
- ("sam2_hiera_base_plus.yaml", "sam2_hiera_base_plus.pt"),
230
- ("sam2_hiera_small.yaml", "sam2_hiera_small.pt"),
231
- ("sam2_hiera_tiny.yaml", "sam2_hiera_tiny.pt")
232
  ]
233
 
234
  # Prioritize model size based on device memory
@@ -242,14 +243,18 @@ def try_load_sam2(config_name_with_yaml: str, checkpoint_name: str):
242
  except Exception as e:
243
  logger.warning(f"Could not determine device memory: {e}")
244
 
245
- for config_yaml, checkpoint_pt in model_attempts:
246
- predictor = try_load_sam2(config_yaml, checkpoint_pt)
247
  if predictor is not None:
248
  return predictor
249
 
250
  logger.error("All SAM2 model loading attempts failed")
251
  return None
252
-
 
 
 
 
253
  def _load_matanyone_model(self, progress_callback: Optional[callable] = None):
254
  """
255
  Load MatAnyone model with multiple import strategies
@@ -284,7 +289,11 @@ def _load_matanyone_model(self, progress_callback: Optional[callable] = None):
284
 
285
  logger.error("All MatAnyone loading strategies failed")
286
  return None, None
287
-
 
 
 
 
288
  def _load_matanyone_strategy_1(self):
289
  """MatAnyone loading strategy 1: Direct model import"""
290
  from matanyone.model.matanyone import MatAnyOne
@@ -351,7 +360,11 @@ def _load_matanyone_strategy_4(self):
351
  model = load_model_from_hub(model_path, device=self.device)
352
 
353
  return model, model # Return same object for both
354
-
 
 
 
 
355
  def _cleanup_models(self):
356
  """Clean up loaded models and free memory"""
357
  if self.sam2_predictor is not None:
@@ -372,6 +385,15 @@ def _cleanup_models(self):
372
 
373
  logger.debug("Model cleanup completed")
374
 
 
 
 
 
 
 
 
 
 
375
  def get_model_info(self) -> Dict[str, Any]:
376
  """
377
  Get information about loaded models
@@ -420,7 +442,11 @@ def get_load_summary(self) -> str:
420
  summary += f"Device: {self.device}"
421
 
422
  return summary
423
-
 
 
 
 
424
  def validate_models(self) -> bool:
425
  """
426
  Validate that models are properly loaded and functional
@@ -444,7 +470,11 @@ def validate_models(self) -> bool:
444
  except Exception as e:
445
  logger.error(f"Model validation failed: {e}")
446
  return False
447
-
 
 
 
 
448
  def reload_models(self, progress_callback: Optional[callable] = None) -> Tuple[Any, Any]:
449
  """
450
  Reload all models (useful for error recovery)
@@ -461,11 +491,6 @@ def reload_models(self, progress_callback: Optional[callable] = None) -> Tuple[A
461
 
462
  return self.load_all_models(progress_callback)
463
 
464
- def cleanup(self):
465
- """Clean up all resources"""
466
- self._cleanup_models()
467
- logger.info("ModelLoader cleanup completed")
468
-
469
  @property
470
  def models_ready(self) -> bool:
471
  """Check if all models are loaded and ready"""
 
3
  Handles loading and validation of SAM2 and MatAnyone AI models
4
  """
5
 
6
+ # ============================================================================ #
7
+ # IMPORTS AND DEPENDENCIES
8
+ # ============================================================================ #
9
+
10
  import os
11
  import gc
12
  import time
 
17
  from pathlib import Path
18
 
19
  import torch
 
20
  import gradio as gr
21
  from omegaconf import DictConfig, OmegaConf
22
 
 
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
+ # ============================================================================ #
31
+ # MODEL LOADER CLASS - MAIN INTERFACE
32
+ # ============================================================================ #
33
+
34
  class ModelLoader:
35
  """
36
  Comprehensive model loading and management for SAM2 and MatAnyone
37
+ Handles automatic config detection, multiple fallback strategies, and memory management
38
  """
39
 
40
  def __init__(self, device_mgr: device_manager.DeviceManager, memory_mgr: memory_manager.MemoryManager):
 
48
  self.matanyone_core = None
49
 
50
  # Configuration paths
 
51
  self.checkpoints_dir = "./checkpoints"
52
  os.makedirs(self.checkpoints_dir, exist_ok=True)
53
 
 
62
 
63
  logger.info(f"ModelLoader initialized for device: {self.device}")
64
  self._apply_gradio_patch()
65
+
66
+ # ============================================================================ #
67
+ # INITIALIZATION AND SETUP
68
+ # ============================================================================ #
69
 
70
  def _apply_gradio_patch(self):
71
  """Apply Gradio schema monkey patch to prevent validation errors"""
 
86
 
87
  except (ImportError, AttributeError) as e:
88
  logger.warning(f"Could not apply Gradio monkey patch: {e}")
89
+
90
+ # ============================================================================ #
91
+ # MAIN MODEL LOADING ORCHESTRATION
92
+ # ============================================================================ #
93
+
94
  def load_all_models(self, progress_callback: Optional[callable] = None, cancel_event=None) -> Tuple[Any, Any]:
95
  """
96
  Load both SAM2 and MatAnyone models with comprehensive error handling
 
167
  progress_callback(1.0, f"Error: {error_msg}")
168
 
169
  return None, None
170
+
171
+ # ============================================================================ #
172
+ # SAM2 MODEL LOADING - AUTOMATIC CONFIG DETECTION
173
+ # ============================================================================ #
174
+
175
  def _load_sam2_predictor(self, progress_callback: Optional[callable] = None):
176
  """
177
+ Load SAM2 predictor with automatic config detection - no manual config files needed
178
+ Uses build_sam2_video_predictor for automatic configuration based on checkpoint filename
179
 
180
  Args:
181
  progress_callback: Progress update callback
182
 
183
  Returns:
184
+ SAM2VideoPredictor or None
185
  """
186
+ def try_load_sam2_auto(repo_id: str, filename: str, model_name: str):
187
+ """Attempt to load SAM2 with automatic config detection"""
 
 
 
188
  try:
189
+ checkpoint_path = os.path.join(self.checkpoints_dir, filename)
190
  logger.info(f"Attempting SAM2 checkpoint: {checkpoint_path}")
191
 
192
  # Download checkpoint if needed
193
  if not os.path.exists(checkpoint_path):
194
+ logger.info(f"Downloading {filename} from Hugging Face Hub...")
195
  if progress_callback:
196
+ progress_callback(0.2, f"Downloading {filename}...")
197
 
198
  try:
199
  from huggingface_hub import hf_hub_download
 
200
  checkpoint_path = hf_hub_download(
201
+ repo_id=repo_id,
202
+ filename=filename,
203
  cache_dir=self.checkpoints_dir,
204
  local_dir_use_symlinks=False
205
  )
206
  logger.info(f"Download complete: {checkpoint_path}")
207
  except Exception as download_error:
208
+ logger.warning(f"Failed to download {filename}: {download_error}")
209
  return None
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  if progress_callback:
212
+ progress_callback(0.4, f"Building SAM2 {model_name}...")
213
 
214
+ # Use automatic config detection - NO manual config needed!
215
+ from sam2.build_sam import build_sam2_video_predictor
216
 
217
+ predictor = build_sam2_video_predictor(checkpoint_path, device=self.device)
 
 
218
 
219
+ logger.info(f"SAM2 {model_name} loaded successfully on {self.device}")
220
  return predictor
221
 
222
  except Exception as e:
223
+ error_msg = f"Failed to load SAM2 {model_name}: {e}"
224
  logger.warning(error_msg)
225
  return None
226
 
227
+ # Try different SAM2 models with automatic config detection
228
  model_attempts = [
229
+ ("facebook/sam2-hiera-large", "sam2_hiera_large.pt", "hiera_large"),
230
+ ("facebook/sam2-hiera-base-plus", "sam2_hiera_base_plus.pt", "hiera_base_plus"),
231
+ ("facebook/sam2-hiera-small", "sam2_hiera_small.pt", "hiera_small"),
232
+ ("facebook/sam2-hiera-tiny", "sam2_hiera_tiny.pt", "hiera_tiny")
233
  ]
234
 
235
  # Prioritize model size based on device memory
 
243
  except Exception as e:
244
  logger.warning(f"Could not determine device memory: {e}")
245
 
246
+ for repo_id, filename, model_name in model_attempts:
247
+ predictor = try_load_sam2_auto(repo_id, filename, model_name)
248
  if predictor is not None:
249
  return predictor
250
 
251
  logger.error("All SAM2 model loading attempts failed")
252
  return None
253
+
254
+ # ============================================================================ #
255
+ # MATANYONE MODEL LOADING - MULTIPLE STRATEGIES
256
+ # ============================================================================ #
257
+
258
  def _load_matanyone_model(self, progress_callback: Optional[callable] = None):
259
  """
260
  Load MatAnyone model with multiple import strategies
 
289
 
290
  logger.error("All MatAnyone loading strategies failed")
291
  return None, None
292
+
293
+ # ============================================================================ #
294
+ # MATANYONE LOADING STRATEGIES
295
+ # ============================================================================ #
296
+
297
  def _load_matanyone_strategy_1(self):
298
  """MatAnyone loading strategy 1: Direct model import"""
299
  from matanyone.model.matanyone import MatAnyOne
 
360
  model = load_model_from_hub(model_path, device=self.device)
361
 
362
  return model, model # Return same object for both
363
+
364
+ # ============================================================================ #
365
+ # MODEL MANAGEMENT AND CLEANUP
366
+ # ============================================================================ #
367
+
368
  def _cleanup_models(self):
369
  """Clean up loaded models and free memory"""
370
  if self.sam2_predictor is not None:
 
385
 
386
  logger.debug("Model cleanup completed")
387
 
388
+ def cleanup(self):
389
+ """Clean up all resources"""
390
+ self._cleanup_models()
391
+ logger.info("ModelLoader cleanup completed")
392
+
393
+ # ============================================================================ #
394
+ # MODEL INFORMATION AND STATUS
395
+ # ============================================================================ #
396
+
397
  def get_model_info(self) -> Dict[str, Any]:
398
  """
399
  Get information about loaded models
 
442
  summary += f"Device: {self.device}"
443
 
444
  return summary
445
+
446
+ # ============================================================================ #
447
+ # MODEL VALIDATION AND TESTING
448
+ # ============================================================================ #
449
+
450
  def validate_models(self) -> bool:
451
  """
452
  Validate that models are properly loaded and functional
 
470
  except Exception as e:
471
  logger.error(f"Model validation failed: {e}")
472
  return False
473
+
474
+ # ============================================================================ #
475
+ # UTILITY METHODS
476
+ # ============================================================================ #
477
+
478
  def reload_models(self, progress_callback: Optional[callable] = None) -> Tuple[Any, Any]:
479
  """
480
  Reload all models (useful for error recovery)
 
491
 
492
  return self.load_all_models(progress_callback)
493
 
 
 
 
 
 
494
  @property
495
  def models_ready(self) -> bool:
496
  """Check if all models are loaded and ready"""