MogensR commited on
Commit
2cd1a25
·
1 Parent(s): d51cab4

Update models/loaders/sam2_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/sam2_loader.py +68 -17
models/loaders/sam2_loader.py CHANGED
@@ -3,7 +3,7 @@
3
 
4
  """
5
  SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x)
6
- - Official HF load: SAM2ImagePredictor.from_pretrained(...)
7
  - Never assigns predictor.device (read-only) — moves .model to device instead
8
  - Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env
9
  - Downscale ladder on set_image(); upsample masks back to original H,W
@@ -244,6 +244,12 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_ca
244
  self.load_time = 0.0
245
 
246
  def _determine_optimal_size(self) -> str:
 
 
 
 
 
 
247
  try:
248
  if torch.cuda.is_available():
249
  props = torch.cuda.get_device_properties(0)
@@ -260,11 +266,12 @@ def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]:
260
  if model_size == "auto":
261
  model_size = self._determine_optimal_size()
262
 
 
263
  model_map = {
264
- "tiny": "facebook/sam2.1-hiera-tiny",
265
- "small": "facebook/sam2.1-hiera-small",
266
- "base": "facebook/sam2.1-hiera-base-plus",
267
- "large": "facebook/sam2.1-hiera-large",
268
  }
269
  self.model_id = model_map.get(model_size, model_map["tiny"])
270
  logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
@@ -288,17 +295,61 @@ def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]:
288
  return None
289
 
290
  def _load_official(self):
291
- from sam2.sam2_image_predictor import SAM2ImagePredictor
292
- predictor = SAM2ImagePredictor.from_pretrained(
293
- self.model_id,
294
- cache_dir=self.cache_dir,
295
- local_files_only=False,
296
- trust_remote_code=True,
297
- )
298
- if hasattr(predictor, "model"):
299
- predictor.model = predictor.model.to(self.device)
300
- predictor.model.eval()
301
- return predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  def _load_fallback(self):
304
  class FallbackSAM2:
@@ -360,4 +411,4 @@ def get_info(self) -> Dict[str, Any]:
360
  m = out["masks"]
361
  print("Masks:", m.shape, m.dtype, m.min(), m.max())
362
  cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8))
363
- print("Wrote sam2_mask0.png")
 
3
 
4
  """
5
  SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x)
6
+ - Uses traditional build_sam2 method with HF hub downloads for SAM 2.1 weights
7
  - Never assigns predictor.device (read-only) — moves .model to device instead
8
  - Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env
9
  - Downscale ladder on set_image(); upsample masks back to original H,W
 
244
  self.load_time = 0.0
245
 
246
  def _determine_optimal_size(self) -> str:
247
+ # Check environment variable first
248
+ env_size = os.environ.get("USE_SAM2", "").lower()
249
+ if env_size in ["tiny", "small", "base", "large"]:
250
+ logger.info(f"Using SAM2 size from environment: {env_size}")
251
+ return env_size
252
+
253
  try:
254
  if torch.cuda.is_available():
255
  props = torch.cuda.get_device_properties(0)
 
266
  if model_size == "auto":
267
  model_size = self._determine_optimal_size()
268
 
269
+ # Use original SAM2 model names (without .1) for compatibility
270
  model_map = {
271
+ "tiny": "facebook/sam2-hiera-tiny",
272
+ "small": "facebook/sam2-hiera-small",
273
+ "base": "facebook/sam2-hiera-base-plus",
274
+ "large": "facebook/sam2-hiera-large",
275
  }
276
  self.model_id = model_map.get(model_size, model_map["tiny"])
277
  logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
 
295
  return None
296
 
297
  def _load_official(self):
298
+ try:
299
+ from huggingface_hub import hf_hub_download
300
+ from sam2.build_sam import build_sam2
301
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
302
+ except ImportError as e:
303
+ logger.error(f"Failed to import SAM2 components: {e}")
304
+ return None
305
+
306
+ # Map model IDs to config files and checkpoint names
307
+ config_map = {
308
+ "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
309
+ "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
310
+ "facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"),
311
+ "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
312
+ }
313
+
314
+ config_file, checkpoint_file = config_map.get(self.model_id, (None, None))
315
+ if not config_file:
316
+ raise ValueError(f"Unknown model: {self.model_id}")
317
+
318
+ try:
319
+ # Download the checkpoint from HuggingFace
320
+ logger.info(f"Downloading checkpoint: {checkpoint_file}")
321
+ checkpoint_path = hf_hub_download(
322
+ repo_id=self.model_id,
323
+ filename=checkpoint_file,
324
+ cache_dir=self.cache_dir,
325
+ local_files_only=False
326
+ )
327
+ logger.info(f"Checkpoint downloaded to: {checkpoint_path}")
328
+
329
+ # Also download the config file if needed
330
+ config_path = hf_hub_download(
331
+ repo_id=self.model_id,
332
+ filename=config_file,
333
+ cache_dir=self.cache_dir,
334
+ local_files_only=False
335
+ )
336
+ logger.info(f"Config downloaded to: {config_path}")
337
+
338
+ # Build the model using the traditional method
339
+ sam2_model = build_sam2(config_path, checkpoint_path, device=self.device)
340
+ predictor = SAM2ImagePredictor(sam2_model)
341
+
342
+ # Ensure model is on the correct device and in eval mode
343
+ if hasattr(predictor, "model"):
344
+ predictor.model = predictor.model.to(self.device)
345
+ predictor.model.eval()
346
+
347
+ return predictor
348
+
349
+ except Exception as e:
350
+ logger.error(f"Error loading SAM2 model: {e}")
351
+ logger.debug(traceback.format_exc())
352
+ return None
353
 
354
  def _load_fallback(self):
355
  class FallbackSAM2:
 
411
  m = out["masks"]
412
  print("Masks:", m.shape, m.dtype, m.min(), m.max())
413
  cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8))
414
+ print("Wrote sam2_mask0.png")