| import os | |
| import time | |
| import torch | |
| from monai.inferers import SlidingWindowInferer | |
| from config import BUILD_SRMAMAMBA_AVAILABLE, build_SRMAMamba, SRMA_MAMBA_DIR | |
| MODEL_T1 = None | |
| MODEL_T2 = None | |
| DEVICE = torch.device('cpu') | |
| WINDOW_INFER = None | |
| def clear_gpu_memory(): | |
| global MODEL_T1, MODEL_T2, WINDOW_INFER | |
| if torch.cuda.is_available(): | |
| if MODEL_T1 is not None: | |
| del MODEL_T1 | |
| MODEL_T1 = None | |
| if MODEL_T2 is not None: | |
| del MODEL_T2 | |
| MODEL_T2 = None | |
| if WINDOW_INFER is not None: | |
| del WINDOW_INFER | |
| WINDOW_INFER = None | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| print(" β GPU memory cleared (models unloaded)") | |
| def load_model(modality='T1'): | |
| global MODEL_T1, MODEL_T2, DEVICE, WINDOW_INFER, BUILD_SRMAMAMBA_AVAILABLE | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| if not BUILD_SRMAMAMBA_AVAILABLE or build_SRMAMamba is None: | |
| error_msg = "Model builder (build_SRMAMamba) is not available. Please check the logs for import errors." | |
| print(f"β {error_msg}") | |
| raise ImportError(error_msg) | |
| print(f"Loading {modality} model...") | |
| if torch.cuda.is_available(): | |
| try: | |
| max_retries = 3 | |
| retry_delay = 2 | |
| for attempt in range(max_retries): | |
| try: | |
| torch.cuda.empty_cache() | |
| test_tensor = torch.zeros(1).cuda() | |
| del test_tensor | |
| torch.cuda.synchronize() | |
| DEVICE = torch.device('cuda') | |
| print(f"β Using device: {DEVICE}") | |
| break | |
| except RuntimeError as e: | |
| if "CUDA" in str(e) and attempt < max_retries - 1: | |
| print(f"β GPU wake-up attempt {attempt + 1}/{max_retries}: {e}") | |
| print(f"β Waiting {retry_delay}s for GPU to wake up...") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 | |
| else: | |
| raise | |
| except Exception as e: | |
| print(f"β CUDA available but failed to initialize: {e}. Falling back to CPU.") | |
| DEVICE = torch.device('cpu') | |
| else: | |
| DEVICE = torch.device('cpu') | |
| print(f"βΉ CUDA not available. Using device: {DEVICE}") | |
| if DEVICE.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
| reserved = torch.cuda.memory_reserved(0) / (1024**3) | |
| total = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| free_memory_gb = total - allocated | |
| print(f" β GPU memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {free_memory_gb:.2f} GB free (total: {total:.2f} GB)") | |
| if free_memory_gb < 1.0: | |
| print(f" β CRITICAL: Very low free memory ({free_memory_gb:.2f} GB). Using ultra-minimal settings.") | |
| size = [192, 192, 32] | |
| batch_size = 1 | |
| overlap = 0.25 | |
| elif free_memory_gb < 2.0: | |
| print(f" β WARNING: Very low free memory ({free_memory_gb:.2f} GB). Using minimal settings.") | |
| size = [192, 192, 32] | |
| batch_size = 1 | |
| overlap = 0.25 | |
| elif free_memory_gb < 5.0: | |
| size = [224, 224, 48] | |
| batch_size = 1 | |
| overlap = 0.2 | |
| elif free_memory_gb > 40: | |
| print(f" Very high VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings for maximum speed.") | |
| size = [256, 256, 80] | |
| batch_size = 2 | |
| overlap = 0.1 | |
| elif free_memory_gb > 30: | |
| print(f" High VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings for speed.") | |
| size = [256, 256, 64] | |
| batch_size = 2 | |
| overlap = 0.1 | |
| elif free_memory_gb > 25: | |
| print(f" β Large VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings.") | |
| size = [256, 256, 64] | |
| batch_size = 1 | |
| overlap = 0.1 | |
| elif free_memory_gb > 20: | |
| size = [256, 256, 64] | |
| batch_size = 1 | |
| overlap = 0.1 | |
| elif free_memory_gb > 15: | |
| size = [256, 256, 64] | |
| batch_size = 1 | |
| overlap = 0.1 | |
| elif free_memory_gb > 10: | |
| size = [224, 224, 64] | |
| batch_size = 1 | |
| overlap = 0.1 | |
| elif free_memory_gb > 8: | |
| size = [224, 224, 48] | |
| batch_size = 1 | |
| overlap = 0.2 | |
| else: | |
| size = [192, 192, 48] | |
| batch_size = 1 | |
| overlap = 0.2 | |
| else: | |
| size = [224, 224, 64] | |
| batch_size = 1 | |
| overlap = 0.15 | |
| print(f" β Sliding window config: roi_size={size}, sw_batch_size={batch_size}, overlap={overlap}") | |
| print("Building model architecture...") | |
| if SRMA_MAMBA_DIR: | |
| original_cwd = os.getcwd() | |
| try: | |
| os.chdir(SRMA_MAMBA_DIR) | |
| print(f"Changed working directory to: {SRMA_MAMBA_DIR}") | |
| model = build_SRMAMamba() | |
| print("β Model architecture built") | |
| finally: | |
| os.chdir(original_cwd) | |
| else: | |
| model = build_SRMAMamba() | |
| print("β Model architecture built") | |
| model = model.to(DEVICE) | |
| print(f"β Model moved to {DEVICE}") | |
| checkpoint_path = f"checkpoint_{modality}.pth" | |
| possible_paths = [ | |
| checkpoint_path, | |
| os.path.join(os.path.dirname(__file__), checkpoint_path), | |
| f"../../Chkpoints/checkpoint_{modality}.pth", | |
| f"Chkpoints/checkpoint_{modality}.pth", | |
| f"../Chkpoints/checkpoint_{modality}.pth", | |
| f"Model/Chkpoints/checkpoint_{modality}.pth", | |
| os.path.join(os.path.dirname(__file__), f"Chkpoints/checkpoint_{modality}.pth"), | |
| ] | |
| found = False | |
| for path in possible_paths: | |
| abs_path = os.path.abspath(path) | |
| if os.path.exists(path) or os.path.exists(abs_path): | |
| checkpoint_path = path if os.path.exists(path) else abs_path | |
| found = True | |
| print(f"β Found checkpoint at: {checkpoint_path}") | |
| break | |
| if not found: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| repo_id = os.environ.get("HF_MODEL_REPO", "HarshithReddy01/srmamamba-liver-segmentation") | |
| print(f"Attempting to download checkpoint from Hugging Face: {repo_id}") | |
| checkpoint_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=f"checkpoint_{modality}.pth", | |
| cache_dir="." | |
| ) | |
| found = True | |
| print(f"β Downloaded checkpoint to: {checkpoint_path}") | |
| except Exception as e: | |
| error_msg = f"Checkpoint not found. Searched: {possible_paths}. Hugging Face download failed: {str(e)}" | |
| print(f"β {error_msg}") | |
| raise FileNotFoundError(error_msg) | |
| print(f"Loading checkpoint weights from: {checkpoint_path}") | |
| try: | |
| checkpoint = torch.load(checkpoint_path, map_location=DEVICE) | |
| if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| print("β Checkpoint loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load checkpoint: {e}") | |
| raise | |
| model.eval() | |
| print("β Model set to evaluation mode") | |
| if DEVICE.type == 'cuda': | |
| import config | |
| from packaging import version | |
| torch_version = version.parse(torch.__version__) | |
| if torch_version >= version.parse("2.9.0"): | |
| torch.backends.cuda.matmul.fp32_precision = 'tf32' | |
| torch.backends.cudnn.conv.fp32_precision = 'tf32' | |
| tf32_matmul = torch.backends.cuda.matmul.fp32_precision | |
| tf32_conv = torch.backends.cudnn.conv.fp32_precision | |
| else: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| tf32_matmul = 'tf32' if torch.backends.cuda.matmul.allow_tf32 else 'ieee' | |
| tf32_conv = 'tf32' if torch.backends.cudnn.allow_tf32 else 'ieee' | |
| torch.backends.cudnn.benchmark = True | |
| print(f"TF32 enabled: matmul={tf32_matmul}, conv={tf32_conv}") | |
| print("cuDNN benchmarking enabled") | |
| if config.ENABLE_TORCH_COMPILE: | |
| try: | |
| compile_mode = os.environ.get('TORCH_COMPILE_MODE', 'reduce-overhead') | |
| if compile_mode == 'max-autotune': | |
| print(f" β Compiling with max-autotune (may take 2-5 min on first run)...") | |
| model = torch.compile(model, mode='max-autotune', fullgraph=False) | |
| print(f"β Model compiled with torch.compile (mode=max-autotune, fullgraph=False)") | |
| elif compile_mode == 'default': | |
| print(f" β Compiling with default mode (may take 1-3 min on first run)...") | |
| model = torch.compile(model, fullgraph=False) | |
| print(f"β Model compiled with torch.compile (mode=default, fullgraph=False)") | |
| else: | |
| print(f" β Compiling with reduce-overhead (faster first run, ~30-60s)...") | |
| model = torch.compile(model, mode='reduce-overhead', fullgraph=False) | |
| print(f"β Model compiled with torch.compile (mode=reduce-overhead, fullgraph=False)") | |
| except Exception as e: | |
| print(f" β torch.compile failed: {e}. Continuing without compilation.") | |
| else: | |
| print(" βΉ torch.compile disabled (set ENABLE_TORCH_COMPILE=true to enable)") | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| allocated_after_load = torch.cuda.memory_allocated(0) / (1024**3) | |
| free_after_load = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) | |
| print(f" β GPU memory after model load: {allocated_after_load:.2f} GB allocated, {free_after_load:.2f} GB free") | |
| if free_after_load < 1.0: | |
| print(f" β CRITICAL: Only {free_after_load:.2f} GB free after model load. Using ultra-minimal settings.") | |
| size = [192, 192, 32] | |
| batch_size = 1 | |
| overlap = 0.25 | |
| elif free_after_load < 2.0: | |
| print(f" β WARNING: Low free memory ({free_after_load:.2f} GB) after model load. Adjusting to minimal settings.") | |
| size = [192, 192, 32] | |
| batch_size = 1 | |
| overlap = 0.25 | |
| elif free_after_load > 40: | |
| print(f" Excellent free memory ({free_after_load:.2f} GB) after model load. Using optimal settings for maximum speed.") | |
| size = [256, 256, 80] | |
| batch_size = 2 | |
| overlap = 0.1 | |
| elif free_after_load > 30: | |
| print(f" Excellent free memory ({free_after_load:.2f} GB) after model load. Using optimal settings for speed.") | |
| size = [256, 256, 64] | |
| batch_size = 2 | |
| overlap = 0.1 | |
| elif free_after_load > 25: | |
| print(f" β Good free memory ({free_after_load:.2f} GB) after model load. Using optimal settings.") | |
| size = [256, 256, 64] | |
| batch_size = 1 | |
| overlap = 0.1 | |
| elif free_after_load > 20: | |
| print(f" β Good free memory ({free_after_load:.2f} GB) after model load. Using optimal settings.") | |
| size = [256, 256, 64] | |
| batch_size = 1 | |
| overlap = 0.1 | |
| elif free_after_load > 15: | |
| size = [256, 256, 64] | |
| batch_size = 1 | |
| overlap = 0.1 | |
| elif free_after_load < 5.0 and (size[0] > 224 or batch_size > 1): | |
| print(f" β WARNING: Limited free memory ({free_after_load:.2f} GB). Reducing window size and batch size.") | |
| size = [224, 224, 48] | |
| batch_size = 1 | |
| overlap = 0.1 | |
| aggregation_device = 'cuda' | |
| if free_after_load < 2.0: | |
| aggregation_device = 'cpu' | |
| print(f" β Very low VRAM ({free_after_load:.2f} GB), using CPU aggregation to prevent OOM") | |
| else: | |
| print(f" β Using GPU aggregation for maximum speed (VRAM: {free_after_load:.2f} GB free)") | |
| WINDOW_INFER = SlidingWindowInferer( | |
| roi_size=size, | |
| sw_batch_size=batch_size, | |
| overlap=overlap, | |
| sw_device='cuda', | |
| device=aggregation_device | |
| ) | |
| print(f"β Sliding window inferer created (GPU compute, {aggregation_device.upper()} aggregation)") | |
| if DEVICE.type == 'cuda': | |
| if config.ENABLE_TORCH_COMPILE: | |
| print(" Running warm-up inference to trigger compilation and kernel autotuning...") | |
| print(" This may take 30-60s (reduce-overhead) or 2-5min (max-autotune) on first run...") | |
| else: | |
| print(" Running warm-up inference to trigger kernel autotuning...") | |
| try: | |
| dummy_input = torch.randn(1, 1, size[0], size[1], size[2], device=DEVICE, dtype=torch.float32) | |
| dummy_input = dummy_input.contiguous(memory_format=torch.channels_last_3d) | |
| warmup_start = time.time() | |
| with torch.no_grad(): | |
| from torch.amp import autocast | |
| with autocast(device_type='cuda'): | |
| _ = model(dummy_input) | |
| torch.cuda.synchronize() | |
| warmup_time = time.time() - warmup_start | |
| del dummy_input, _ | |
| torch.cuda.empty_cache() | |
| if config.ENABLE_TORCH_COMPILE: | |
| print(f" Warm-up completed in {warmup_time:.1f}s (compilation + kernel autotuning)") | |
| else: | |
| print(f" Warm-up completed in {warmup_time:.1f}s (kernels autotuned)") | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| print(f" Warm-up OOM (non-critical): {e}") | |
| print(f" Will use progressive fallback during inference") | |
| else: | |
| print(f" Warm-up failed (non-critical): {e}") | |
| except Exception as e: | |
| print(f" Warm-up failed (non-critical): {e}") | |
| if modality == 'T1': | |
| MODEL_T1 = model | |
| else: | |
| MODEL_T2 = model | |
| print(f"β {modality} model loaded and ready") | |
| return model | |