Justin331 commited on
Commit
d920677
·
verified ·
1 Parent(s): af5bb79

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +28 -31
handler.py CHANGED
@@ -15,6 +15,30 @@ import numpy as np
15
  from PIL import Image
16
  import cv2
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Configure logging
19
  logging.basicConfig(
20
  level=logging.INFO,
@@ -24,7 +48,10 @@ logging.basicConfig(
24
  )
25
  logger = logging.getLogger(__name__)
26
 
 
 
27
  # SAM3 imports - using local sam3 package in repository
 
28
  from sam3.model_builder import build_sam3_video_predictor
29
 
30
  # HuggingFace Hub for uploads
@@ -67,39 +94,9 @@ class EndpointHandler:
67
  logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}")
68
  logger.info(f"CUDA Version: {torch.version.cuda}")
69
  logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
70
-
71
- # CRITICAL FIX: Patch torch.autocast BEFORE building the predictor
72
- # SAM3 has @torch.autocast decorators hardcoded to use BFloat16
73
- # We need to override the autocast context manager to be a no-op
74
- logger.info("Patching torch.autocast to disable BFloat16 (before model loading)...")
75
-
76
- # Store the original autocast
77
- self._original_autocast = torch.autocast
78
-
79
- # Create a no-op autocast that always disables mixed precision
80
- class Float32Autocast:
81
- def __init__(self, device_type, dtype=None, enabled=True):
82
- # Completely disable autocast
83
- self.device_type = device_type
84
- self.dtype = torch.float32
85
- self.enabled = False
86
-
87
- def __enter__(self):
88
- return self
89
-
90
- def __exit__(self, *args):
91
- pass
92
-
93
- # Monkey-patch torch.autocast globally BEFORE importing/building
94
- torch.autocast = Float32Autocast
95
- if hasattr(torch.cuda.amp, 'autocast'):
96
- torch.cuda.amp.autocast = Float32Autocast
97
- if hasattr(torch.amp, 'autocast'):
98
- torch.amp.autocast = Float32Autocast
99
-
100
- logger.info("✓ Patched torch.autocast to be a no-op (forces float32)")
101
 
102
  # Build SAM3 video predictor
 
103
  try:
104
  logger.info("Building SAM3 video predictor...")
105
  start_time = time.time()
 
15
  from PIL import Image
16
  import cv2
17
 
18
+ # CRITICAL: Patch torch.autocast BEFORE any SAM3 imports
19
+ # SAM3 uses @torch.autocast decorators that get applied at import time
20
+ # We must patch torch.autocast before the decorators are evaluated
21
+ class Float32Autocast:
22
+ """No-op autocast that forces float32."""
23
+ def __init__(self, device_type, dtype=None, enabled=True):
24
+ self.device_type = device_type
25
+ self.dtype = torch.float32
26
+ self.enabled = False
27
+
28
+ def __enter__(self):
29
+ return self
30
+
31
+ def __exit__(self, *args):
32
+ pass
33
+
34
+ # Store original and replace globally
35
+ _ORIGINAL_AUTOCAST = torch.autocast
36
+ torch.autocast = Float32Autocast
37
+ if hasattr(torch.cuda, 'amp'):
38
+ torch.cuda.amp.autocast = Float32Autocast
39
+ if hasattr(torch, 'amp'):
40
+ torch.amp.autocast = Float32Autocast
41
+
42
  # Configure logging
43
  logging.basicConfig(
44
  level=logging.INFO,
 
48
  )
49
  logger = logging.getLogger(__name__)
50
 
51
+ logger.info("✓ Patched torch.autocast globally before SAM3 import")
52
+
53
  # SAM3 imports - using local sam3 package in repository
54
+ # This will now use our patched autocast for all @torch.autocast decorators
55
  from sam3.model_builder import build_sam3_video_predictor
56
 
57
  # HuggingFace Hub for uploads
 
94
  logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}")
95
  logger.info(f"CUDA Version: {torch.version.cuda}")
96
  logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # Build SAM3 video predictor
99
+ # Note: torch.autocast was already patched at module import time
100
  try:
101
  logger.info("Building SAM3 video predictor...")
102
  start_time = time.time()