File size: 14,542 Bytes
7801205
 
 
 
c02fad1
18f25b9
2fa3177
887ad93
2fa3177
5dd0ea6
d21c058
95075bf
2cac7f2
 
 
 
7801205
 
 
2cac7f2
 
 
60721af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60723f6
2cac7f2
 
 
e388d15
 
 
 
121d535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cac7f2
121d535
 
 
 
 
 
a74b6ce
2fa3177
a74b6ce
 
2fa3177
121d535
2fa3177
 
121d535
 
 
 
79fd983
a74b6ce
121d535
a74b6ce
79fd983
842b2ca
a74b6ce
44da582
f2ca9b7
c237131
f2ca9b7
 
 
 
 
 
c237131
 
 
 
 
 
 
 
 
f2ca9b7
 
 
7d6c8fc
1d331f0
f2ca9b7
 
 
8848b53
f2ca9b7
c237131
 
1d331f0
c237131
 
 
 
 
f2ca9b7
0496106
a74b6ce
f2ca9b7
0496106
44da582
70fbe25
0ac3af1
 
 
 
842b2ca
 
 
f2ca9b7
 
842b2ca
 
 
0496106
f2ca9b7
1d331f0
 
 
79fd983
f2ca9b7
 
 
 
 
 
2d8f0f2
79fd983
a74b6ce
79fd983
f2ca9b7
 
f5d2b82
 
 
 
 
 
 
 
ad860e3
f2ca9b7
 
 
 
 
 
79fd983
f2ca9b7
 
 
 
 
79fd983
f2ca9b7
79fd983
2fa3177
 
f2ca9b7
 
 
121d535
2cac7f2
9df3660
699652d
7801205
 
 
 
6c748cb
 
9c5c250
6c748cb
7801205
7a2ca4b
4d847b9
699652d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a581467
7a2ca4b
 
 
7801205
 
ed9fa70
7801205
2cac7f2
7801205
 
 
 
 
 
5dd0ea6
9df3660
7801205
c5b67f9
 
7801205
 
6c748cb
 
7801205
6c748cb
5dd0ea6
6c748cb
2fa3177
c237131
 
 
 
 
 
 
 
2fa3177
 
 
f2ca9b7
 
2fa3177
ccd4fd7
 
 
fe423d6
ccd4fd7
fe423d6
c5b67f9
72ec283
5dd0ea6
 
7801205
 
 
c5df5bb
df77e6e
43eda01
c444218
43eda01
2d911b6
988b12f
2d911b6
7926ecf
6cb458d
9dda281
d15884e
 
988b12f
43eda01
c5df5bb
527dbff
c5df5bb
ae7e73e
3ac4249
da3fc9c
c5df5bb
ae7e73e
ccd4fd7
43eda01
fe423d6
ccd4fd7
df77e6e
c5df5bb
 
 
f499bba
 
 
 
c5df5bb
 
702eb13
 
 
c5df5bb
 
da3fc9c
c5df5bb
3ac4249
fe423d6
da3fc9c
 
812aa8d
18f25b9
812aa8d
18f25b9
 
c5df5bb
 
60721af
c5df5bb
18f25b9
da3fc9c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import nibabel as nib
import matplotlib as mpl
import matplotlib.pyplot as plt
import spaces  # Import spaces for GPU decoration
import numpy as np
from scipy.ndimage import center_of_mass, zoom, label, generate_binary_structure

# Define paths
MODEL_DIR = "./model"  # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML")  # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"

# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model"  # Replace with your actual model repository ID

import os
import subprocess

def setup_hd_bet(repo_dir="./HD-BET"):
    """
    Clones the HD-BET repository and installs it in editable mode using pip.

    Parameters:
        repo_dir (str): Directory where HD-BET will be cloned and installed.
    """
    if not os.path.exists(repo_dir):
        print("Cloning HD-BET repository...")
        subprocess.run(["git", "clone", "https://github.com/MIC-DKFZ/HD-BET", repo_dir], check=True)
    else:
        print("HD-BET repository already exists.")

    # Install the HD-BET package from source
    print("Installing HD-BET using pip...")
    subprocess.run(["pip", "install", "-e", "."], cwd=repo_dir, check=True)

# Function to download the Dataset004_WML folder
def download_model():
    if not os.path.exists(DATASET_DIR):
        os.makedirs(DATASET_DIR, exist_ok=True)
        print("Downloading Dataset004_WML.zip...")
        zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
        subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
        print("Dataset004_WML downloaded and extracted.")
        
def resample_to_isotropic(data, affine, target_spacing=1.0):
    """
    Resamples a 3D NIfTI image to isotropic voxel size.
    
    Parameters:
        data (numpy.ndarray): The input 3D image data.
        affine (numpy.ndarray): The affine transformation matrix.
        target_spacing (float): Desired isotropic voxel spacing (in mm).
    
    Returns:
        resampled_data (numpy.ndarray): Resampled image data.
        resampled_affine (numpy.ndarray): Updated affine matrix.
    """
    # Extract current voxel dimensions from the affine matrix
    current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0))

    # Compute the scaling factors for resampling
    scaling_factors = current_spacing / target_spacing

    # Resample the data using zoom
    resampled_data = zoom(data, zoom=scaling_factors, order=1)  # Linear interpolation

    # Update the affine matrix to reflect the new voxel dimensions
    resampled_affine = affine.copy()
    resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis]

    return resampled_data, resampled_affine        
    
def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
    """
    Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
    The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
    """
  # Load NIfTI image
    img = nib.load(nifti_path)
    data = img.get_fdata()
    affine = img.affine

    # Resample the image to 1 mm isotropic
    resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)

    # Compute the center of mass of non-zero voxels
    com = center_of_mass(resampled_data > 0)
    center = np.round(com).astype(int)

    # Define half the slice size
    half_size = slice_size // 2

def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None, label_components=False):
    """
    Extracts slices from a 3D NIfTI image. 
    If label_components=True, it assigns different labels (colors) to each connected component (26-connectivity)
    and returns the labeled 3D mask.
    
    Returns:
        labeled_data (np.ndarray): The 3D array (either labeled or original).
    """
    # Load NIfTI image
    img = nib.load(nifti_path)
    data = img.get_fdata()
    affine = img.affine

    # Resample the image to 1 mm isotropic
    resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)

    # Optionally label connected components
    if label_components:
        structure = generate_binary_structure(3, 3)  # 3D, 26-connectivity
        labeled_data, num_features = label(data > 0, structure=structure)
        labeled_data_resampled, num_features = label(resampled_data > 0, structure=structure)
    else:
        labeled_data = resampled_data
        num_features = None  # Not needed if we're not labeling
        labeled_data_resampled = resampled_data

    # Compute or reuse the center of mass
    if center is None:
        com = center_of_mass(labeled_data_resampled > 0)
        center = np.round(com).astype(int)

    # Define half the slice size
    half_size = slice_size // 2

    # Function to extract and pad slices
    def extract_2d_slice(data, center, axis):
        slices = [slice(None)] * 3
        slices[axis] = center[axis]
        extracted_slice = data[tuple(slices)]

        remaining_axes = [i for i in range(3) if i != axis]
        cropped_slice = extracted_slice[
            max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]),
            max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]),
        ]

        pad_height = slice_size - cropped_slice.shape[0]
        pad_width = slice_size - cropped_slice.shape[1]
        padded_slice = np.pad(cropped_slice,
                              ((pad_height // 2, pad_height - pad_height // 2),
                               (pad_width // 2, pad_width - pad_width // 2)),
                              mode='constant', constant_values=0)
        return padded_slice

    # Extract slices
    axial_slice = extract_2d_slice(labeled_data_resampled, center, axis=2)
    coronal_slice = extract_2d_slice(labeled_data_resampled, center, axis=1)
    sagittal_slice = extract_2d_slice(labeled_data_resampled, center, axis=0)

    # Apply rotations
    axial_slice = np.rot90(axial_slice, k=-1)
    coronal_slice = np.rot90(coronal_slice, k=1)
    coronal_slice = np.rot90(coronal_slice, k=2)
    sagittal_slice = np.rot90(sagittal_slice, k=1)
    sagittal_slice = np.rot90(sagittal_slice, k=2)

    # Create subplots
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))

    # Choose colormap
    if label_components:
        # Create 256 pastel colors
        pastel = plt.cm.Pastel1(np.linspace(0, 1, 256))
        np.random.seed(42)  # For reproducibility
        shuffled_colors = pastel[1:].copy()
        np.random.shuffle(shuffled_colors)
        final_colors = np.vstack([np.array([0, 0, 0, 1]), shuffled_colors])
 
        custom_cmap = mpl.colors.ListedColormap(final_colors)
        cmap = custom_cmap  # Colorful
        vmin = 0
        vmax = num_features
    else:
        cmap = "gray"  # Normal
        vmin = None
        vmax = None

    # Plot slices
    for idx, slice_data in enumerate([axial_slice, coronal_slice, sagittal_slice]):
        ax = axes[idx]
        im = ax.imshow(slice_data, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax)
        ax.axis("off")

    # Save figure
    plt.tight_layout()
    plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
    plt.close()

    # Return the labeled mask
    return labeled_data
    
# Function to run nnUNet inference
@spaces.GPU(duration=90)  # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file,hd_bet=False):
    # Prepare directories
    os.makedirs(INPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Extract the original filename without the extension
    original_filename = os.path.basename(nifti_file.name)
    base_filename = original_filename.replace(".nii.gz", "")
    
    # Save the uploaded file to the input directory
    input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
    os.rename(nifti_file.name, input_path)  # Move the uploaded file to the expected input location

    if hd_bet:
       # Apply skull-stripping with HD-BET
       hd_bet_output_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
       try:
          subprocess.run([
          "hd-bet",
          "-i", input_path,
          "-o", hd_bet_output_path,
          "-device", "cuda",  # or "cpu"
          "--disable_tta" ], check=True)
          print("Skull-stripping completed.")
          input_path = hd_bet_output_path
       except subprocess.CalledProcessError as e:
          return f"HD-BET Error: {e}"
    
    # Debugging: List files in the /tmp/input directory
    print("Files in /tmp/input:")
    print(os.listdir(INPUT_DIR))

    # Set environment variables for nnUNet
    os.environ["nnUNet_results"] = MODEL_DIR

    # Construct and run the nnUNetv2_predict command
    command = [
        "nnUNetv2_predict",
        "-i", INPUT_DIR,
        "-o", OUTPUT_DIR,
        "-d", "004",                  # Dataset ID
        "-c", "3d_fullres",           # Configuration
        "-tr", "nnUNetTrainer_8000epochs",
        "-device", "cuda",  # Explicitly use GPU
    ]
    print("Files in /tmp/output:")
    print(os.listdir(OUTPUT_DIR))
    try:
        subprocess.run(command, check=True)

        # Rename the output file to match the original input filename
        output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
        new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
        if os.path.exists(output_file):
            os.rename(output_file, new_output_file)

            # Compute center of mass for the input image
            img = nib.load(input_path)
            data = img.get_fdata()
            affine = img.affine
            resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
            com = center_of_mass(resampled_data > 0)  # Center of mass
            center = np.round(com).astype(int)        # Round to integer

            # Extract and save 2D slices
            input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
            output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
            image = extract_middle_slices(input_path, input_slice_path, center=center)
            labeled_mask = extract_middle_slices(new_output_file, output_slice_path, center=center, label_components=True)

            # Load the binary lesion mask to get its affine
            output_img = nib.load(new_output_file)

            labeled_mask_path = os.path.join(OUTPUT_DIR, f"{base_filename}_LabeledClusters.nii.gz")
            nib.save(nib.Nifti1Image(labeled_mask.astype(np.int16), output_img.affine), labeled_mask_path)

            # Return paths for the Gradio interface
            return new_output_file, input_slice_path, output_slice_path, labeled_mask_path
        else:
            return "Error: Output file not found."
    except subprocess.CalledProcessError as e:
        return f"Error: {e}"

# Gradio interface with adjusted layout
with gr.Blocks() as demo:
    gr.Markdown("""
   # 🔥 FLAMeS: FLAIR Lesion Segmentation for Multiple Sclerosis

    Upload a FLAIR brain MRI in NIfTI format (.nii.gz) to generate a binary segmentation of multiple sclerosis lesions.  
    FLAMeS is based on the nnUNet framework<sup>2</sup> and was trained on 668 MRI scans acquired using Siemens, GE, and Philips 1.5T and 3T scanners<sup>1</sup>.
    We suggest skull-stripping the image in advance using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results. If that's not feasible, you can still upload your image as-is and enable the "Apply skull-stripping" option below.

    Inference takes approximately 1 minute per MRI, with processing limited to one scan at a time due to Hugging Face's zero-GPU usage constraints. To process multiple cases simultaneously, install the [nnUNet v2](https://github.com/MIC-DKFZ/nnUNet), download [FLAMeS's model](https://huggingface.co/FrancescoLR/FLAMeS-model) and run it locally using your own GPU or CPU setup.
    
    **Disclaimer:** Uploaded data is stored temporarily, no one has access to it, and it is deleted when the app is closed. For details, see [Gradio's file access guide](https://www.gradio.app/main/guides/file-access). Human subjects data should only be uploaded for processing if permitted by your institution's human subjects protection office.
    This is a research tool and is not intended for clinical use. Clinical decisions should not be based on the outputs of this tool.
    
    """)

    with gr.Row():
        with gr.Column(scale=1):
            flair_input = gr.File(label="Upload a FLAIR Image (.nii.gz)")
            hd_bet = gr.Checkbox(label="Apply skull-stripping", value=False)
            submit_button = gr.Button("Submit")
        with gr.Column(scale=2):
            seg_output = gr.File(label="Download the Lesion Segmentation Mask")
            clusters_output = gr.File(label="Download the Labeled Lesion Segmentation Mask")
            input_img = gr.Image(label="Input: FLAIR image")
            output_img = gr.Image(label="Output: Binary Lesion Mask")
            

    gr.Markdown("""
    **If you find this tool useful, please consider citing:**

    1. FLAMeS: A Robust Deep Learning Model for Automated Multiple Sclerosis Lesion Segmentation  
   Dereskewicz, E., La Rosa, F., dos Santos Silva, J., Sizer, E., Kohli, A., Wynen, M., ... & Beck, E. S.
   *medRxiv (2025)  
   DOI: [10.1177/13524585231169437](https://doi.org/10.1101/2025.05.19.25327707)

    2. nnU-Net: A Self-Configuring Method for Deep Learning-Based Biomedical Image Segmentation  
   F. Isensee, P. F. Jaeger, S. A. Kohl, J. Petersen, & K. H. Maier-Hein.  
   *Nature Methods.* 2021;18(2):203-211.  
   DOI: [10.1038/s41592-020-01008-z](https://www.nature.com/articles/s41592-020-01008-z)
    """)

    submit_button.click(
        fn=run_nnunet_predict,
        inputs=[flair_input, hd_bet],
        outputs=[seg_output, input_img, output_img, clusters_output]
    )

# Debugging GPU environment
if torch.cuda.is_available():
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU available. Falling back to CPU.")
    os.system("nvidia-smi")

setup_hd_bet()
download_model()

if __name__ == "__main__":
    demo.launch(share=True)