Harshith Reddy
Docker: Add OpenCV libs, use PYTORCH_ALLOC_CONF only with max_split_size_mb=128, create postBuild for Python Spaces fallback
e0e8f87
| import os | |
| import sys | |
| import subprocess | |
| if 'OMP_NUM_THREADS' not in os.environ or not os.environ['OMP_NUM_THREADS'].isdigit(): | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| print(f"β Set OMP_NUM_THREADS={os.environ['OMP_NUM_THREADS']}") | |
| if 'PYTORCH_ALLOC_CONF' not in os.environ: | |
| os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128' | |
| print(f"Set PYTORCH_ALLOC_CONF={os.environ['PYTORCH_ALLOC_CONF']}") | |
| ENABLE_TORCH_COMPILE = os.environ.get('ENABLE_TORCH_COMPILE', 'false').lower() == 'true' | |
| if ENABLE_TORCH_COMPILE: | |
| print("torch.compile enabled. First inference may take 30-60s for compilation.") | |
| print(" Set ENABLE_TORCH_COMPILE=false to disable for faster first run") | |
| print(" Set TORCH_COMPILE_MODE=max-autotune for maximum speed (slower first run)") | |
| else: | |
| print("torch.compile disabled by default (set ENABLE_TORCH_COMPILE=true to enable)") | |
| ENABLE_CUDNN_BENCHMARK = os.environ.get('ENABLE_CUDNN_BENCHMARK', 'true').lower() == 'true' | |
| INFERENCE_TIMEOUT = int(os.environ.get('INFERENCE_TIMEOUT', '1800')) | |
| MAX_GRADIO_CONCURRENCY = int(os.environ.get('MAX_GRADIO_CONCURRENCY', '1')) | |
| import gradio as gr | |
| print(f"Gradio version: {gr.__version__}") | |
| try: | |
| from gradio.routes import mount_gradio_app | |
| HAS_MOUNT_GRADIO_APP = True | |
| except ImportError: | |
| HAS_MOUNT_GRADIO_APP = False | |
| print("β CRITICAL: mount_gradio_app not available. Gradio version too old. Need >= 4.44.1") | |
| print(f"β Current Gradio version: {gr.__version__}") | |
| print("β Please ensure requirements.txt has gradio==4.44.1") | |
| try: | |
| import spaces | |
| HAS_SPACES = True | |
| except ImportError: | |
| HAS_SPACES = False | |
| print("Warning: spaces module not found. GPU decorator will not be used.") | |
| srma_mamba_paths = [ | |
| os.path.join(os.path.dirname(__file__), 'SRMA-Mamba'), | |
| os.path.join(os.path.dirname(__file__), '../../SRMA-Mamba'), | |
| 'SRMA-Mamba', | |
| ] | |
| SRMA_MAMBA_DIR = None | |
| for path in srma_mamba_paths: | |
| if os.path.exists(path): | |
| sys.path.insert(0, path) | |
| SRMA_MAMBA_DIR = path | |
| print(f"Found model code at: {path}") | |
| break | |
| else: | |
| print("Warning: SRMA-Mamba directory not found. Model imports may fail.") | |
| BUILD_SRMAMAMBA_AVAILABLE = False | |
| build_SRMAMamba = None | |
| try: | |
| import mamba_ssm | |
| HAS_MAMBA_SSM = True | |
| try: | |
| version = mamba_ssm.__version__ | |
| print(f"mamba_ssm CUDA extension loaded (version: {version}) - fast path enabled") | |
| except: | |
| print("mamba_ssm CUDA extension loaded - fast path enabled") | |
| except ImportError: | |
| HAS_MAMBA_SSM = False | |
| print("ERROR: mamba_ssm not found. This is CRITICAL for speed. Model will use slow fallback.") | |
| print(" To install: Run setup.sh or: pip install mamba-ssm>=2.2.2") | |
| import os | |
| if os.environ.get('REQUIRE_CUDA_EXTENSIONS', 'false').lower() == 'true': | |
| raise ImportError("mamba_ssm is required but not installed. Set REQUIRE_CUDA_EXTENSIONS=false to allow fallback.") | |
| try: | |
| import selective_scan_cuda_oflex | |
| HAS_SELECTIVE_SCAN_CUDA = True | |
| print("selective_scan_cuda_oflex CUDA extension loaded - fast path enabled") | |
| except ImportError: | |
| HAS_SELECTIVE_SCAN_CUDA = False | |
| print("ERROR: selective_scan_cuda_oflex not found. This is CRITICAL for speed. Model will use slow fallback.") | |
| print(" To install: Run setup.sh or: cd SRMA-Mamba/selective_scan && pip install -e .") | |
| import os | |
| if os.environ.get('REQUIRE_CUDA_EXTENSIONS', 'false').lower() == 'true': | |
| raise ImportError("selective_scan_cuda_oflex is required but not installed. Set REQUIRE_CUDA_EXTENSIONS=false to allow fallback.") | |
| try: | |
| from configs.model_configs import build_SRMAMamba | |
| BUILD_SRMAMAMBA_AVAILABLE = True | |
| print("β Successfully imported build_SRMAMamba") | |
| except ImportError as e: | |
| error_str = str(e) | |
| print(f"Import error: {error_str}") | |
| if 'mamba_ssm' in error_str or 'mamba-ssm' in error_str: | |
| print("β mamba-ssm not found. Attempting runtime installation...") | |
| print("This may take 5-10 minutes. Please wait...") | |
| os.environ['FORCE_CUDA'] = '1' | |
| if 'CUDA_HOME' not in os.environ: | |
| os.environ['CUDA_HOME'] = '/usr/local/cuda' | |
| try: | |
| print("Attempting mamba-ssm installation (method 1)...") | |
| result = subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--no-cache-dir", "mamba-ssm>=2.2.2"], | |
| capture_output=True, | |
| text=True, | |
| timeout=900 | |
| ) | |
| if result.returncode != 0: | |
| print(f"Method 1 failed. Trying method 2 (no build isolation)...") | |
| result = subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--no-cache-dir", "--no-build-isolation", "mamba-ssm>=2.2.2"], | |
| capture_output=True, | |
| text=True, | |
| timeout=900 | |
| ) | |
| if result.returncode == 0: | |
| print("β mamba-ssm installed successfully") | |
| try: | |
| from configs.model_configs import build_SRMAMamba | |
| BUILD_SRMAMAMBA_AVAILABLE = True | |
| print("β Successfully imported build_SRMAMamba after installation") | |
| except ImportError as e2: | |
| print(f"β Still cannot import after installation: {e2}") | |
| print("β App will start but model loading will fail") | |
| BUILD_SRMAMAMBA_AVAILABLE = False | |
| else: | |
| print(f"β Installation failed. Output: {result.stdout[:500]}") | |
| print(f"β Error: {result.stderr[:500]}") | |
| print("β App will start but model loading will fail") | |
| BUILD_SRMAMAMBA_AVAILABLE = False | |
| except subprocess.TimeoutExpired: | |
| print("β Installation timed out after 15 minutes") | |
| print("β App will start but model loading will fail") | |
| BUILD_SRMAMAMBA_AVAILABLE = False | |
| except Exception as install_error: | |
| print(f"β Installation error: {install_error}") | |
| print("β App will start but model loading will fail") | |
| BUILD_SRMAMAMBA_AVAILABLE = False | |
| else: | |
| print(f"β Import error (not mamba-ssm related): {e}") | |
| print("β App will start but model loading will fail") | |
| BUILD_SRMAMAMBA_AVAILABLE = False | |