import gradio as gr import subprocess import os import shutil from huggingface_hub import hf_hub_download import torch import nibabel as nib import matplotlib.pyplot as plt import spaces # Import spaces for GPU decoration import numpy as np from scipy.ndimage import center_of_mass, zoom # Define paths MODEL_DIR = "./model" # Local directory to store the downloaded model DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML INPUT_DIR = "/tmp/input" OUTPUT_DIR = "/tmp/output" # Hugging Face Model Repository REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID # Function to download the Dataset004_WML folder def download_model(): if not os.path.exists(DATASET_DIR): os.makedirs(DATASET_DIR, exist_ok=True) print("Downloading Dataset004_WML.zip...") zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR) subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR]) print("Dataset004_WML downloaded and extracted.") def resample_to_isotropic(data, affine, target_spacing=1.0): """ Resamples a 3D NIfTI image to isotropic voxel size. Parameters: data (numpy.ndarray): The input 3D image data. affine (numpy.ndarray): The affine transformation matrix. target_spacing (float): Desired isotropic voxel spacing (in mm). Returns: resampled_data (numpy.ndarray): Resampled image data. resampled_affine (numpy.ndarray): Updated affine matrix. """ # Extract current voxel dimensions from the affine matrix current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0)) # Compute the scaling factors for resampling scaling_factors = current_spacing / target_spacing # Resample the data using zoom resampled_data = zoom(data, zoom=scaling_factors, order=1) # Linear interpolation # Update the affine matrix to reflect the new voxel dimensions resampled_affine = affine.copy() resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis] return resampled_data, resampled_affine def extract_middle_slices(nifti_path, output_image_path, slice_size=180): """ Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image. The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG. """ # Load NIfTI image img = nib.load(nifti_path) data = img.get_fdata() affine = img.affine # Resample the image to 1 mm isotropic resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) # Compute the center of mass of non-zero voxels com = center_of_mass(resampled_data > 0) center = np.round(com).astype(int) # Define half the slice size half_size = slice_size // 2 def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None): """ Extracts slices from a 3D NIfTI image. If a center is provided, it uses it; otherwise, computes the center of mass of non-zero voxels. Slices are taken along axial, coronal, and sagittal planes and saved as a single PNG. """ # Load NIfTI image img = nib.load(nifti_path) data = img.get_fdata() affine = img.affine # Resample the image to 1 mm isotropic resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) # Compute or reuse the center of mass if center is None: com = center_of_mass(resampled_data > 0) center = np.round(com).astype(int) # Define half the slice size half_size = slice_size // 2 # Safely extract and pad 2D slices def extract_2d_slice(data, center, axis): slices = [slice(None)] * 3 slices[axis] = center[axis] # Fix the axis to extract a single slice extracted_slice = data[tuple(slices)] # Crop the 2D slice around the center in the remaining dimensions remaining_axes = [i for i in range(3) if i != axis] cropped_slice = extracted_slice[ max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]), max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]), ] # Pad the slice to ensure 180x180 dimensions pad_height = slice_size - cropped_slice.shape[0] pad_width = slice_size - cropped_slice.shape[1] padded_slice = np.pad(cropped_slice, ((pad_height // 2, pad_height - pad_height // 2), (pad_width // 2, pad_width - pad_width // 2)), mode='constant', constant_values=0) return padded_slice # Extract slices in axial, coronal, and sagittal planes axial_slice = extract_2d_slice(resampled_data, center, axis=2) # Axial (z-axis) coronal_slice = extract_2d_slice(resampled_data, center, axis=1) # Coronal (y-axis) sagittal_slice = extract_2d_slice(resampled_data, center, axis=0) # Sagittal (x-axis) # Apply rotations to each slice axial_slice = np.rot90(axial_slice, k=-1) # 90 degrees clockwise coronal_slice = np.rot90(coronal_slice, k=1) # 90 degrees anticlockwise coronal_slice = np.rot90(coronal_slice, k=2) # Additional 180 degrees sagittal_slice = np.rot90(sagittal_slice, k=1) # 90 degrees anticlockwise sagittal_slice = np.rot90(sagittal_slice, k=2) # Additional 180 degrees # Create subplots fig, axes = plt.subplots(1, 3, figsize=(12, 4)) # Plot each padded and rotated slice axes[0].imshow(axial_slice, cmap="gray", origin="lower") axes[0].axis("off") axes[1].imshow(coronal_slice, cmap="gray", origin="lower") axes[1].axis("off") axes[2].imshow(sagittal_slice, cmap="gray", origin="lower") axes[2].axis("off") # Save the figure plt.tight_layout() plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0) plt.close() # Function to run nnUNet inference @spaces.GPU # Decorate the function to allocate GPU for its execution def run_nnunet_predict(nifti_file): # Prepare directories os.makedirs(INPUT_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # Extract the original filename without the extension original_filename = os.path.basename(nifti_file.name) base_filename = original_filename.replace(".nii.gz", "") # Save the uploaded file to the input directory input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz") os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location # Debugging: List files in the /tmp/input directory print("Files in /tmp/input:") print(os.listdir(INPUT_DIR)) # Set environment variables for nnUNet os.environ["nnUNet_results"] = MODEL_DIR # Construct and run the nnUNetv2_predict command command = [ "nnUNetv2_predict", "-i", INPUT_DIR, "-o", OUTPUT_DIR, "-d", "004", # Dataset ID "-c", "3d_fullres", # Configuration "-tr", "nnUNetTrainer_8000epochs", "-device", "cuda" # Explicitly use GPU ] print("Files in /tmp/output:") print(os.listdir(OUTPUT_DIR)) try: subprocess.run(command, check=True) # Rename the output file to match the original input filename output_file = os.path.join(OUTPUT_DIR, "image.nii.gz") new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz") if os.path.exists(output_file): os.rename(output_file, new_output_file) # Compute center of mass for the input image img = nib.load(input_path) data = img.get_fdata() affine = img.affine resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) com = center_of_mass(resampled_data > 0) # Center of mass center = np.round(com).astype(int) # Round to integer # Extract and save 2D slices input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png") output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png") extract_middle_slices(input_path, input_slice_path, center=center) extract_middle_slices(new_output_file, output_slice_path, center=center) # Return paths for the Gradio interface return new_output_file, input_slice_path, output_slice_path else: return "Error: Output file not found." except subprocess.CalledProcessError as e: return f"Error: {e}" # Gradio interface with adjusted layout with gr.Blocks() as demo: gr.Markdown(""" # FLAMeS: Multiple Sclerosis Lesion Segmentation Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of multiple sclerosis lesions. """) with gr.Row(): with gr.Column(scale=1): flair_input = gr.File(label="Upload FLAIR Image (.nii.gz)") submit_button = gr.Button("Submit") with gr.Column(scale=2): seg_output = gr.File(label="Download Segmentation Mask") input_img = gr.Image(label="Input: FLAIR image") output_img = gr.Image(label="Output: Lesion Mask") gr.Markdown(""" **If you find this tool useful, please consider citing:** 1. A Deep Learning-Based Pipeline for Longitudinal White Matter Lesion Segmentation Using Diverse FLAIR Images F. La Rosa, J. Dos Santos Silva, et al. *Multiple Sclerosis Journal.* DOI: [10.1177/13524585231169437](https://doi.org/10.1177/13524585231169437) 2. nnU-Net: A Self-Configuring Method for Deep Learning-Based Biomedical Image Segmentation F. Isensee, et al. *Nature Methods.* DOI: [10.1038/s41592-020-01008-z](https://www.nature.com/articles/s41592-020-01008-z) """) submit_button.click( fn=run_nnunet_predict, inputs=[flair_input], outputs=[seg_output, input_img, output_img] ) # Debugging GPU environment if torch.cuda.is_available(): print(f"GPU is available: {torch.cuda.get_device_name(0)}") else: print("No GPU available. Falling back to CPU.") os.system("nvidia-smi") download_model() if __name__ == "__main__": demo.launch(share=True)