Harshith Reddy
Docker: Add OpenCV libs, use PYTORCH_ALLOC_CONF only with max_split_size_mb=128, create postBuild for Python Spaces fallback
e0e8f87
import os
import sys
import subprocess
if 'OMP_NUM_THREADS' not in os.environ or not os.environ['OMP_NUM_THREADS'].isdigit():
os.environ['OMP_NUM_THREADS'] = '1'
print(f"βœ“ Set OMP_NUM_THREADS={os.environ['OMP_NUM_THREADS']}")
if 'PYTORCH_ALLOC_CONF' not in os.environ:
os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128'
print(f"Set PYTORCH_ALLOC_CONF={os.environ['PYTORCH_ALLOC_CONF']}")
ENABLE_TORCH_COMPILE = os.environ.get('ENABLE_TORCH_COMPILE', 'false').lower() == 'true'
if ENABLE_TORCH_COMPILE:
print("torch.compile enabled. First inference may take 30-60s for compilation.")
print(" Set ENABLE_TORCH_COMPILE=false to disable for faster first run")
print(" Set TORCH_COMPILE_MODE=max-autotune for maximum speed (slower first run)")
else:
print("torch.compile disabled by default (set ENABLE_TORCH_COMPILE=true to enable)")
ENABLE_CUDNN_BENCHMARK = os.environ.get('ENABLE_CUDNN_BENCHMARK', 'true').lower() == 'true'
INFERENCE_TIMEOUT = int(os.environ.get('INFERENCE_TIMEOUT', '1800'))
MAX_GRADIO_CONCURRENCY = int(os.environ.get('MAX_GRADIO_CONCURRENCY', '1'))
import gradio as gr
print(f"Gradio version: {gr.__version__}")
try:
from gradio.routes import mount_gradio_app
HAS_MOUNT_GRADIO_APP = True
except ImportError:
HAS_MOUNT_GRADIO_APP = False
print("⚠ CRITICAL: mount_gradio_app not available. Gradio version too old. Need >= 4.44.1")
print(f"⚠ Current Gradio version: {gr.__version__}")
print("⚠ Please ensure requirements.txt has gradio==4.44.1")
try:
import spaces
HAS_SPACES = True
except ImportError:
HAS_SPACES = False
print("Warning: spaces module not found. GPU decorator will not be used.")
srma_mamba_paths = [
os.path.join(os.path.dirname(__file__), 'SRMA-Mamba'),
os.path.join(os.path.dirname(__file__), '../../SRMA-Mamba'),
'SRMA-Mamba',
]
SRMA_MAMBA_DIR = None
for path in srma_mamba_paths:
if os.path.exists(path):
sys.path.insert(0, path)
SRMA_MAMBA_DIR = path
print(f"Found model code at: {path}")
break
else:
print("Warning: SRMA-Mamba directory not found. Model imports may fail.")
BUILD_SRMAMAMBA_AVAILABLE = False
build_SRMAMamba = None
try:
import mamba_ssm
HAS_MAMBA_SSM = True
try:
version = mamba_ssm.__version__
print(f"mamba_ssm CUDA extension loaded (version: {version}) - fast path enabled")
except:
print("mamba_ssm CUDA extension loaded - fast path enabled")
except ImportError:
HAS_MAMBA_SSM = False
print("ERROR: mamba_ssm not found. This is CRITICAL for speed. Model will use slow fallback.")
print(" To install: Run setup.sh or: pip install mamba-ssm>=2.2.2")
import os
if os.environ.get('REQUIRE_CUDA_EXTENSIONS', 'false').lower() == 'true':
raise ImportError("mamba_ssm is required but not installed. Set REQUIRE_CUDA_EXTENSIONS=false to allow fallback.")
try:
import selective_scan_cuda_oflex
HAS_SELECTIVE_SCAN_CUDA = True
print("selective_scan_cuda_oflex CUDA extension loaded - fast path enabled")
except ImportError:
HAS_SELECTIVE_SCAN_CUDA = False
print("ERROR: selective_scan_cuda_oflex not found. This is CRITICAL for speed. Model will use slow fallback.")
print(" To install: Run setup.sh or: cd SRMA-Mamba/selective_scan && pip install -e .")
import os
if os.environ.get('REQUIRE_CUDA_EXTENSIONS', 'false').lower() == 'true':
raise ImportError("selective_scan_cuda_oflex is required but not installed. Set REQUIRE_CUDA_EXTENSIONS=false to allow fallback.")
try:
from configs.model_configs import build_SRMAMamba
BUILD_SRMAMAMBA_AVAILABLE = True
print("βœ“ Successfully imported build_SRMAMamba")
except ImportError as e:
error_str = str(e)
print(f"Import error: {error_str}")
if 'mamba_ssm' in error_str or 'mamba-ssm' in error_str:
print("⚠ mamba-ssm not found. Attempting runtime installation...")
print("This may take 5-10 minutes. Please wait...")
os.environ['FORCE_CUDA'] = '1'
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = '/usr/local/cuda'
try:
print("Attempting mamba-ssm installation (method 1)...")
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", "mamba-ssm>=2.2.2"],
capture_output=True,
text=True,
timeout=900
)
if result.returncode != 0:
print(f"Method 1 failed. Trying method 2 (no build isolation)...")
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", "--no-build-isolation", "mamba-ssm>=2.2.2"],
capture_output=True,
text=True,
timeout=900
)
if result.returncode == 0:
print("βœ“ mamba-ssm installed successfully")
try:
from configs.model_configs import build_SRMAMamba
BUILD_SRMAMAMBA_AVAILABLE = True
print("βœ“ Successfully imported build_SRMAMamba after installation")
except ImportError as e2:
print(f"⚠ Still cannot import after installation: {e2}")
print("⚠ App will start but model loading will fail")
BUILD_SRMAMAMBA_AVAILABLE = False
else:
print(f"⚠ Installation failed. Output: {result.stdout[:500]}")
print(f"⚠ Error: {result.stderr[:500]}")
print("⚠ App will start but model loading will fail")
BUILD_SRMAMAMBA_AVAILABLE = False
except subprocess.TimeoutExpired:
print("⚠ Installation timed out after 15 minutes")
print("⚠ App will start but model loading will fail")
BUILD_SRMAMAMBA_AVAILABLE = False
except Exception as install_error:
print(f"⚠ Installation error: {install_error}")
print("⚠ App will start but model loading will fail")
BUILD_SRMAMAMBA_AVAILABLE = False
else:
print(f"⚠ Import error (not mamba-ssm related): {e}")
print("⚠ App will start but model loading will fail")
BUILD_SRMAMAMBA_AVAILABLE = False