Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import subprocess | |
| import os | |
| import shutil | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import nibabel as nib | |
| import matplotlib as mpl | |
| import matplotlib.pyplot as plt | |
| import spaces # Import spaces for GPU decoration | |
| import numpy as np | |
| from scipy.ndimage import center_of_mass, zoom, label, generate_binary_structure | |
| # Define paths | |
| MODEL_DIR = "./model" # Local directory to store the downloaded model | |
| DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML | |
| INPUT_DIR = "/tmp/input" | |
| OUTPUT_DIR = "/tmp/output" | |
| # Hugging Face Model Repository | |
| REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID | |
| import os | |
| import subprocess | |
| def setup_hd_bet(repo_dir="./HD-BET"): | |
| """ | |
| Clones the HD-BET repository and installs it in editable mode using pip. | |
| Parameters: | |
| repo_dir (str): Directory where HD-BET will be cloned and installed. | |
| """ | |
| if not os.path.exists(repo_dir): | |
| print("Cloning HD-BET repository...") | |
| subprocess.run(["git", "clone", "https://github.com/MIC-DKFZ/HD-BET", repo_dir], check=True) | |
| else: | |
| print("HD-BET repository already exists.") | |
| # Install the HD-BET package from source | |
| print("Installing HD-BET using pip...") | |
| subprocess.run(["pip", "install", "-e", "."], cwd=repo_dir, check=True) | |
| # Function to download the Dataset004_WML folder | |
| def download_model(): | |
| if not os.path.exists(DATASET_DIR): | |
| os.makedirs(DATASET_DIR, exist_ok=True) | |
| print("Downloading Dataset004_WML.zip...") | |
| zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR) | |
| subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR]) | |
| print("Dataset004_WML downloaded and extracted.") | |
| def resample_to_isotropic(data, affine, target_spacing=1.0): | |
| """ | |
| Resamples a 3D NIfTI image to isotropic voxel size. | |
| Parameters: | |
| data (numpy.ndarray): The input 3D image data. | |
| affine (numpy.ndarray): The affine transformation matrix. | |
| target_spacing (float): Desired isotropic voxel spacing (in mm). | |
| Returns: | |
| resampled_data (numpy.ndarray): Resampled image data. | |
| resampled_affine (numpy.ndarray): Updated affine matrix. | |
| """ | |
| # Extract current voxel dimensions from the affine matrix | |
| current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0)) | |
| # Compute the scaling factors for resampling | |
| scaling_factors = current_spacing / target_spacing | |
| # Resample the data using zoom | |
| resampled_data = zoom(data, zoom=scaling_factors, order=1) # Linear interpolation | |
| # Update the affine matrix to reflect the new voxel dimensions | |
| resampled_affine = affine.copy() | |
| resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis] | |
| return resampled_data, resampled_affine | |
| def extract_middle_slices(nifti_path, output_image_path, slice_size=180): | |
| """ | |
| Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image. | |
| The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG. | |
| """ | |
| # Load NIfTI image | |
| img = nib.load(nifti_path) | |
| data = img.get_fdata() | |
| affine = img.affine | |
| # Resample the image to 1 mm isotropic | |
| resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) | |
| # Compute the center of mass of non-zero voxels | |
| com = center_of_mass(resampled_data > 0) | |
| center = np.round(com).astype(int) | |
| # Define half the slice size | |
| half_size = slice_size // 2 | |
| def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None, label_components=False): | |
| """ | |
| Extracts slices from a 3D NIfTI image. | |
| If label_components=True, it assigns different labels (colors) to each connected component (26-connectivity) | |
| and returns the labeled 3D mask. | |
| Returns: | |
| labeled_data (np.ndarray): The 3D array (either labeled or original). | |
| """ | |
| # Load NIfTI image | |
| img = nib.load(nifti_path) | |
| data = img.get_fdata() | |
| affine = img.affine | |
| # Resample the image to 1 mm isotropic | |
| resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) | |
| # Optionally label connected components | |
| if label_components: | |
| structure = generate_binary_structure(3, 3) # 3D, 26-connectivity | |
| labeled_data, num_features = label(data > 0, structure=structure) | |
| labeled_data_resampled, num_features = label(resampled_data > 0, structure=structure) | |
| else: | |
| labeled_data = resampled_data | |
| num_features = None # Not needed if we're not labeling | |
| labeled_data_resampled = resampled_data | |
| # Compute or reuse the center of mass | |
| if center is None: | |
| com = center_of_mass(labeled_data_resampled > 0) | |
| center = np.round(com).astype(int) | |
| # Define half the slice size | |
| half_size = slice_size // 2 | |
| # Function to extract and pad slices | |
| def extract_2d_slice(data, center, axis): | |
| slices = [slice(None)] * 3 | |
| slices[axis] = center[axis] | |
| extracted_slice = data[tuple(slices)] | |
| remaining_axes = [i for i in range(3) if i != axis] | |
| cropped_slice = extracted_slice[ | |
| max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]), | |
| max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]), | |
| ] | |
| pad_height = slice_size - cropped_slice.shape[0] | |
| pad_width = slice_size - cropped_slice.shape[1] | |
| padded_slice = np.pad(cropped_slice, | |
| ((pad_height // 2, pad_height - pad_height // 2), | |
| (pad_width // 2, pad_width - pad_width // 2)), | |
| mode='constant', constant_values=0) | |
| return padded_slice | |
| # Extract slices | |
| axial_slice = extract_2d_slice(labeled_data_resampled, center, axis=2) | |
| coronal_slice = extract_2d_slice(labeled_data_resampled, center, axis=1) | |
| sagittal_slice = extract_2d_slice(labeled_data_resampled, center, axis=0) | |
| # Apply rotations | |
| axial_slice = np.rot90(axial_slice, k=-1) | |
| coronal_slice = np.rot90(coronal_slice, k=1) | |
| coronal_slice = np.rot90(coronal_slice, k=2) | |
| sagittal_slice = np.rot90(sagittal_slice, k=1) | |
| sagittal_slice = np.rot90(sagittal_slice, k=2) | |
| # Create subplots | |
| fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | |
| # Choose colormap | |
| if label_components: | |
| # Create 256 pastel colors | |
| pastel = plt.cm.Pastel1(np.linspace(0, 1, 256)) | |
| np.random.seed(42) # For reproducibility | |
| shuffled_colors = pastel[1:].copy() | |
| np.random.shuffle(shuffled_colors) | |
| final_colors = np.vstack([np.array([0, 0, 0, 1]), shuffled_colors]) | |
| custom_cmap = mpl.colors.ListedColormap(final_colors) | |
| cmap = custom_cmap # Colorful | |
| vmin = 0 | |
| vmax = num_features | |
| else: | |
| cmap = "gray" # Normal | |
| vmin = None | |
| vmax = None | |
| # Plot slices | |
| for idx, slice_data in enumerate([axial_slice, coronal_slice, sagittal_slice]): | |
| ax = axes[idx] | |
| im = ax.imshow(slice_data, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax) | |
| ax.axis("off") | |
| # Save figure | |
| plt.tight_layout() | |
| plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0) | |
| plt.close() | |
| # Return the labeled mask | |
| return labeled_data | |
| # Function to run nnUNet inference | |
| # Decorate the function to allocate GPU for its execution | |
| def run_nnunet_predict(nifti_file,hd_bet=False): | |
| # Prepare directories | |
| os.makedirs(INPUT_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Extract the original filename without the extension | |
| original_filename = os.path.basename(nifti_file.name) | |
| base_filename = original_filename.replace(".nii.gz", "") | |
| # Save the uploaded file to the input directory | |
| input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz") | |
| os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location | |
| if hd_bet: | |
| # Apply skull-stripping with HD-BET | |
| hd_bet_output_path = os.path.join(INPUT_DIR, "image_0000.nii.gz") | |
| try: | |
| subprocess.run([ | |
| "hd-bet", | |
| "-i", input_path, | |
| "-o", hd_bet_output_path, | |
| "-device", "cuda", # or "cpu" | |
| "--disable_tta" ], check=True) | |
| print("Skull-stripping completed.") | |
| input_path = hd_bet_output_path | |
| except subprocess.CalledProcessError as e: | |
| return f"HD-BET Error: {e}" | |
| # Debugging: List files in the /tmp/input directory | |
| print("Files in /tmp/input:") | |
| print(os.listdir(INPUT_DIR)) | |
| # Set environment variables for nnUNet | |
| os.environ["nnUNet_results"] = MODEL_DIR | |
| # Construct and run the nnUNetv2_predict command | |
| command = [ | |
| "nnUNetv2_predict", | |
| "-i", INPUT_DIR, | |
| "-o", OUTPUT_DIR, | |
| "-d", "004", # Dataset ID | |
| "-c", "3d_fullres", # Configuration | |
| "-tr", "nnUNetTrainer_8000epochs", | |
| "-device", "cuda", # Explicitly use GPU | |
| ] | |
| print("Files in /tmp/output:") | |
| print(os.listdir(OUTPUT_DIR)) | |
| try: | |
| subprocess.run(command, check=True) | |
| # Rename the output file to match the original input filename | |
| output_file = os.path.join(OUTPUT_DIR, "image.nii.gz") | |
| new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz") | |
| if os.path.exists(output_file): | |
| os.rename(output_file, new_output_file) | |
| # Compute center of mass for the input image | |
| img = nib.load(input_path) | |
| data = img.get_fdata() | |
| affine = img.affine | |
| resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) | |
| com = center_of_mass(resampled_data > 0) # Center of mass | |
| center = np.round(com).astype(int) # Round to integer | |
| # Extract and save 2D slices | |
| input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png") | |
| output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png") | |
| image = extract_middle_slices(input_path, input_slice_path, center=center) | |
| labeled_mask = extract_middle_slices(new_output_file, output_slice_path, center=center, label_components=True) | |
| # Load the binary lesion mask to get its affine | |
| output_img = nib.load(new_output_file) | |
| labeled_mask_path = os.path.join(OUTPUT_DIR, f"{base_filename}_LabeledClusters.nii.gz") | |
| nib.save(nib.Nifti1Image(labeled_mask.astype(np.int16), output_img.affine), labeled_mask_path) | |
| # Return paths for the Gradio interface | |
| return new_output_file, input_slice_path, output_slice_path, labeled_mask_path | |
| else: | |
| return "Error: Output file not found." | |
| except subprocess.CalledProcessError as e: | |
| return f"Error: {e}" | |
| # Gradio interface with adjusted layout | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # 🔥 FLAMeS: FLAIR Lesion Segmentation for Multiple Sclerosis | |
| Upload a FLAIR brain MRI in NIfTI format (.nii.gz) to generate a binary segmentation of multiple sclerosis lesions. | |
| FLAMeS is based on the nnUNet framework<sup>2</sup> and was trained on 668 MRI scans acquired using Siemens, GE, and Philips 1.5T and 3T scanners<sup>1</sup>. | |
| We suggest skull-stripping the image in advance using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results. If that's not feasible, you can still upload your image as-is and enable the "Apply skull-stripping" option below. | |
| Inference takes approximately 1 minute per MRI, with processing limited to one scan at a time due to Hugging Face's zero-GPU usage constraints. To process multiple cases simultaneously, install the [nnUNet v2](https://github.com/MIC-DKFZ/nnUNet), download [FLAMeS's model](https://huggingface.co/FrancescoLR/FLAMeS-model) and run it locally using your own GPU or CPU setup. | |
| **Disclaimer:** Uploaded data is stored temporarily, no one has access to it, and it is deleted when the app is closed. For details, see [Gradio's file access guide](https://www.gradio.app/main/guides/file-access). Human subjects data should only be uploaded for processing if permitted by your institution's human subjects protection office. | |
| This is a research tool and is not intended for clinical use. Clinical decisions should not be based on the outputs of this tool. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| flair_input = gr.File(label="Upload a FLAIR Image (.nii.gz)") | |
| hd_bet = gr.Checkbox(label="Apply skull-stripping", value=False) | |
| submit_button = gr.Button("Submit") | |
| with gr.Column(scale=2): | |
| seg_output = gr.File(label="Download the Lesion Segmentation Mask") | |
| clusters_output = gr.File(label="Download the Labeled Lesion Segmentation Mask") | |
| input_img = gr.Image(label="Input: FLAIR image") | |
| output_img = gr.Image(label="Output: Binary Lesion Mask") | |
| gr.Markdown(""" | |
| **If you find this tool useful, please consider citing:** | |
| 1. A Novel Convolutional Neural Network for Automated Multiple Sclerosis Brain Lesion Segmentation | |
| Dereskewicz, E., La Rosa, F., dos Santos Silva, J., Sizer, E., Kohli, A., Wynen, M., ... & Beck, E. S. | |
| *Journal of Neuroimaging* 35.5 (2025): e70085 | |
| DOI: [10.1111/jon.70085](https://doi.org/10.1111/jon.70085) | |
| 2. nnU-Net: A Self-Configuring Method for Deep Learning-Based Biomedical Image Segmentation | |
| F. Isensee, P. F. Jaeger, S. A. Kohl, J. Petersen, & K. H. Maier-Hein. | |
| *Nature Methods.* 2021;18(2):203-211. | |
| DOI: [10.1038/s41592-020-01008-z](https://www.nature.com/articles/s41592-020-01008-z) | |
| """) | |
| submit_button.click( | |
| fn=run_nnunet_predict, | |
| inputs=[flair_input, hd_bet], | |
| outputs=[seg_output, input_img, output_img, clusters_output] | |
| ) | |
| # Debugging GPU environment | |
| if torch.cuda.is_available(): | |
| print(f"GPU is available: {torch.cuda.get_device_name(0)}") | |
| else: | |
| print("No GPU available. Falling back to CPU.") | |
| os.system("nvidia-smi") | |
| setup_hd_bet() | |
| download_model() | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |