Update models/loaders/sam2_loader.py
Browse files- models/loaders/sam2_loader.py +68 -17
 
    	
        models/loaders/sam2_loader.py
    CHANGED
    
    | 
         @@ -3,7 +3,7 @@ 
     | 
|
| 3 | 
         | 
| 4 | 
         
             
            """
         
     | 
| 5 | 
         
             
            SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x)
         
     | 
| 6 | 
         
            -
            -  
     | 
| 7 | 
         
             
            - Never assigns predictor.device (read-only) — moves .model to device instead
         
     | 
| 8 | 
         
             
            - Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env
         
     | 
| 9 | 
         
             
            - Downscale ladder on set_image(); upsample masks back to original H,W
         
     | 
| 
         @@ -244,6 +244,12 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_ca 
     | 
|
| 244 | 
         
             
                    self.load_time = 0.0
         
     | 
| 245 | 
         | 
| 246 | 
         
             
                def _determine_optimal_size(self) -> str:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 247 | 
         
             
                    try:
         
     | 
| 248 | 
         
             
                        if torch.cuda.is_available():
         
     | 
| 249 | 
         
             
                            props = torch.cuda.get_device_properties(0)
         
     | 
| 
         @@ -260,11 +266,12 @@ def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]: 
     | 
|
| 260 | 
         
             
                    if model_size == "auto":
         
     | 
| 261 | 
         
             
                        model_size = self._determine_optimal_size()
         
     | 
| 262 | 
         | 
| 
         | 
|
| 263 | 
         
             
                    model_map = {
         
     | 
| 264 | 
         
            -
                        "tiny":  "facebook/sam2 
     | 
| 265 | 
         
            -
                        "small": "facebook/sam2 
     | 
| 266 | 
         
            -
                        "base":  "facebook/sam2 
     | 
| 267 | 
         
            -
                        "large": "facebook/sam2 
     | 
| 268 | 
         
             
                    }
         
     | 
| 269 | 
         
             
                    self.model_id = model_map.get(model_size, model_map["tiny"])
         
     | 
| 270 | 
         
             
                    logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
         
     | 
| 
         @@ -288,17 +295,61 @@ def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]: 
     | 
|
| 288 | 
         
             
                    return None
         
     | 
| 289 | 
         | 
| 290 | 
         
             
                def _load_official(self):
         
     | 
| 291 | 
         
            -
                     
     | 
| 292 | 
         
            -
             
     | 
| 293 | 
         
            -
                         
     | 
| 294 | 
         
            -
                         
     | 
| 295 | 
         
            -
             
     | 
| 296 | 
         
            -
                         
     | 
| 297 | 
         
            -
             
     | 
| 298 | 
         
            -
                     
     | 
| 299 | 
         
            -
             
     | 
| 300 | 
         
            -
             
     | 
| 301 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 302 | 
         | 
| 303 | 
         
             
                def _load_fallback(self):
         
     | 
| 304 | 
         
             
                    class FallbackSAM2:
         
     | 
| 
         @@ -360,4 +411,4 @@ def get_info(self) -> Dict[str, Any]: 
     | 
|
| 360 | 
         
             
                m = out["masks"]
         
     | 
| 361 | 
         
             
                print("Masks:", m.shape, m.dtype, m.min(), m.max())
         
     | 
| 362 | 
         
             
                cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8))
         
     | 
| 363 | 
         
            -
                print("Wrote sam2_mask0.png")
         
     | 
| 
         | 
|
| 3 | 
         | 
| 4 | 
         
             
            """
         
     | 
| 5 | 
         
             
            SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x)
         
     | 
| 6 | 
         
            +
            - Uses traditional build_sam2 method with HF hub downloads for SAM 2.1 weights
         
     | 
| 7 | 
         
             
            - Never assigns predictor.device (read-only) — moves .model to device instead
         
     | 
| 8 | 
         
             
            - Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env
         
     | 
| 9 | 
         
             
            - Downscale ladder on set_image(); upsample masks back to original H,W
         
     | 
| 
         | 
|
| 244 | 
         
             
                    self.load_time = 0.0
         
     | 
| 245 | 
         | 
| 246 | 
         
             
                def _determine_optimal_size(self) -> str:
         
     | 
| 247 | 
         
            +
                    # Check environment variable first
         
     | 
| 248 | 
         
            +
                    env_size = os.environ.get("USE_SAM2", "").lower()
         
     | 
| 249 | 
         
            +
                    if env_size in ["tiny", "small", "base", "large"]:
         
     | 
| 250 | 
         
            +
                        logger.info(f"Using SAM2 size from environment: {env_size}")
         
     | 
| 251 | 
         
            +
                        return env_size
         
     | 
| 252 | 
         
            +
                        
         
     | 
| 253 | 
         
             
                    try:
         
     | 
| 254 | 
         
             
                        if torch.cuda.is_available():
         
     | 
| 255 | 
         
             
                            props = torch.cuda.get_device_properties(0)
         
     | 
| 
         | 
|
| 266 | 
         
             
                    if model_size == "auto":
         
     | 
| 267 | 
         
             
                        model_size = self._determine_optimal_size()
         
     | 
| 268 | 
         | 
| 269 | 
         
            +
                    # Use original SAM2 model names (without .1) for compatibility
         
     | 
| 270 | 
         
             
                    model_map = {
         
     | 
| 271 | 
         
            +
                        "tiny":  "facebook/sam2-hiera-tiny",
         
     | 
| 272 | 
         
            +
                        "small": "facebook/sam2-hiera-small",
         
     | 
| 273 | 
         
            +
                        "base":  "facebook/sam2-hiera-base-plus",
         
     | 
| 274 | 
         
            +
                        "large": "facebook/sam2-hiera-large",
         
     | 
| 275 | 
         
             
                    }
         
     | 
| 276 | 
         
             
                    self.model_id = model_map.get(model_size, model_map["tiny"])
         
     | 
| 277 | 
         
             
                    logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
         
     | 
| 
         | 
|
| 295 | 
         
             
                    return None
         
     | 
| 296 | 
         | 
| 297 | 
         
             
                def _load_official(self):
         
     | 
| 298 | 
         
            +
                    try:
         
     | 
| 299 | 
         
            +
                        from huggingface_hub import hf_hub_download
         
     | 
| 300 | 
         
            +
                        from sam2.build_sam import build_sam2
         
     | 
| 301 | 
         
            +
                        from sam2.sam2_image_predictor import SAM2ImagePredictor
         
     | 
| 302 | 
         
            +
                    except ImportError as e:
         
     | 
| 303 | 
         
            +
                        logger.error(f"Failed to import SAM2 components: {e}")
         
     | 
| 304 | 
         
            +
                        return None
         
     | 
| 305 | 
         
            +
                    
         
     | 
| 306 | 
         
            +
                    # Map model IDs to config files and checkpoint names
         
     | 
| 307 | 
         
            +
                    config_map = {
         
     | 
| 308 | 
         
            +
                        "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
         
     | 
| 309 | 
         
            +
                        "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
         
     | 
| 310 | 
         
            +
                        "facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"),
         
     | 
| 311 | 
         
            +
                        "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
         
     | 
| 312 | 
         
            +
                    }
         
     | 
| 313 | 
         
            +
                    
         
     | 
| 314 | 
         
            +
                    config_file, checkpoint_file = config_map.get(self.model_id, (None, None))
         
     | 
| 315 | 
         
            +
                    if not config_file:
         
     | 
| 316 | 
         
            +
                        raise ValueError(f"Unknown model: {self.model_id}")
         
     | 
| 317 | 
         
            +
                    
         
     | 
| 318 | 
         
            +
                    try:
         
     | 
| 319 | 
         
            +
                        # Download the checkpoint from HuggingFace
         
     | 
| 320 | 
         
            +
                        logger.info(f"Downloading checkpoint: {checkpoint_file}")
         
     | 
| 321 | 
         
            +
                        checkpoint_path = hf_hub_download(
         
     | 
| 322 | 
         
            +
                            repo_id=self.model_id,
         
     | 
| 323 | 
         
            +
                            filename=checkpoint_file,
         
     | 
| 324 | 
         
            +
                            cache_dir=self.cache_dir,
         
     | 
| 325 | 
         
            +
                            local_files_only=False
         
     | 
| 326 | 
         
            +
                        )
         
     | 
| 327 | 
         
            +
                        logger.info(f"Checkpoint downloaded to: {checkpoint_path}")
         
     | 
| 328 | 
         
            +
                        
         
     | 
| 329 | 
         
            +
                        # Also download the config file if needed
         
     | 
| 330 | 
         
            +
                        config_path = hf_hub_download(
         
     | 
| 331 | 
         
            +
                            repo_id=self.model_id,
         
     | 
| 332 | 
         
            +
                            filename=config_file,
         
     | 
| 333 | 
         
            +
                            cache_dir=self.cache_dir,
         
     | 
| 334 | 
         
            +
                            local_files_only=False
         
     | 
| 335 | 
         
            +
                        )
         
     | 
| 336 | 
         
            +
                        logger.info(f"Config downloaded to: {config_path}")
         
     | 
| 337 | 
         
            +
                        
         
     | 
| 338 | 
         
            +
                        # Build the model using the traditional method
         
     | 
| 339 | 
         
            +
                        sam2_model = build_sam2(config_path, checkpoint_path, device=self.device)
         
     | 
| 340 | 
         
            +
                        predictor = SAM2ImagePredictor(sam2_model)
         
     | 
| 341 | 
         
            +
                        
         
     | 
| 342 | 
         
            +
                        # Ensure model is on the correct device and in eval mode
         
     | 
| 343 | 
         
            +
                        if hasattr(predictor, "model"):
         
     | 
| 344 | 
         
            +
                            predictor.model = predictor.model.to(self.device)
         
     | 
| 345 | 
         
            +
                            predictor.model.eval()
         
     | 
| 346 | 
         
            +
                        
         
     | 
| 347 | 
         
            +
                        return predictor
         
     | 
| 348 | 
         
            +
                        
         
     | 
| 349 | 
         
            +
                    except Exception as e:
         
     | 
| 350 | 
         
            +
                        logger.error(f"Error loading SAM2 model: {e}")
         
     | 
| 351 | 
         
            +
                        logger.debug(traceback.format_exc())
         
     | 
| 352 | 
         
            +
                        return None
         
     | 
| 353 | 
         | 
| 354 | 
         
             
                def _load_fallback(self):
         
     | 
| 355 | 
         
             
                    class FallbackSAM2:
         
     | 
| 
         | 
|
| 411 | 
         
             
                m = out["masks"]
         
     | 
| 412 | 
         
             
                print("Masks:", m.shape, m.dtype, m.min(), m.max())
         
     | 
| 413 | 
         
             
                cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8))
         
     | 
| 414 | 
         
            +
                print("Wrote sam2_mask0.png")
         
     |