Spaces:
Sleeping
Sleeping
| import os | |
| import yaml | |
| import torch | |
| import nibabel as nib | |
| import numpy as np | |
| import gradio as gr | |
| from typing import Tuple | |
| import tempfile | |
| import shutil | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| import cv2 # For Gaussian Blur | |
| import io # For saving plots to memory | |
| import base64 # For encoding plots | |
| import uuid # For unique IDs | |
| import traceback # For detailed error printing | |
| import SimpleITK as sitk | |
| import itk | |
| from scipy.signal import medfilt | |
| import skimage.filters | |
| from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Resized, NormalizeIntensityd, ToTensord | |
| from model import ViTBackboneNet, Classifier, SingleScanModelBP | |
| # Optional HD-BET import (packaged locally like in MCI app) | |
| try: | |
| from HD_BET.run import run_hd_bet | |
| from HD_BET.hd_bet import hd_bet | |
| except Exception as e: | |
| print(f"Warning: HD_BET not available: {e}") | |
| run_hd_bet = None | |
| hd_bet = None | |
| APP_DIR = os.path.dirname(__file__) | |
| TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates") | |
| PARAMS_RIGID_PATH = os.path.join(TEMPLATE_DIR, "Parameters_Rigid.txt") | |
| DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "temp_head.nii.gz") | |
| FLAIR_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_04.5-18.5_t2w.nii") | |
| T1C_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_13.0-18.5_t1w.nii") | |
| HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py") | |
| HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model") | |
| def load_config() -> dict: | |
| cfg_path = os.path.join(APP_DIR, "config.yml") | |
| if os.path.exists(cfg_path): | |
| with open(cfg_path, "r") as f: | |
| return yaml.safe_load(f) | |
| # Defaults | |
| return { | |
| "gpu": {"device": "cpu"}, | |
| "infer": { | |
| "checkpoints": "./checkpoints/idh_model.ckpt", | |
| "simclr_checkpoint": "./checkpoints/simclr_vitb.ckpt", | |
| "threshold": 0.5, | |
| "image_size": [96, 96, 96], | |
| }, | |
| } | |
| def build_model(cfg: dict): | |
| device = torch.device(cfg.get("gpu", {}).get("device", "cpu")) | |
| infer_cfg = cfg.get("infer", {}) | |
| simclr_path = os.path.join(APP_DIR, infer_cfg.get("simclr_checkpoint", "")) | |
| ckpt_path = os.path.join(APP_DIR, infer_cfg.get("checkpoints", "")) | |
| backbone = ViTBackboneNet(simclr_ckpt_path=simclr_path) | |
| classifier = Classifier(d_model=768, num_classes=1) | |
| model = SingleScanModelBP(backbone, classifier) | |
| # Load finetuned checkpoint (Lightning or plain state_dict) | |
| if os.path.exists(ckpt_path): | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| if "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("model."): | |
| new_state_dict[key[len("model."):]] = value | |
| else: | |
| new_state_dict[key] = value | |
| else: | |
| new_state_dict = checkpoint | |
| model.load_state_dict(new_state_dict, strict=False) | |
| else: | |
| print(f"Warning: Finetuned checkpoint not found at {ckpt_path}. Model will use backbone-only weights.") | |
| model.to(device) | |
| model.eval() | |
| return model, device | |
| # ---------------- Preprocessing (Registration + Enhancement + Skull Stripping) ---------------- | |
| def bias_field_correction(img_array: np.ndarray) -> np.ndarray: | |
| image = sitk.GetImageFromArray(img_array.astype(np.float32)) | |
| if image.GetPixelID() != sitk.sitkFloat32: | |
| image = sitk.Cast(image, sitk.sitkFloat32) | |
| maskImage = sitk.OtsuThreshold(image, 0, 1, 200) | |
| corrector = sitk.N4BiasFieldCorrectionImageFilter() | |
| numberFittingLevels = 4 | |
| max_iters = [min(50 * (2 ** i), 200) for i in range(numberFittingLevels)] | |
| corrector.SetMaximumNumberOfIterations(max_iters) | |
| corrected_image = corrector.Execute(image, maskImage) | |
| return sitk.GetArrayFromImage(corrected_image) | |
| def denoise(volume: np.ndarray, kernel_size: int = 3) -> np.ndarray: | |
| return medfilt(volume, kernel_size) | |
| def rescale_intensity(volume: np.ndarray, percentils=[0.5, 99.5], bins_num=256) -> np.ndarray: | |
| volume_float = volume.astype(np.float32) | |
| try: | |
| t = skimage.filters.threshold_otsu(volume_float, nbins=256) | |
| volume_masked = np.copy(volume_float) | |
| volume_masked[volume_masked < t] = 0 | |
| obj_volume = volume_masked[np.where(volume_masked > 0)] | |
| except ValueError: | |
| obj_volume = volume_float.flatten() | |
| if obj_volume.size == 0: | |
| obj_volume = volume_float.flatten() | |
| min_value = np.min(obj_volume) | |
| max_value = np.max(obj_volume) | |
| else: | |
| min_value = np.percentile(obj_volume, percentils[0]) | |
| max_value = np.percentile(obj_volume, percentils[1]) | |
| denom = max_value - min_value | |
| if denom < 1e-6: | |
| denom = 1e-6 | |
| if bins_num == 0: | |
| output_volume = (volume_float - min_value) / denom | |
| output_volume = np.clip(output_volume, 0.0, 1.0) | |
| else: | |
| output_volume = np.round((volume_float - min_value) / denom * (bins_num - 1)) | |
| output_volume = np.clip(output_volume, 0, bins_num - 1) | |
| return output_volume.astype(np.float32) | |
| def equalize_hist(volume: np.ndarray, bins_num=256) -> np.ndarray: | |
| mask = volume > 1e-6 | |
| obj_volume = volume[mask] | |
| if obj_volume.size == 0: | |
| return volume | |
| hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max())) | |
| cdf = hist.cumsum() | |
| cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1]) | |
| equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized) | |
| equalized_volume = np.copy(volume) | |
| equalized_volume[mask] = equalized_obj_volume | |
| return equalized_volume.astype(np.float32) | |
| def run_enhance_on_file(input_nifti_path: str, output_nifti_path: str): | |
| """ | |
| Simplified enhancement - just copy the file since N4 is now done in registration. | |
| This maintains compatibility with the existing preprocessing pipeline. | |
| """ | |
| print(f"Enhancement step (N4 already applied during registration): {input_nifti_path}") | |
| # Since N4 bias correction is now handled in registration, just copy the file | |
| import shutil | |
| shutil.copy2(input_nifti_path, output_nifti_path) | |
| print(f"Enhancement complete (passthrough): {output_nifti_path}") | |
| def register_image_sitk(input_nifti_path: str, output_nifti_path: str, template_path: str, interp_type='linear'): | |
| """ | |
| MRI registration with SimpleITK matching the provided script approach. | |
| Args: | |
| input_nifti_path: Path to input NIfTI file | |
| output_nifti_path: Path to save registered output | |
| template_path: Path to template image | |
| interp_type: Interpolation type ('linear', 'bspline', 'nearest_neighbor') | |
| """ | |
| print(f"Registering {input_nifti_path} to template {template_path}") | |
| # Read template and moving images | |
| fixed_img = sitk.ReadImage(template_path, sitk.sitkFloat32) | |
| moving_img = sitk.ReadImage(input_nifti_path, sitk.sitkFloat32) | |
| # Apply N4 bias correction to moving image | |
| moving_img = sitk.N4BiasFieldCorrection(moving_img) | |
| # Resample fixed image to 1mm isotropic | |
| old_size = fixed_img.GetSize() | |
| old_spacing = fixed_img.GetSpacing() | |
| new_spacing = (1, 1, 1) | |
| new_size = [ | |
| int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))), | |
| int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))), | |
| int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2]))) | |
| ] | |
| # Set interpolation type | |
| if interp_type == 'linear': | |
| interp_type = sitk.sitkLinear | |
| elif interp_type == 'bspline': | |
| interp_type = sitk.sitkBSpline | |
| elif interp_type == 'nearest_neighbor': | |
| interp_type = sitk.sitkNearestNeighbor | |
| else: | |
| interp_type = sitk.sitkLinear | |
| # Resample fixed image | |
| resample = sitk.ResampleImageFilter() | |
| resample.SetOutputSpacing(new_spacing) | |
| resample.SetSize(new_size) | |
| resample.SetOutputOrigin(fixed_img.GetOrigin()) | |
| resample.SetOutputDirection(fixed_img.GetDirection()) | |
| resample.SetInterpolator(interp_type) | |
| resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue()) | |
| resample.SetOutputPixelType(sitk.sitkFloat32) | |
| fixed_img = resample.Execute(fixed_img) | |
| # Initialize transform | |
| transform = sitk.CenteredTransformInitializer( | |
| fixed_img, | |
| moving_img, | |
| sitk.Euler3DTransform(), | |
| sitk.CenteredTransformInitializerFilter.GEOMETRY) | |
| # Set up registration method | |
| registration_method = sitk.ImageRegistrationMethod() | |
| registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) | |
| registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) | |
| registration_method.SetMetricSamplingPercentage(0.01) | |
| registration_method.SetInterpolator(sitk.sitkLinear) | |
| registration_method.SetOptimizerAsGradientDescent( | |
| learningRate=1.0, | |
| numberOfIterations=100, | |
| convergenceMinimumValue=1e-6, | |
| convergenceWindowSize=10) | |
| registration_method.SetOptimizerScalesFromPhysicalShift() | |
| registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) | |
| registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) | |
| registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() | |
| registration_method.SetInitialTransform(transform) | |
| # Execute registration | |
| final_transform = registration_method.Execute(fixed_img, moving_img) | |
| # Apply transform and save registered image | |
| moving_img_resampled = sitk.Resample( | |
| moving_img, | |
| fixed_img, | |
| final_transform, | |
| sitk.sitkLinear, | |
| 0.0, | |
| moving_img.GetPixelID()) | |
| sitk.WriteImage(moving_img_resampled, output_nifti_path) | |
| print(f"Registration complete. Saved to: {output_nifti_path}") | |
| def register_image(input_nifti_path: str, output_nifti_path: str): | |
| """Wrapper to maintain compatibility - now uses SimpleITK registration.""" | |
| if not os.path.exists(DEFAULT_TEMPLATE_PATH): | |
| raise FileNotFoundError(f"Template file missing: {DEFAULT_TEMPLATE_PATH}") | |
| register_image_sitk(input_nifti_path, output_nifti_path, DEFAULT_TEMPLATE_PATH) | |
| def run_skull_stripping(input_nifti_path: str, output_dir: str): | |
| """ | |
| Brain extraction using HD-BET direct integration matching the script approach. | |
| Args: | |
| input_nifti_path: Path to input NIfTI file | |
| output_dir: Directory to save skull-stripped output | |
| Returns: | |
| tuple: (output_file_path, output_mask_path) | |
| """ | |
| print(f"Running HD-BET skull stripping on {input_nifti_path}") | |
| if hd_bet is None: | |
| raise RuntimeError("HD-BET not available. Please include HD_BET and hdbet_model in src/IDH.") | |
| if not os.path.exists(HD_BET_MODEL_DIR): | |
| raise FileNotFoundError(f"HD-BET models not found at {HD_BET_MODEL_DIR}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Get base filename and prepare HD-BET compatible naming | |
| base_name = os.path.basename(input_nifti_path).replace('.nii.gz', '').replace('.nii', '') | |
| # HD-BET expects files with _0000 suffix - create temporary file if needed | |
| temp_input_dir = os.path.join(output_dir, "temp_input") | |
| os.makedirs(temp_input_dir, exist_ok=True) | |
| # Copy input file with _0000 suffix for HD-BET | |
| temp_input_path = os.path.join(temp_input_dir, f"{base_name}_0000.nii.gz") | |
| shutil.copy2(input_nifti_path, temp_input_path) | |
| # Set device | |
| device = "0" if torch.cuda.is_available() else "cpu" | |
| try: | |
| # Also try setting the specific model file path | |
| model_file = os.path.join(HD_BET_MODEL_DIR, '0.model') | |
| if os.path.exists(model_file): | |
| print(f"Local model file exists at: {model_file}") | |
| else: | |
| print(f"Warning: Model file not found at: {model_file}") | |
| # List directory contents for debugging | |
| if os.path.exists(HD_BET_MODEL_DIR): | |
| print(f"Contents of {HD_BET_MODEL_DIR}: {os.listdir(HD_BET_MODEL_DIR)}") | |
| else: | |
| print(f"Directory {HD_BET_MODEL_DIR} does not exist") | |
| # Run HD-BET directly on the temporary directory | |
| print(f"Running hd_bet with input_dir: {temp_input_dir}, output_dir: {output_dir}") | |
| hd_bet(temp_input_dir, output_dir, device=device, mode='fast', tta=0) | |
| # HD-BET outputs files with original naming convention | |
| output_file_path = os.path.join(output_dir, f"{base_name}_0000.nii.gz") | |
| output_mask_path = os.path.join(output_dir, f"{base_name}_0000_mask.nii.gz") | |
| # Rename to expected format for compatibility | |
| final_output_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz") | |
| final_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz") | |
| if os.path.exists(output_file_path): | |
| shutil.move(output_file_path, final_output_path) | |
| if os.path.exists(output_mask_path): | |
| shutil.move(output_mask_path, final_mask_path) | |
| # Clean up temporary directory | |
| shutil.rmtree(temp_input_dir, ignore_errors=True) | |
| if not os.path.exists(final_output_path): | |
| raise RuntimeError(f"HD-BET did not produce output file: {final_output_path}") | |
| print(f"Skull stripping complete. Output saved to: {final_output_path}") | |
| return final_output_path, final_mask_path | |
| except Exception as e: | |
| # Clean up on error | |
| shutil.rmtree(temp_input_dir, ignore_errors=True) | |
| raise RuntimeError(f"HD-BET skull stripping failed: {str(e)}") | |
| # ---------------- Saliency Generation ---------------- | |
| def extract_attention_map(vit_model, image, layer_idx=-1, img_size=(96, 96, 96), patch_size=16): | |
| """ | |
| Extracts the attention map from a Vision Transformer (ViT) model. | |
| This function wraps the attention blocks of the ViT to capture the attention | |
| weights during a forward pass. It then processes these weights to generate | |
| a 3D saliency map corresponding to the model's focus on the input image. | |
| """ | |
| attention_maps = {} | |
| original_attns = {} | |
| # A wrapper class to intercept and store attention weights from a ViT block. | |
| class AttentionWithWeights(torch.nn.Module): | |
| def __init__(self, original_attn_module): | |
| super().__init__() | |
| self.original_attn_module = original_attn_module | |
| self.attn_weights = None | |
| def forward(self, x): | |
| # The original implementation of the attention module may not return | |
| # the attention weights. This wrapper recalculates them to ensure they | |
| # are captured. This is based on the standard ViT attention mechanism. | |
| output = self.original_attn_module(x) | |
| if hasattr(self.original_attn_module, 'qkv'): | |
| qkv = self.original_attn_module.qkv(x) | |
| batch_size, seq_len, _ = x.shape | |
| # Assuming qkv has been fused and has shape (batch_size, seq_len, 3 * num_heads * head_dim) | |
| qkv = qkv.reshape(batch_size, seq_len, 3, self.original_attn_module.num_heads, -1) | |
| qkv = qkv.permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| attn = (q @ k.transpose(-2, -1)) * self.original_attn_module.scale | |
| self.attn_weights = attn.softmax(dim=-1) | |
| return output | |
| # Store original attention modules and replace with wrappers | |
| for i, block in enumerate(vit_model.blocks): | |
| if hasattr(block, 'attn'): | |
| original_attns[i] = block.attn | |
| block.attn = AttentionWithWeights(block.attn) | |
| try: | |
| # Perform a forward pass to execute the wrapped modules and capture weights | |
| with torch.no_grad(): | |
| _ = vit_model(image) | |
| # Collect the captured attention weights from each block | |
| for i, block in enumerate(vit_model.blocks): | |
| if hasattr(block.attn, 'attn_weights') and block.attn.attn_weights is not None: | |
| attention_maps[f"layer_{i}"] = block.attn.attn_weights.detach() | |
| finally: | |
| # Restore original attention modules | |
| for i, original_attn in original_attns.items(): | |
| vit_model.blocks[i].attn = original_attn | |
| if not attention_maps: | |
| raise RuntimeError("Could not extract any attention maps. Please check the ViT model structure.") | |
| # Select the attention map from the specified layer | |
| if layer_idx < 0: | |
| layer_idx = len(attention_maps) + layer_idx | |
| layer_name = f"layer_{layer_idx}" | |
| if layer_name not in attention_maps: | |
| raise ValueError(f"Layer {layer_idx} not found. Available layers: {list(attention_maps.keys())}") | |
| layer_attn = attention_maps[layer_name] | |
| # Average attention across all heads | |
| head_attn = layer_attn[0].mean(dim=0) | |
| # Get attention from the [CLS] token to all other image patches | |
| cls_attn = head_attn[0, 1:] | |
| # Reshape the 1D attention vector into a 3D volume | |
| patches_per_dim = img_size[0] // patch_size | |
| total_patches = patches_per_dim ** 3 | |
| # Pad or truncate if the number of patches doesn't align | |
| if cls_attn.shape[0] != total_patches: | |
| if cls_attn.shape[0] > total_patches: | |
| cls_attn = cls_attn[:total_patches] | |
| else: | |
| padded = torch.zeros(total_patches, device=cls_attn.device) | |
| padded[:cls_attn.shape[0]] = cls_attn | |
| cls_attn = padded | |
| cls_attn_3d = cls_attn.reshape(patches_per_dim, patches_per_dim, patches_per_dim) | |
| cls_attn_3d = cls_attn_3d.unsqueeze(0).unsqueeze(0) # Add batch and channel dims | |
| # Upsample the attention map to the full image resolution | |
| upsampled_attn = torch.nn.functional.interpolate( | |
| cls_attn_3d, | |
| size=img_size, | |
| mode='trilinear', | |
| align_corners=False | |
| ).squeeze() | |
| # Normalize the map to [0, 1] for visualization | |
| upsampled_attn = upsampled_attn.cpu().numpy() | |
| upsampled_attn = (upsampled_attn - upsampled_attn.min()) / (upsampled_attn.max() - upsampled_attn.min()) | |
| return upsampled_attn | |
| def generate_saliency_dual(model, input_tensor, layer_idx=-1): | |
| """ | |
| Generate saliency maps for dual-input IDH model. | |
| Args: | |
| model: The complete IDH model | |
| input_tensor: Dual input tensor (batch_size, 2, C, D, H, W) | |
| layer_idx: ViT layer to visualize | |
| Returns: | |
| tuple: (flair_input_3d, t1c_input_3d, flair_saliency_3d) | |
| """ | |
| print("Generating saliency maps for dual input...") | |
| try: | |
| # Extract individual images from dual input | |
| # input_tensor shape: [batch_size, 2, C, D, H, W] | |
| flair_tensor = input_tensor[:, 0] # [batch, C, D, H, W] | |
| t1c_tensor = input_tensor[:, 1] # [batch, C, D, H, W] | |
| # Get the ViT backbone | |
| vit_model = model.backbone.backbone | |
| # Generate attention map only for FLAIR | |
| flair_attn = extract_attention_map(vit_model, flair_tensor, layer_idx) | |
| # Convert input tensors to numpy for visualization | |
| flair_input_3d = flair_tensor.squeeze().cpu().detach().numpy() | |
| t1c_input_3d = t1c_tensor.squeeze().cpu().detach().numpy() | |
| print("Saliency maps generated successfully.") | |
| return flair_input_3d, t1c_input_3d, flair_attn | |
| except Exception as e: | |
| print(f"Error during saliency generation: {e}") | |
| traceback.print_exc() | |
| return None, None, None | |
| # ---------------- Visualization Functions ---------------- | |
| def create_slice_plots_dual(flair_data_3d, t1c_data_3d, flair_saliency_3d, slice_index): | |
| """Create slice plots for simplified dual input visualization: T1c, FLAIR, FLAIR attention.""" | |
| print(f"Generating plots for slice index: {slice_index}") | |
| if any(data is None for data in [flair_data_3d, t1c_data_3d, flair_saliency_3d]): | |
| return None, None, None | |
| # Check bounds - using axis 2 for axial slices | |
| if not (0 <= slice_index < flair_data_3d.shape[2]): | |
| print(f"Error: Slice index {slice_index} out of bounds (0-{flair_data_3d.shape[2]-1}).") | |
| return None, None, None | |
| def save_plot_to_numpy(fig): | |
| with io.BytesIO() as buf: | |
| fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img_arr = plt.imread(buf, format='png') | |
| return (img_arr * 255).astype(np.uint8) | |
| try: | |
| # Extract axial slices - using axis 2 (last dimension) | |
| flair_slice = flair_data_3d[:, :, slice_index] | |
| t1c_slice = t1c_data_3d[:, :, slice_index] | |
| flair_saliency_slice = flair_saliency_3d[:, :, slice_index] | |
| # Normalize input slices | |
| def normalize_slice(slice_data, volume_data): | |
| p1, p99 = np.percentile(volume_data, (1, 99)) | |
| denom = max(p99 - p1, 1e-6) | |
| return np.clip((slice_data - p1) / denom, 0, 1) | |
| flair_slice_norm = normalize_slice(flair_slice, flair_data_3d) | |
| t1c_slice_norm = normalize_slice(t1c_slice, t1c_data_3d) | |
| # Process saliency slice | |
| def process_saliency_slice(saliency_slice, saliency_volume): | |
| saliency_slice = np.copy(saliency_slice) | |
| saliency_slice[saliency_slice < 0] = 0 | |
| saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0) | |
| s_max = max(np.max(saliency_volume[saliency_volume >= 0]), 1e-6) | |
| saliency_slice_norm = saliency_slice_blurred / s_max | |
| return np.where(saliency_slice_norm > 0.0, saliency_slice_norm, 0) | |
| flair_sal_processed = process_saliency_slice(flair_saliency_slice, flair_saliency_3d) | |
| # Create plots | |
| plots = [] | |
| # T1c Input | |
| fig1, ax1 = plt.subplots(figsize=(6, 6)) | |
| ax1.imshow(t1c_slice_norm, cmap='gray', interpolation='none', origin='lower') | |
| ax1.axis('off') | |
| ax1.set_title('T1c Input', fontsize=14, color='white', pad=10) | |
| plots.append(save_plot_to_numpy(fig1)) | |
| # FLAIR Input | |
| fig2, ax2 = plt.subplots(figsize=(6, 6)) | |
| ax2.imshow(flair_slice_norm, cmap='gray', interpolation='none', origin='lower') | |
| ax2.axis('off') | |
| ax2.set_title('FLAIR Input', fontsize=14, color='white', pad=10) | |
| plots.append(save_plot_to_numpy(fig2)) | |
| # FLAIR Attention | |
| fig3, ax3 = plt.subplots(figsize=(6, 6)) | |
| ax3.imshow(flair_sal_processed, cmap='magma', interpolation='none', origin='lower', vmin=0) | |
| ax3.axis('off') | |
| ax3.set_title('FLAIR Attention', fontsize=14, color='white', pad=10) | |
| plots.append(save_plot_to_numpy(fig3)) | |
| print(f"Generated 3 plots successfully for axial slice {slice_index}.") | |
| return tuple(plots) | |
| except Exception as e: | |
| print(f"Error generating plots for slice {slice_index}: {e}") | |
| traceback.print_exc() | |
| return tuple([None] * 3) | |
| # ---------------- Inference ---------------- | |
| def get_dual_validation_transform(image_size: Tuple[int, int, int]): | |
| return Compose([ | |
| LoadImaged(keys=["image1", "image2"]), | |
| EnsureChannelFirstd(keys=["image1", "image2"]), | |
| Resized(keys=["image1", "image2"], spatial_size=tuple(image_size), mode="trilinear"), | |
| NormalizeIntensityd(keys=["image1", "image2"], nonzero=True, channel_wise=True), | |
| ToTensord(keys=["image1", "image2"]), | |
| ]) | |
| def preprocess_dual_nifti(flair_path: str, t1c_path: str, image_size: Tuple[int, int, int], device: torch.device) -> torch.Tensor: | |
| transform = get_dual_validation_transform(image_size) | |
| sample = {"image1": flair_path, "image2": t1c_path} | |
| sample = transform(sample) | |
| img1 = sample["image1"] # (C, D, H, W) | |
| img2 = sample["image2"] # (C, D, H, W) | |
| images = torch.stack([img1, img2], dim=0).unsqueeze(0).to(device) | |
| return images | |
| def predict_idh(flair_file, t1c_file, threshold: float, do_preprocess: bool, generate_saliency: bool, cfg: dict, model, device): | |
| try: | |
| if flair_file is None or t1c_file is None: | |
| return {"error": "Please upload both FLAIR and T1c NIfTI files (.nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0} | |
| flair_path = flair_file.name if hasattr(flair_file, 'name') else flair_file | |
| t1c_path = t1c_file.name if hasattr(t1c_file, 'name') else t1c_file | |
| if not (flair_path.endswith(".nii") or flair_path.endswith(".nii.gz")): | |
| return {"error": "FLAIR must be a NIfTI file (.nii or .nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0} | |
| if not (t1c_path.endswith(".nii") or t1c_path.endswith(".nii.gz")): | |
| return {"error": "T1c must be a NIfTI file (.nii or .nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0} | |
| work_dir = tempfile.mkdtemp() | |
| flair_final_path, t1c_final_path = flair_path, t1c_path | |
| try: | |
| # Optional preprocessing pipeline | |
| if do_preprocess: | |
| # Registration (use modality-specific templates) | |
| flair_reg = os.path.join(work_dir, "flair_registered.nii.gz") | |
| t1c_reg = os.path.join(work_dir, "t1c_registered.nii.gz") | |
| register_image_sitk(flair_path, flair_reg, FLAIR_TEMPLATE_PATH) | |
| register_image_sitk(t1c_path, t1c_reg, T1C_TEMPLATE_PATH) | |
| # Enhancement | |
| flair_enh = os.path.join(work_dir, "flair_enhanced.nii.gz") | |
| t1c_enh = os.path.join(work_dir, "t1c_enhanced.nii.gz") | |
| run_enhance_on_file(flair_reg, flair_enh) | |
| run_enhance_on_file(t1c_reg, t1c_enh) | |
| # Skull stripping | |
| skullstrip_dir = os.path.join(work_dir, "skullstripped") | |
| flair_bet, _ = run_skull_stripping(flair_enh, skullstrip_dir) | |
| t1c_bet, _ = run_skull_stripping(t1c_enh, skullstrip_dir) | |
| flair_final_path, t1c_final_path = flair_bet, t1c_bet | |
| # Prediction | |
| image_size = cfg.get("infer", {}).get("image_size", [96, 96, 96]) | |
| input_tensor = preprocess_dual_nifti(flair_final_path, t1c_final_path, image_size, device) | |
| with torch.no_grad(): | |
| logits = model(input_tensor) | |
| prob = torch.sigmoid(logits).cpu().numpy().flatten()[0].item() | |
| predicted_class = int(prob >= threshold) | |
| prediction_result = { | |
| "IDH_mutant_probability": float(prob), | |
| "threshold": float(threshold), | |
| "predicted_class": int(predicted_class), | |
| "preprocessing": bool(do_preprocess), | |
| "class_label": "IDH-mutant" if predicted_class == 1 else "IDH-wildtype" | |
| } | |
| # Initialize saliency outputs | |
| t1c_input_img = flair_input_img = flair_attn_img = None | |
| slider_update = gr.Slider(visible=False) | |
| saliency_state = {"input_paths": None, "saliency_paths": None, "num_slices": 0} | |
| # Generate saliency maps if requested | |
| if generate_saliency: | |
| print("--- Generating Saliency Maps ---") | |
| try: | |
| flair_input_3d, t1c_input_3d, flair_saliency_3d = generate_saliency_dual(model, input_tensor, layer_idx=-1) | |
| if all(data is not None for data in [flair_input_3d, t1c_input_3d, flair_saliency_3d]): | |
| num_slices = flair_input_3d.shape[2] # Use axis 2 for axial slices | |
| center_slice_index = num_slices // 2 | |
| # Save numpy arrays for slider callback | |
| unique_id = str(uuid.uuid4()) | |
| temp_paths = [] | |
| for name, data in [("flair_input", flair_input_3d), ("t1c_input", t1c_input_3d), | |
| ("flair_saliency", flair_saliency_3d)]: | |
| path = os.path.join(work_dir, f"{unique_id}_{name}.npy") | |
| np.save(path, data) | |
| temp_paths.append(path) | |
| # Generate initial plots for center slice | |
| plots = create_slice_plots_dual(flair_input_3d, t1c_input_3d, flair_saliency_3d, center_slice_index) | |
| if plots and all(p is not None for p in plots): | |
| t1c_input_img, flair_input_img, flair_attn_img = plots | |
| # Update state and slider | |
| saliency_state = { | |
| "input_paths": temp_paths[:2], # [flair_input, t1c_input] | |
| "saliency_paths": temp_paths[2:], # [flair_saliency] | |
| "num_slices": num_slices | |
| } | |
| slider_update = gr.Slider(value=center_slice_index, minimum=0, maximum=num_slices-1, step=1, label="Select Slice", visible=True) | |
| print("--- Saliency Generation Complete ---") | |
| else: | |
| print("Warning: Saliency generation failed - some outputs were None") | |
| except Exception as e: | |
| print(f"Error during saliency generation: {e}") | |
| traceback.print_exc() | |
| return (prediction_result, t1c_input_img, flair_input_img, flair_attn_img, slider_update, saliency_state) | |
| except Exception as e: | |
| shutil.rmtree(work_dir, ignore_errors=True) | |
| return {"error": f"Processing failed: {str(e)}"}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0} | |
| except Exception as e: | |
| return {"error": str(e)}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0} | |
| def update_slice_viewer_dual(slice_index, current_state): | |
| """Update slice viewer for dual input saliency visualization.""" | |
| input_paths = current_state.get("input_paths", []) | |
| saliency_paths = current_state.get("saliency_paths", []) | |
| if not input_paths or not saliency_paths or len(input_paths) != 2 or len(saliency_paths) != 1: | |
| print(f"Warning: Invalid state for slice viewer update: {current_state}") | |
| return None, None, None | |
| try: | |
| # Load numpy arrays | |
| flair_input_3d = np.load(input_paths[0]) | |
| t1c_input_3d = np.load(input_paths[1]) | |
| flair_saliency_3d = np.load(saliency_paths[0]) | |
| # Validate slice index | |
| slice_index = int(slice_index) | |
| if not (0 <= slice_index < flair_input_3d.shape[2]): # Use axis 2 for axial slices | |
| print(f"Warning: Invalid slice index {slice_index}") | |
| return None, None, None | |
| # Generate new plots | |
| plots = create_slice_plots_dual(flair_input_3d, t1c_input_3d, flair_saliency_3d, slice_index) | |
| return plots if plots else tuple([None] * 3) | |
| except Exception as e: | |
| print(f"Error updating slice viewer for index {slice_index}: {e}") | |
| traceback.print_exc() | |
| return tuple([None] * 3) | |
| def build_interface(): | |
| cfg = load_config() | |
| model, device = build_model(cfg) | |
| default_threshold = float(cfg.get("infer", {}).get("threshold", 0.5)) | |
| with gr.Blocks(title="BrainIAC: IDH Classification", css=""" | |
| #header-row { | |
| min-height: 150px; | |
| align-items: center; | |
| } | |
| .logo-img img { | |
| height: 150px; | |
| object-fit: contain; | |
| } | |
| """) as demo: | |
| # --- Header with Logos --- | |
| with gr.Row(elem_id="header-row"): | |
| with gr.Column(scale=1): | |
| gr.Image(os.path.join(APP_DIR, "static/images/kannlab.png"), | |
| show_label=False, interactive=False, | |
| show_download_button=False, | |
| container=False, | |
| elem_classes=["logo-img"]) | |
| with gr.Column(scale=3): | |
| gr.Markdown( | |
| "<h1 style='text-align: center; margin-bottom: 2.5rem'>" | |
| "BrainIAC: IDH Classification" | |
| "</h1>" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Image(os.path.join(APP_DIR, "static/images/brainiac.jpeg"), | |
| show_label=False, interactive=False, | |
| show_download_button=False, | |
| container=False, | |
| elem_classes=["logo-img"]) | |
| # --- Add model description section --- | |
| with gr.Accordion("ℹ️ Model Details and Usage Guide", open=False): | |
| gr.Markdown(""" | |
| ### 🧠 BrainIAC: IDH Classification | |
| **Model Description** | |
| A Vision Transformer (ViT) model with BrainIAC as pre-trained backbone designed to predict IDH mutation status from dual MRI sequences (FLAIR + T1c). | |
| **Training Dataset** | |
| - **Subjects**: Trained on FLAIR and T1c MRI scans from glioma patients from UCSF-PDGM dataset | |
| - **Imaging Modalities**: FLAIR and T1c (contrast-enhanced T1-weighted) | |
| - **Preprocessing**: N4 bias correction, MNI registration, and skull stripping (HD-BET) | |
| **Input** | |
| - Format: NIfTI (.nii or .nii.gz) | |
| - Required sequences: FLAIR and T1c (both required) | |
| - Image size: Automatically resized to 96×96×96 voxels | |
| **Output** | |
| - Binary classification: IDH-mutant or IDH-wildtype | |
| - Probability score for IDH mutation | |
| - Attention map visualization | |
| **Intended Use** | |
| - Research use only! | |
| **NOTE** | |
| - Requires both FLAIR and T1c sequences | |
| - Not validated on other MRI sequences | |
| - Not validated for other brain pathologies beyond gliomas | |
| - Upload PHI data at own risk! | |
| - The model is hosted on a cloud-based CPU instance | |
| - The data is not stored, shared or collected for any purpose! | |
| **Preprocessing Pipeline** | |
| When enabled, the preprocessing performs: | |
| 1. **Registration**: SimpleITK-based registration to template space with mutual information metric and 1mm isotropic resampling | |
| 2. **N4 Bias Correction**: Applied during registration step | |
| 3. **Skull Stripping**: Remove non-brain tissue using HD-BET direct integration | |
| **Attention Maps** | |
| When enabled, generates ViT attention maps showing which brain regions the model focuses on for prediction. | |
| """) | |
| # Use gr.State to store paths to numpy arrays for the slider callback | |
| saliency_state = gr.State({"input_paths": None, "saliency_paths": None, "num_slices": 0}) | |
| # Main Content | |
| gr.Markdown("**Upload FLAIR and T1c NIfTI volumes** — Optional preprocessing performs registration to MNI, enhancement, and skull stripping.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Controls") | |
| flair_input = gr.File(label="FLAIR (.nii or .nii.gz)") | |
| t1c_input = gr.File(label="T1c (.nii or .nii.gz)") | |
| preprocess_checkbox = gr.Checkbox(value=False, label="Preprocess NIfTI (registration + enhancement + skull stripping)") | |
| generate_saliency_checkbox = gr.Checkbox(value=True, label="Generate Attention Maps") | |
| threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, step=0.01, label="Decision Threshold") | |
| predict_btn = gr.Button("Predict IDH Status", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| gr.Markdown("### Classification Result") | |
| output_json = gr.JSON(label="Prediction") | |
| # Saliency visualization section | |
| with gr.Group(): | |
| gr.Markdown("### Attention Map Viewer (Axial Slice)") | |
| slice_slider = gr.Slider(label="Select Slice", minimum=0, maximum=0, step=1, value=0, visible=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("<p style='text-align: center;'>T1c Input</p>") | |
| t1c_input_img = gr.Image(label="T1c Input", type="numpy", show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("<p style='text-align: center;'>FLAIR Input</p>") | |
| flair_input_img = gr.Image(label="FLAIR Input", type="numpy", show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("<p style='text-align: center;'>FLAIR Attention</p>") | |
| flair_attn_img = gr.Image(label="Attention Mask", type="numpy", show_label=False) | |
| # Wire components | |
| predict_btn.click( | |
| fn=lambda f, t, prep, gen_sal, thr: predict_idh(f, t, thr, prep, gen_sal, cfg, model, device), | |
| inputs=[flair_input, t1c_input, preprocess_checkbox, generate_saliency_checkbox, threshold_input], | |
| outputs=[output_json, t1c_input_img, flair_input_img, flair_attn_img, slice_slider, saliency_state], | |
| ) | |
| slice_slider.change( | |
| fn=update_slice_viewer_dual, | |
| inputs=[slice_slider, saliency_state], | |
| outputs=[t1c_input_img, flair_input_img, flair_attn_img] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| iface = build_interface() | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |