FLAMeS / app.py
FrancescoLR's picture
Update app.py
f499bba verified
import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import nibabel as nib
import matplotlib as mpl
import matplotlib.pyplot as plt
import spaces # Import spaces for GPU decoration
import numpy as np
from scipy.ndimage import center_of_mass, zoom, label, generate_binary_structure
# Define paths
MODEL_DIR = "./model" # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
import os
import subprocess
def setup_hd_bet(repo_dir="./HD-BET"):
"""
Clones the HD-BET repository and installs it in editable mode using pip.
Parameters:
repo_dir (str): Directory where HD-BET will be cloned and installed.
"""
if not os.path.exists(repo_dir):
print("Cloning HD-BET repository...")
subprocess.run(["git", "clone", "https://github.com/MIC-DKFZ/HD-BET", repo_dir], check=True)
else:
print("HD-BET repository already exists.")
# Install the HD-BET package from source
print("Installing HD-BET using pip...")
subprocess.run(["pip", "install", "-e", "."], cwd=repo_dir, check=True)
# Function to download the Dataset004_WML folder
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading Dataset004_WML.zip...")
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
print("Dataset004_WML downloaded and extracted.")
def resample_to_isotropic(data, affine, target_spacing=1.0):
"""
Resamples a 3D NIfTI image to isotropic voxel size.
Parameters:
data (numpy.ndarray): The input 3D image data.
affine (numpy.ndarray): The affine transformation matrix.
target_spacing (float): Desired isotropic voxel spacing (in mm).
Returns:
resampled_data (numpy.ndarray): Resampled image data.
resampled_affine (numpy.ndarray): Updated affine matrix.
"""
# Extract current voxel dimensions from the affine matrix
current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0))
# Compute the scaling factors for resampling
scaling_factors = current_spacing / target_spacing
# Resample the data using zoom
resampled_data = zoom(data, zoom=scaling_factors, order=1) # Linear interpolation
# Update the affine matrix to reflect the new voxel dimensions
resampled_affine = affine.copy()
resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis]
return resampled_data, resampled_affine
def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
"""
Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
"""
# Load NIfTI image
img = nib.load(nifti_path)
data = img.get_fdata()
affine = img.affine
# Resample the image to 1 mm isotropic
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
# Compute the center of mass of non-zero voxels
com = center_of_mass(resampled_data > 0)
center = np.round(com).astype(int)
# Define half the slice size
half_size = slice_size // 2
def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None, label_components=False):
"""
Extracts slices from a 3D NIfTI image.
If label_components=True, it assigns different labels (colors) to each connected component (26-connectivity)
and returns the labeled 3D mask.
Returns:
labeled_data (np.ndarray): The 3D array (either labeled or original).
"""
# Load NIfTI image
img = nib.load(nifti_path)
data = img.get_fdata()
affine = img.affine
# Resample the image to 1 mm isotropic
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
# Optionally label connected components
if label_components:
structure = generate_binary_structure(3, 3) # 3D, 26-connectivity
labeled_data, num_features = label(data > 0, structure=structure)
labeled_data_resampled, num_features = label(resampled_data > 0, structure=structure)
else:
labeled_data = resampled_data
num_features = None # Not needed if we're not labeling
labeled_data_resampled = resampled_data
# Compute or reuse the center of mass
if center is None:
com = center_of_mass(labeled_data_resampled > 0)
center = np.round(com).astype(int)
# Define half the slice size
half_size = slice_size // 2
# Function to extract and pad slices
def extract_2d_slice(data, center, axis):
slices = [slice(None)] * 3
slices[axis] = center[axis]
extracted_slice = data[tuple(slices)]
remaining_axes = [i for i in range(3) if i != axis]
cropped_slice = extracted_slice[
max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]),
max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]),
]
pad_height = slice_size - cropped_slice.shape[0]
pad_width = slice_size - cropped_slice.shape[1]
padded_slice = np.pad(cropped_slice,
((pad_height // 2, pad_height - pad_height // 2),
(pad_width // 2, pad_width - pad_width // 2)),
mode='constant', constant_values=0)
return padded_slice
# Extract slices
axial_slice = extract_2d_slice(labeled_data_resampled, center, axis=2)
coronal_slice = extract_2d_slice(labeled_data_resampled, center, axis=1)
sagittal_slice = extract_2d_slice(labeled_data_resampled, center, axis=0)
# Apply rotations
axial_slice = np.rot90(axial_slice, k=-1)
coronal_slice = np.rot90(coronal_slice, k=1)
coronal_slice = np.rot90(coronal_slice, k=2)
sagittal_slice = np.rot90(sagittal_slice, k=1)
sagittal_slice = np.rot90(sagittal_slice, k=2)
# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# Choose colormap
if label_components:
# Create 256 pastel colors
pastel = plt.cm.Pastel1(np.linspace(0, 1, 256))
np.random.seed(42) # For reproducibility
shuffled_colors = pastel[1:].copy()
np.random.shuffle(shuffled_colors)
final_colors = np.vstack([np.array([0, 0, 0, 1]), shuffled_colors])
custom_cmap = mpl.colors.ListedColormap(final_colors)
cmap = custom_cmap # Colorful
vmin = 0
vmax = num_features
else:
cmap = "gray" # Normal
vmin = None
vmax = None
# Plot slices
for idx, slice_data in enumerate([axial_slice, coronal_slice, sagittal_slice]):
ax = axes[idx]
im = ax.imshow(slice_data, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax)
ax.axis("off")
# Save figure
plt.tight_layout()
plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
plt.close()
# Return the labeled mask
return labeled_data
# Function to run nnUNet inference
@spaces.GPU(duration=90) # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file,hd_bet=False):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Extract the original filename without the extension
original_filename = os.path.basename(nifti_file.name)
base_filename = original_filename.replace(".nii.gz", "")
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
if hd_bet:
# Apply skull-stripping with HD-BET
hd_bet_output_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
try:
subprocess.run([
"hd-bet",
"-i", input_path,
"-o", hd_bet_output_path,
"-device", "cuda", # or "cpu"
"--disable_tta" ], check=True)
print("Skull-stripping completed.")
input_path = hd_bet_output_path
except subprocess.CalledProcessError as e:
return f"HD-BET Error: {e}"
# Debugging: List files in the /tmp/input directory
print("Files in /tmp/input:")
print(os.listdir(INPUT_DIR))
# Set environment variables for nnUNet
os.environ["nnUNet_results"] = MODEL_DIR
# Construct and run the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs",
"-device", "cuda", # Explicitly use GPU
]
print("Files in /tmp/output:")
print(os.listdir(OUTPUT_DIR))
try:
subprocess.run(command, check=True)
# Rename the output file to match the original input filename
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
if os.path.exists(output_file):
os.rename(output_file, new_output_file)
# Compute center of mass for the input image
img = nib.load(input_path)
data = img.get_fdata()
affine = img.affine
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
com = center_of_mass(resampled_data > 0) # Center of mass
center = np.round(com).astype(int) # Round to integer
# Extract and save 2D slices
input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
image = extract_middle_slices(input_path, input_slice_path, center=center)
labeled_mask = extract_middle_slices(new_output_file, output_slice_path, center=center, label_components=True)
# Load the binary lesion mask to get its affine
output_img = nib.load(new_output_file)
labeled_mask_path = os.path.join(OUTPUT_DIR, f"{base_filename}_LabeledClusters.nii.gz")
nib.save(nib.Nifti1Image(labeled_mask.astype(np.int16), output_img.affine), labeled_mask_path)
# Return paths for the Gradio interface
return new_output_file, input_slice_path, output_slice_path, labeled_mask_path
else:
return "Error: Output file not found."
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio interface with adjusted layout
with gr.Blocks() as demo:
gr.Markdown("""
# 🔥 FLAMeS: FLAIR Lesion Segmentation for Multiple Sclerosis
Upload a FLAIR brain MRI in NIfTI format (.nii.gz) to generate a binary segmentation of multiple sclerosis lesions.
FLAMeS is based on the nnUNet framework<sup>2</sup> and was trained on 668 MRI scans acquired using Siemens, GE, and Philips 1.5T and 3T scanners<sup>1</sup>.
We suggest skull-stripping the image in advance using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results. If that's not feasible, you can still upload your image as-is and enable the "Apply skull-stripping" option below.
Inference takes approximately 1 minute per MRI, with processing limited to one scan at a time due to Hugging Face's zero-GPU usage constraints. To process multiple cases simultaneously, install the [nnUNet v2](https://github.com/MIC-DKFZ/nnUNet), download [FLAMeS's model](https://huggingface.co/FrancescoLR/FLAMeS-model) and run it locally using your own GPU or CPU setup.
**Disclaimer:** Uploaded data is stored temporarily, no one has access to it, and it is deleted when the app is closed. For details, see [Gradio's file access guide](https://www.gradio.app/main/guides/file-access). Human subjects data should only be uploaded for processing if permitted by your institution's human subjects protection office.
This is a research tool and is not intended for clinical use. Clinical decisions should not be based on the outputs of this tool.
""")
with gr.Row():
with gr.Column(scale=1):
flair_input = gr.File(label="Upload a FLAIR Image (.nii.gz)")
hd_bet = gr.Checkbox(label="Apply skull-stripping", value=False)
submit_button = gr.Button("Submit")
with gr.Column(scale=2):
seg_output = gr.File(label="Download the Lesion Segmentation Mask")
clusters_output = gr.File(label="Download the Labeled Lesion Segmentation Mask")
input_img = gr.Image(label="Input: FLAIR image")
output_img = gr.Image(label="Output: Binary Lesion Mask")
gr.Markdown("""
**If you find this tool useful, please consider citing:**
1. FLAMeS: A Robust Deep Learning Model for Automated Multiple Sclerosis Lesion Segmentation
Dereskewicz, E., La Rosa, F., dos Santos Silva, J., Sizer, E., Kohli, A., Wynen, M., ... & Beck, E. S.
*medRxiv (2025)
DOI: [10.1177/13524585231169437](https://doi.org/10.1101/2025.05.19.25327707)
2. nnU-Net: A Self-Configuring Method for Deep Learning-Based Biomedical Image Segmentation
F. Isensee, P. F. Jaeger, S. A. Kohl, J. Petersen, & K. H. Maier-Hein.
*Nature Methods.* 2021;18(2):203-211.
DOI: [10.1038/s41592-020-01008-z](https://www.nature.com/articles/s41592-020-01008-z)
""")
submit_button.click(
fn=run_nnunet_predict,
inputs=[flair_input, hd_bet],
outputs=[seg_output, input_img, output_img, clusters_output]
)
# Debugging GPU environment
if torch.cuda.is_available():
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available. Falling back to CPU.")
os.system("nvidia-smi")
setup_hd_bet()
download_model()
if __name__ == "__main__":
demo.launch(share=True)