Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,542 Bytes
7801205 c02fad1 18f25b9 2fa3177 887ad93 2fa3177 5dd0ea6 d21c058 95075bf 2cac7f2 7801205 2cac7f2 60721af 60723f6 2cac7f2 e388d15 121d535 2cac7f2 121d535 a74b6ce 2fa3177 a74b6ce 2fa3177 121d535 2fa3177 121d535 79fd983 a74b6ce 121d535 a74b6ce 79fd983 842b2ca a74b6ce 44da582 f2ca9b7 c237131 f2ca9b7 c237131 f2ca9b7 7d6c8fc 1d331f0 f2ca9b7 8848b53 f2ca9b7 c237131 1d331f0 c237131 f2ca9b7 0496106 a74b6ce f2ca9b7 0496106 44da582 70fbe25 0ac3af1 842b2ca f2ca9b7 842b2ca 0496106 f2ca9b7 1d331f0 79fd983 f2ca9b7 2d8f0f2 79fd983 a74b6ce 79fd983 f2ca9b7 f5d2b82 ad860e3 f2ca9b7 79fd983 f2ca9b7 79fd983 f2ca9b7 79fd983 2fa3177 f2ca9b7 121d535 2cac7f2 9df3660 699652d 7801205 6c748cb 9c5c250 6c748cb 7801205 7a2ca4b 4d847b9 699652d a581467 7a2ca4b 7801205 ed9fa70 7801205 2cac7f2 7801205 5dd0ea6 9df3660 7801205 c5b67f9 7801205 6c748cb 7801205 6c748cb 5dd0ea6 6c748cb 2fa3177 c237131 2fa3177 f2ca9b7 2fa3177 ccd4fd7 fe423d6 ccd4fd7 fe423d6 c5b67f9 72ec283 5dd0ea6 7801205 c5df5bb df77e6e 43eda01 c444218 43eda01 2d911b6 988b12f 2d911b6 7926ecf 6cb458d 9dda281 d15884e 988b12f 43eda01 c5df5bb 527dbff c5df5bb ae7e73e 3ac4249 da3fc9c c5df5bb ae7e73e ccd4fd7 43eda01 fe423d6 ccd4fd7 df77e6e c5df5bb f499bba c5df5bb 702eb13 c5df5bb da3fc9c c5df5bb 3ac4249 fe423d6 da3fc9c 812aa8d 18f25b9 812aa8d 18f25b9 c5df5bb 60721af c5df5bb 18f25b9 da3fc9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 |
import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import nibabel as nib
import matplotlib as mpl
import matplotlib.pyplot as plt
import spaces # Import spaces for GPU decoration
import numpy as np
from scipy.ndimage import center_of_mass, zoom, label, generate_binary_structure
# Define paths
MODEL_DIR = "./model" # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
import os
import subprocess
def setup_hd_bet(repo_dir="./HD-BET"):
"""
Clones the HD-BET repository and installs it in editable mode using pip.
Parameters:
repo_dir (str): Directory where HD-BET will be cloned and installed.
"""
if not os.path.exists(repo_dir):
print("Cloning HD-BET repository...")
subprocess.run(["git", "clone", "https://github.com/MIC-DKFZ/HD-BET", repo_dir], check=True)
else:
print("HD-BET repository already exists.")
# Install the HD-BET package from source
print("Installing HD-BET using pip...")
subprocess.run(["pip", "install", "-e", "."], cwd=repo_dir, check=True)
# Function to download the Dataset004_WML folder
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading Dataset004_WML.zip...")
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
print("Dataset004_WML downloaded and extracted.")
def resample_to_isotropic(data, affine, target_spacing=1.0):
"""
Resamples a 3D NIfTI image to isotropic voxel size.
Parameters:
data (numpy.ndarray): The input 3D image data.
affine (numpy.ndarray): The affine transformation matrix.
target_spacing (float): Desired isotropic voxel spacing (in mm).
Returns:
resampled_data (numpy.ndarray): Resampled image data.
resampled_affine (numpy.ndarray): Updated affine matrix.
"""
# Extract current voxel dimensions from the affine matrix
current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0))
# Compute the scaling factors for resampling
scaling_factors = current_spacing / target_spacing
# Resample the data using zoom
resampled_data = zoom(data, zoom=scaling_factors, order=1) # Linear interpolation
# Update the affine matrix to reflect the new voxel dimensions
resampled_affine = affine.copy()
resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis]
return resampled_data, resampled_affine
def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
"""
Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
"""
# Load NIfTI image
img = nib.load(nifti_path)
data = img.get_fdata()
affine = img.affine
# Resample the image to 1 mm isotropic
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
# Compute the center of mass of non-zero voxels
com = center_of_mass(resampled_data > 0)
center = np.round(com).astype(int)
# Define half the slice size
half_size = slice_size // 2
def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None, label_components=False):
"""
Extracts slices from a 3D NIfTI image.
If label_components=True, it assigns different labels (colors) to each connected component (26-connectivity)
and returns the labeled 3D mask.
Returns:
labeled_data (np.ndarray): The 3D array (either labeled or original).
"""
# Load NIfTI image
img = nib.load(nifti_path)
data = img.get_fdata()
affine = img.affine
# Resample the image to 1 mm isotropic
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
# Optionally label connected components
if label_components:
structure = generate_binary_structure(3, 3) # 3D, 26-connectivity
labeled_data, num_features = label(data > 0, structure=structure)
labeled_data_resampled, num_features = label(resampled_data > 0, structure=structure)
else:
labeled_data = resampled_data
num_features = None # Not needed if we're not labeling
labeled_data_resampled = resampled_data
# Compute or reuse the center of mass
if center is None:
com = center_of_mass(labeled_data_resampled > 0)
center = np.round(com).astype(int)
# Define half the slice size
half_size = slice_size // 2
# Function to extract and pad slices
def extract_2d_slice(data, center, axis):
slices = [slice(None)] * 3
slices[axis] = center[axis]
extracted_slice = data[tuple(slices)]
remaining_axes = [i for i in range(3) if i != axis]
cropped_slice = extracted_slice[
max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]),
max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]),
]
pad_height = slice_size - cropped_slice.shape[0]
pad_width = slice_size - cropped_slice.shape[1]
padded_slice = np.pad(cropped_slice,
((pad_height // 2, pad_height - pad_height // 2),
(pad_width // 2, pad_width - pad_width // 2)),
mode='constant', constant_values=0)
return padded_slice
# Extract slices
axial_slice = extract_2d_slice(labeled_data_resampled, center, axis=2)
coronal_slice = extract_2d_slice(labeled_data_resampled, center, axis=1)
sagittal_slice = extract_2d_slice(labeled_data_resampled, center, axis=0)
# Apply rotations
axial_slice = np.rot90(axial_slice, k=-1)
coronal_slice = np.rot90(coronal_slice, k=1)
coronal_slice = np.rot90(coronal_slice, k=2)
sagittal_slice = np.rot90(sagittal_slice, k=1)
sagittal_slice = np.rot90(sagittal_slice, k=2)
# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# Choose colormap
if label_components:
# Create 256 pastel colors
pastel = plt.cm.Pastel1(np.linspace(0, 1, 256))
np.random.seed(42) # For reproducibility
shuffled_colors = pastel[1:].copy()
np.random.shuffle(shuffled_colors)
final_colors = np.vstack([np.array([0, 0, 0, 1]), shuffled_colors])
custom_cmap = mpl.colors.ListedColormap(final_colors)
cmap = custom_cmap # Colorful
vmin = 0
vmax = num_features
else:
cmap = "gray" # Normal
vmin = None
vmax = None
# Plot slices
for idx, slice_data in enumerate([axial_slice, coronal_slice, sagittal_slice]):
ax = axes[idx]
im = ax.imshow(slice_data, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax)
ax.axis("off")
# Save figure
plt.tight_layout()
plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
plt.close()
# Return the labeled mask
return labeled_data
# Function to run nnUNet inference
@spaces.GPU(duration=90) # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file,hd_bet=False):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Extract the original filename without the extension
original_filename = os.path.basename(nifti_file.name)
base_filename = original_filename.replace(".nii.gz", "")
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
if hd_bet:
# Apply skull-stripping with HD-BET
hd_bet_output_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
try:
subprocess.run([
"hd-bet",
"-i", input_path,
"-o", hd_bet_output_path,
"-device", "cuda", # or "cpu"
"--disable_tta" ], check=True)
print("Skull-stripping completed.")
input_path = hd_bet_output_path
except subprocess.CalledProcessError as e:
return f"HD-BET Error: {e}"
# Debugging: List files in the /tmp/input directory
print("Files in /tmp/input:")
print(os.listdir(INPUT_DIR))
# Set environment variables for nnUNet
os.environ["nnUNet_results"] = MODEL_DIR
# Construct and run the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs",
"-device", "cuda", # Explicitly use GPU
]
print("Files in /tmp/output:")
print(os.listdir(OUTPUT_DIR))
try:
subprocess.run(command, check=True)
# Rename the output file to match the original input filename
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
if os.path.exists(output_file):
os.rename(output_file, new_output_file)
# Compute center of mass for the input image
img = nib.load(input_path)
data = img.get_fdata()
affine = img.affine
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
com = center_of_mass(resampled_data > 0) # Center of mass
center = np.round(com).astype(int) # Round to integer
# Extract and save 2D slices
input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
image = extract_middle_slices(input_path, input_slice_path, center=center)
labeled_mask = extract_middle_slices(new_output_file, output_slice_path, center=center, label_components=True)
# Load the binary lesion mask to get its affine
output_img = nib.load(new_output_file)
labeled_mask_path = os.path.join(OUTPUT_DIR, f"{base_filename}_LabeledClusters.nii.gz")
nib.save(nib.Nifti1Image(labeled_mask.astype(np.int16), output_img.affine), labeled_mask_path)
# Return paths for the Gradio interface
return new_output_file, input_slice_path, output_slice_path, labeled_mask_path
else:
return "Error: Output file not found."
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio interface with adjusted layout
with gr.Blocks() as demo:
gr.Markdown("""
# 🔥 FLAMeS: FLAIR Lesion Segmentation for Multiple Sclerosis
Upload a FLAIR brain MRI in NIfTI format (.nii.gz) to generate a binary segmentation of multiple sclerosis lesions.
FLAMeS is based on the nnUNet framework<sup>2</sup> and was trained on 668 MRI scans acquired using Siemens, GE, and Philips 1.5T and 3T scanners<sup>1</sup>.
We suggest skull-stripping the image in advance using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results. If that's not feasible, you can still upload your image as-is and enable the "Apply skull-stripping" option below.
Inference takes approximately 1 minute per MRI, with processing limited to one scan at a time due to Hugging Face's zero-GPU usage constraints. To process multiple cases simultaneously, install the [nnUNet v2](https://github.com/MIC-DKFZ/nnUNet), download [FLAMeS's model](https://huggingface.co/FrancescoLR/FLAMeS-model) and run it locally using your own GPU or CPU setup.
**Disclaimer:** Uploaded data is stored temporarily, no one has access to it, and it is deleted when the app is closed. For details, see [Gradio's file access guide](https://www.gradio.app/main/guides/file-access). Human subjects data should only be uploaded for processing if permitted by your institution's human subjects protection office.
This is a research tool and is not intended for clinical use. Clinical decisions should not be based on the outputs of this tool.
""")
with gr.Row():
with gr.Column(scale=1):
flair_input = gr.File(label="Upload a FLAIR Image (.nii.gz)")
hd_bet = gr.Checkbox(label="Apply skull-stripping", value=False)
submit_button = gr.Button("Submit")
with gr.Column(scale=2):
seg_output = gr.File(label="Download the Lesion Segmentation Mask")
clusters_output = gr.File(label="Download the Labeled Lesion Segmentation Mask")
input_img = gr.Image(label="Input: FLAIR image")
output_img = gr.Image(label="Output: Binary Lesion Mask")
gr.Markdown("""
**If you find this tool useful, please consider citing:**
1. FLAMeS: A Robust Deep Learning Model for Automated Multiple Sclerosis Lesion Segmentation
Dereskewicz, E., La Rosa, F., dos Santos Silva, J., Sizer, E., Kohli, A., Wynen, M., ... & Beck, E. S.
*medRxiv (2025)
DOI: [10.1177/13524585231169437](https://doi.org/10.1101/2025.05.19.25327707)
2. nnU-Net: A Self-Configuring Method for Deep Learning-Based Biomedical Image Segmentation
F. Isensee, P. F. Jaeger, S. A. Kohl, J. Petersen, & K. H. Maier-Hein.
*Nature Methods.* 2021;18(2):203-211.
DOI: [10.1038/s41592-020-01008-z](https://www.nature.com/articles/s41592-020-01008-z)
""")
submit_button.click(
fn=run_nnunet_predict,
inputs=[flair_input, hd_bet],
outputs=[seg_output, input_img, output_img, clusters_output]
)
# Debugging GPU environment
if torch.cuda.is_available():
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available. Falling back to CPU.")
os.system("nvidia-smi")
setup_hd_bet()
download_model()
if __name__ == "__main__":
demo.launch(share=True)
|