Upload handler.py with huggingface_hub
Browse files- handler.py +28 -31
handler.py
CHANGED
|
@@ -15,6 +15,30 @@ import numpy as np
|
|
| 15 |
from PIL import Image
|
| 16 |
import cv2
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(
|
| 20 |
level=logging.INFO,
|
|
@@ -24,7 +48,10 @@ logging.basicConfig(
|
|
| 24 |
)
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
|
|
|
|
|
|
| 27 |
# SAM3 imports - using local sam3 package in repository
|
|
|
|
| 28 |
from sam3.model_builder import build_sam3_video_predictor
|
| 29 |
|
| 30 |
# HuggingFace Hub for uploads
|
|
@@ -67,39 +94,9 @@ class EndpointHandler:
|
|
| 67 |
logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}")
|
| 68 |
logger.info(f"CUDA Version: {torch.version.cuda}")
|
| 69 |
logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 70 |
-
|
| 71 |
-
# CRITICAL FIX: Patch torch.autocast BEFORE building the predictor
|
| 72 |
-
# SAM3 has @torch.autocast decorators hardcoded to use BFloat16
|
| 73 |
-
# We need to override the autocast context manager to be a no-op
|
| 74 |
-
logger.info("Patching torch.autocast to disable BFloat16 (before model loading)...")
|
| 75 |
-
|
| 76 |
-
# Store the original autocast
|
| 77 |
-
self._original_autocast = torch.autocast
|
| 78 |
-
|
| 79 |
-
# Create a no-op autocast that always disables mixed precision
|
| 80 |
-
class Float32Autocast:
|
| 81 |
-
def __init__(self, device_type, dtype=None, enabled=True):
|
| 82 |
-
# Completely disable autocast
|
| 83 |
-
self.device_type = device_type
|
| 84 |
-
self.dtype = torch.float32
|
| 85 |
-
self.enabled = False
|
| 86 |
-
|
| 87 |
-
def __enter__(self):
|
| 88 |
-
return self
|
| 89 |
-
|
| 90 |
-
def __exit__(self, *args):
|
| 91 |
-
pass
|
| 92 |
-
|
| 93 |
-
# Monkey-patch torch.autocast globally BEFORE importing/building
|
| 94 |
-
torch.autocast = Float32Autocast
|
| 95 |
-
if hasattr(torch.cuda.amp, 'autocast'):
|
| 96 |
-
torch.cuda.amp.autocast = Float32Autocast
|
| 97 |
-
if hasattr(torch.amp, 'autocast'):
|
| 98 |
-
torch.amp.autocast = Float32Autocast
|
| 99 |
-
|
| 100 |
-
logger.info("✓ Patched torch.autocast to be a no-op (forces float32)")
|
| 101 |
|
| 102 |
# Build SAM3 video predictor
|
|
|
|
| 103 |
try:
|
| 104 |
logger.info("Building SAM3 video predictor...")
|
| 105 |
start_time = time.time()
|
|
|
|
| 15 |
from PIL import Image
|
| 16 |
import cv2
|
| 17 |
|
| 18 |
+
# CRITICAL: Patch torch.autocast BEFORE any SAM3 imports
|
| 19 |
+
# SAM3 uses @torch.autocast decorators that get applied at import time
|
| 20 |
+
# We must patch torch.autocast before the decorators are evaluated
|
| 21 |
+
class Float32Autocast:
|
| 22 |
+
"""No-op autocast that forces float32."""
|
| 23 |
+
def __init__(self, device_type, dtype=None, enabled=True):
|
| 24 |
+
self.device_type = device_type
|
| 25 |
+
self.dtype = torch.float32
|
| 26 |
+
self.enabled = False
|
| 27 |
+
|
| 28 |
+
def __enter__(self):
|
| 29 |
+
return self
|
| 30 |
+
|
| 31 |
+
def __exit__(self, *args):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
# Store original and replace globally
|
| 35 |
+
_ORIGINAL_AUTOCAST = torch.autocast
|
| 36 |
+
torch.autocast = Float32Autocast
|
| 37 |
+
if hasattr(torch.cuda, 'amp'):
|
| 38 |
+
torch.cuda.amp.autocast = Float32Autocast
|
| 39 |
+
if hasattr(torch, 'amp'):
|
| 40 |
+
torch.amp.autocast = Float32Autocast
|
| 41 |
+
|
| 42 |
# Configure logging
|
| 43 |
logging.basicConfig(
|
| 44 |
level=logging.INFO,
|
|
|
|
| 48 |
)
|
| 49 |
logger = logging.getLogger(__name__)
|
| 50 |
|
| 51 |
+
logger.info("✓ Patched torch.autocast globally before SAM3 import")
|
| 52 |
+
|
| 53 |
# SAM3 imports - using local sam3 package in repository
|
| 54 |
+
# This will now use our patched autocast for all @torch.autocast decorators
|
| 55 |
from sam3.model_builder import build_sam3_video_predictor
|
| 56 |
|
| 57 |
# HuggingFace Hub for uploads
|
|
|
|
| 94 |
logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}")
|
| 95 |
logger.info(f"CUDA Version: {torch.version.cuda}")
|
| 96 |
logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# Build SAM3 video predictor
|
| 99 |
+
# Note: torch.autocast was already patched at module import time
|
| 100 |
try:
|
| 101 |
logger.info("Building SAM3 video predictor...")
|
| 102 |
start_time = time.time()
|