Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,842 Bytes
7801205 c02fad1 18f25b9 2fa3177 5dd0ea6 d21c058 121d535 2cac7f2 7801205 2cac7f2 60723f6 2cac7f2 e388d15 121d535 2cac7f2 121d535 a74b6ce 2fa3177 a74b6ce 2fa3177 121d535 2fa3177 121d535 79fd983 a74b6ce 121d535 a74b6ce 79fd983 842b2ca a74b6ce 44da582 4c7435a c237131 842b2ca 0496106 a74b6ce 0ac3af1 0496106 44da582 0ac3af1 70fbe25 0ac3af1 842b2ca 0496106 0ac3af1 121d535 79fd983 2d8f0f2 e45af6e 2d8f0f2 e45af6e 2d8f0f2 79fd983 a74b6ce 79fd983 842b2ca 79fd983 842b2ca 79fd983 842b2ca 79fd983 2fa3177 121d535 2cac7f2 801067b 7801205 6c748cb 9c5c250 6c748cb 7801205 7a2ca4b 4d847b9 7a2ca4b 7801205 ed9fa70 7801205 2cac7f2 7801205 5dd0ea6 6c748cb 7801205 c5b67f9 7801205 6c748cb 7801205 6c748cb 5dd0ea6 6c748cb 2fa3177 c237131 2fa3177 c237131 2fa3177 c5b67f9 9c5c250 5dd0ea6 7801205 c5df5bb df77e6e 43eda01 c444218 43eda01 c444218 988b12f 9dda281 7926ecf 9dda281 d15884e 988b12f 43eda01 c5df5bb 527dbff c5df5bb ae7e73e da3fc9c c5df5bb ae7e73e 43eda01 df77e6e c5df5bb 702eb13 c5df5bb 702eb13 c5df5bb da3fc9c c5df5bb da3fc9c 812aa8d 18f25b9 812aa8d 18f25b9 c5df5bb 18f25b9 da3fc9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import spaces # Import spaces for GPU decoration
import numpy as np
from scipy.ndimage import center_of_mass, zoom
# Define paths
MODEL_DIR = "./model" # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
# Function to download the Dataset004_WML folder
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading Dataset004_WML.zip...")
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
print("Dataset004_WML downloaded and extracted.")
def resample_to_isotropic(data, affine, target_spacing=1.0):
"""
Resamples a 3D NIfTI image to isotropic voxel size.
Parameters:
data (numpy.ndarray): The input 3D image data.
affine (numpy.ndarray): The affine transformation matrix.
target_spacing (float): Desired isotropic voxel spacing (in mm).
Returns:
resampled_data (numpy.ndarray): Resampled image data.
resampled_affine (numpy.ndarray): Updated affine matrix.
"""
# Extract current voxel dimensions from the affine matrix
current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0))
# Compute the scaling factors for resampling
scaling_factors = current_spacing / target_spacing
# Resample the data using zoom
resampled_data = zoom(data, zoom=scaling_factors, order=1) # Linear interpolation
# Update the affine matrix to reflect the new voxel dimensions
resampled_affine = affine.copy()
resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis]
return resampled_data, resampled_affine
def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
"""
Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
"""
# Load NIfTI image
img = nib.load(nifti_path)
data = img.get_fdata()
affine = img.affine
# Resample the image to 1 mm isotropic
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
# Compute the center of mass of non-zero voxels
com = center_of_mass(resampled_data > 0)
center = np.round(com).astype(int)
# Define half the slice size
half_size = slice_size // 2
def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None):
"""
Extracts slices from a 3D NIfTI image. If a center is provided, it uses it;
otherwise, computes the center of mass of non-zero voxels. Slices are taken
along axial, coronal, and sagittal planes and saved as a single PNG.
"""
# Load NIfTI image
img = nib.load(nifti_path)
data = img.get_fdata()
affine = img.affine
# Resample the image to 1 mm isotropic
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
# Compute or reuse the center of mass
if center is None:
com = center_of_mass(resampled_data > 0)
center = np.round(com).astype(int)
# Define half the slice size
half_size = slice_size // 2
# Safely extract and pad 2D slices
def extract_2d_slice(data, center, axis):
slices = [slice(None)] * 3
slices[axis] = center[axis] # Fix the axis to extract a single slice
extracted_slice = data[tuple(slices)]
# Crop the 2D slice around the center in the remaining dimensions
remaining_axes = [i for i in range(3) if i != axis]
cropped_slice = extracted_slice[
max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]),
max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]),
]
# Pad the slice to ensure 180x180 dimensions
pad_height = slice_size - cropped_slice.shape[0]
pad_width = slice_size - cropped_slice.shape[1]
padded_slice = np.pad(cropped_slice,
((pad_height // 2, pad_height - pad_height // 2),
(pad_width // 2, pad_width - pad_width // 2)),
mode='constant', constant_values=0)
return padded_slice
# Extract slices in axial, coronal, and sagittal planes
axial_slice = extract_2d_slice(resampled_data, center, axis=2) # Axial (z-axis)
coronal_slice = extract_2d_slice(resampled_data, center, axis=1) # Coronal (y-axis)
sagittal_slice = extract_2d_slice(resampled_data, center, axis=0) # Sagittal (x-axis)
# Apply rotations to each slice
axial_slice = np.rot90(axial_slice, k=-1) # 90 degrees clockwise
coronal_slice = np.rot90(coronal_slice, k=1) # 90 degrees anticlockwise
coronal_slice = np.rot90(coronal_slice, k=2) # Additional 180 degrees
sagittal_slice = np.rot90(sagittal_slice, k=1) # 90 degrees anticlockwise
sagittal_slice = np.rot90(sagittal_slice, k=2) # Additional 180 degrees
# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# Plot each padded and rotated slice
axes[0].imshow(axial_slice, cmap="gray", origin="lower")
axes[0].axis("off")
axes[1].imshow(coronal_slice, cmap="gray", origin="lower")
axes[1].axis("off")
axes[2].imshow(sagittal_slice, cmap="gray", origin="lower")
axes[2].axis("off")
# Save the figure
plt.tight_layout()
plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
plt.close()
# Function to run nnUNet inference
@spaces.GPU(duration=70) # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Extract the original filename without the extension
original_filename = os.path.basename(nifti_file.name)
base_filename = original_filename.replace(".nii.gz", "")
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
# Debugging: List files in the /tmp/input directory
print("Files in /tmp/input:")
print(os.listdir(INPUT_DIR))
# Set environment variables for nnUNet
os.environ["nnUNet_results"] = MODEL_DIR
# Construct and run the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs",
"-device", "cuda" # Explicitly use GPU
]
print("Files in /tmp/output:")
print(os.listdir(OUTPUT_DIR))
try:
subprocess.run(command, check=True)
# Rename the output file to match the original input filename
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
if os.path.exists(output_file):
os.rename(output_file, new_output_file)
# Compute center of mass for the input image
img = nib.load(input_path)
data = img.get_fdata()
affine = img.affine
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
com = center_of_mass(resampled_data > 0) # Center of mass
center = np.round(com).astype(int) # Round to integer
# Extract and save 2D slices
input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
extract_middle_slices(input_path, input_slice_path, center=center)
extract_middle_slices(new_output_file, output_slice_path, center=center)
# Return paths for the Gradio interface
return new_output_file, input_slice_path, output_slice_path
else:
return "Error: Output file not found."
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio interface with adjusted layout
with gr.Blocks() as demo:
gr.Markdown("""
# 🔥 FLAMeS: FLAIR Lesion Segmentation for Multiple Sclerosis
Upload a skull-stripped FLAIR brain MRI in NIfTI (.nii.gz) format to generate a binary segmentation of multiple sclerosis lesions.
FLAMeS is based on the nnUNet framework<sup>2</sup> and was trained on 668 MRI scans acquired using Siemens, GE, and Philips 1.5T and 3T scanners<sup>1</sup>.
For skull-stripping, we suggest using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results.
Inference takes approximately 1 minute per MRI, with processing limited to one scan at a time due to Hugging Face's zero-GPU usage constraints. To process multiple cases simultaneously, download [FLAMeS's model](https://huggingface.co/FrancescoLR/FLAMeS-model) and run it locally using your own GPU or CPU setup.
**Disclaimer:** Uploaded data is stored temporarily, no one has access to it, and it is deleted when the app is closed. For details, see [Gradio's file access guide](https://www.gradio.app/main/guides/file-access). Human subjects data should only be uploaded for processing if permitted by your institution's human subjects protection office.
This is a research tool and is not intended for clinical use. Clinical decisions should not be based on the outputs of this tool.
""")
with gr.Row():
with gr.Column(scale=1):
flair_input = gr.File(label="Upload a FLAIR Image (.nii.gz)")
submit_button = gr.Button("Submit")
with gr.Column(scale=2):
seg_output = gr.File(label="Download the Lesion Segmentation Mask")
input_img = gr.Image(label="Input: FLAIR image")
output_img = gr.Image(label="Output: Lesion Mask")
gr.Markdown("""
**If you find this tool useful, please consider citing:**
1. A Deep Learning-Based Pipeline for Longitudinal White Matter Lesion Segmentation Using Diverse FLAIR Images
F. La Rosa, J. Dos Santos Silva, W. A. Mullins, H. Greenspan, J. F. Sumowski, D. S. Reich, & E. S. Beck.
*ACTRIMS Forum 2023. Multiple Sclerosis Journal.* 2023;29(2_suppl):18-242.
DOI: [10.1177/13524585231169437](https://doi.org/10.1177/13524585231169437)
2. nnU-Net: A Self-Configuring Method for Deep Learning-Based Biomedical Image Segmentation
F. Isensee, P. F. Jaeger, S. A. Kohl, J. Petersen, & K. H. Maier-Hein.
*Nature Methods.* 2021;18(2):203-211.
DOI: [10.1038/s41592-020-01008-z](https://www.nature.com/articles/s41592-020-01008-z)
""")
submit_button.click(
fn=run_nnunet_predict,
inputs=[flair_input],
outputs=[seg_output, input_img, output_img]
)
# Debugging GPU environment
if torch.cuda.is_available():
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available. Falling back to CPU.")
os.system("nvidia-smi")
download_model()
if __name__ == "__main__":
demo.launch(share=True)
|