FrancescoLR commited on
Commit
2fa3177
·
1 Parent(s): 6c748cb

Updated app.py

Browse files
Files changed (2) hide show
  1. app.py +28 -3
  2. requirements.txt +1 -0
app.py CHANGED
@@ -4,6 +4,8 @@ import os
4
  import shutil
5
  from huggingface_hub import hf_hub_download
6
  import torch
 
 
7
  import spaces # Import spaces for GPU decoration
8
 
9
  # Define paths
@@ -24,6 +26,19 @@ def download_model():
24
  subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
25
  print("Dataset004_WML downloaded and extracted.")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Function to run nnUNet inference
28
  @spaces.GPU # Decorate the function to allocate GPU for its execution
29
  def run_nnunet_predict(nifti_file):
@@ -64,18 +79,28 @@ def run_nnunet_predict(nifti_file):
64
  new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
65
  if os.path.exists(output_file):
66
  os.rename(output_file, new_output_file)
67
- return new_output_file
 
 
 
 
 
 
 
68
  else:
69
  return "Error: Output file not found."
70
  except subprocess.CalledProcessError as e:
71
  return f"Error: {e}"
72
 
73
-
74
  # Gradio Interface
75
  interface = gr.Interface(
76
  fn=run_nnunet_predict,
77
  inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
78
- outputs=gr.File(label="Download Segmentation Mask"),
 
 
 
 
79
  title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
80
  description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
81
  )
 
4
  import shutil
5
  from huggingface_hub import hf_hub_download
6
  import torch
7
+ import nibabel as nib
8
+ import matplotlib.pyplot as plt
9
  import spaces # Import spaces for GPU decoration
10
 
11
  # Define paths
 
26
  subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
27
  print("Dataset004_WML downloaded and extracted.")
28
 
29
+ def extract_middle_slice(nifti_path, output_image_path):
30
+ """
31
+ Extracts a middle slice from a 3D NIfTI image and saves it as a PNG file.
32
+ """
33
+ img = nib.load(nifti_path)
34
+ data = img.get_fdata()
35
+ middle_slice_index = data.shape[2] // 2 # Middle slice along the z-axis
36
+ plt.figure(figsize=(6, 6))
37
+ plt.imshow(data[:, :, middle_slice_index], cmap="gray")
38
+ plt.axis("off")
39
+ plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
40
+ plt.close()
41
+
42
  # Function to run nnUNet inference
43
  @spaces.GPU # Decorate the function to allocate GPU for its execution
44
  def run_nnunet_predict(nifti_file):
 
79
  new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
80
  if os.path.exists(output_file):
81
  os.rename(output_file, new_output_file)
82
+
83
+ # Extract and save 2D slices
84
+ input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
85
+ output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
86
+ extract_middle_slice(input_path, input_slice_path)
87
+ extract_middle_slice(new_output_file, output_slice_path)
88
+
89
+ return input_slice_path, output_slice_path, new_output_file
90
  else:
91
  return "Error: Output file not found."
92
  except subprocess.CalledProcessError as e:
93
  return f"Error: {e}"
94
 
 
95
  # Gradio Interface
96
  interface = gr.Interface(
97
  fn=run_nnunet_predict,
98
  inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
99
+ outputs=[
100
+ gr.Image(label="Input Middle Slice"),
101
+ gr.Image(label="Output Middle Slice"),
102
+ gr.File(label="Download Segmentation Mask")
103
+ ],
104
  title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
105
  description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
106
  )
requirements.txt CHANGED
@@ -6,4 +6,5 @@ torchvision
6
  torchaudio
7
  nnunetv2
8
  nibabel
 
9
  numpy
 
6
  torchaudio
7
  nnunetv2
8
  nibabel
9
+ matplotlib
10
  numpy