lukiod commited on
Commit
065e508
·
verified ·
1 Parent(s): b1bba6c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -0
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ import tempfile
5
+ import time
6
+
7
+ # nnU-Net and visualization imports
8
+ from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
9
+ import pyvista as pv
10
+ import nibabel as nib
11
+ import numpy as np
12
+ from matplotlib import cm
13
+ from matplotlib.colors import ListedColormap
14
+ from stpyvista import stpyvista
15
+
16
+ # --- Caching the nnU-Net Predictor ---
17
+ # This is crucial for performance. The model is loaded once and stored in memory.
18
+ @st.cache_resource
19
+ def load_predictor(model_folder):
20
+ """
21
+ Loads and initializes the nnUNetPredictor.
22
+ The @st.cache_resource decorator ensures this function is only run once.
23
+ """
24
+ st.write("Initializing nnU-Net predictor... (This may take a moment)")
25
+
26
+ # Instantiate the predictor
27
+ predictor = nnUNetPredictor(
28
+ tile_step_size=0.5,
29
+ use_gaussian=True,
30
+ use_mirroring=True,
31
+ perform_everything_on_device=True,
32
+ device=torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu'),
33
+ verbose=False,
34
+ verbose_preprocessing=False,
35
+ allow_tqdm=True
36
+ )
37
+
38
+ # Initialize from the trained model folder
39
+ try:
40
+ predictor.initialize_from_trained_model_folder(
41
+ model_folder,
42
+ use_folds=(0,), # Assuming you want to use fold 0
43
+ checkpoint_name='checkpoint_final.pth',
44
+ )
45
+ st.success("nnU-Net predictor initialized successfully!")
46
+ return predictor
47
+ except Exception as e:
48
+ st.error(f"Failed to initialize predictor from {model_folder}. Error: {e}")
49
+ return None
50
+
51
+ # --- Visualization Function (from your script) ---
52
+ def generate_visualization(base_image_path, mask_path):
53
+ """
54
+ Generates a PyVista plot of the base image and the segmentation mask.
55
+ """
56
+ # Load base CT scan
57
+ img = nib.load(base_image_path)
58
+ img_data = img.get_fdata()
59
+ img_data = (img_data - np.min(img_data)) / np.ptp(img_data) # Normalize 0–1
60
+
61
+ # Load segmentation mask
62
+ mask = nib.load(mask_path)
63
+ mask_data = mask.get_fdata().astype(np.uint8)
64
+
65
+ # Label dictionary (from your script)
66
+ label_dict = {
67
+ 1: "Lower Jawbone", 2: "Upper Jawbone", 3: "Left Inferior Alveolar Canal",
68
+ 4: "Right Inferior Alveolar Canal", 5: "Left Maxillary Sinus", 6: "Right Maxillary Sinus",
69
+ 7: "Pharynx", 8: "Bridge", 9: "Crown", 10: "Implant", 11: "Upper Right Central Incisor",
70
+ 12: "Upper Right Lateral Incisor", 13: "Upper Right Canine", 14: "Upper Right First Premolar",
71
+ 15: "Upper Right Second Premolar", 16: "Upper Right First Molar", 17: "Upper Right Second Molar",
72
+ 18: "Upper Right Third Molar", 21: "Upper Left Central Incisor",
73
+ 22: "Upper Left Lateral Incisor", 23: "Upper Left Canine", 24: "Upper Left First Premolar",
74
+ 25: "Upper Left Second Premolar", 26: "Upper Left First Molar", 27: "Upper Left Second Molar",
75
+ 28: "Upper Left Third Molar", 31: "Lower Left Central Incisor",
76
+ 32: "Lower Left Lateral Incisor", 33: "Lower Left Canine", 34: "Lower Left First Premolar",
77
+ 35: "Lower Left Second Premolar", 36: "Lower Left First Molar", 37: "Lower Left Second Molar",
78
+ 38: "Lower Left Third Molar", 41: "Lower Right Central Incisor",
79
+ 42: "Lower Right Lateral Incisor", 43: "Lower Right Canine", 44: "Lower Right First Premolar",
80
+ 45: "Lower Right Second Premolar", 46: "Lower Right First Molar", 47: "Lower Right Second Molar",
81
+ 48: "Lower Right Third Molar"
82
+ }
83
+
84
+ # Generate color map
85
+ num_labels = max(label_dict.keys()) + 1
86
+ colors = np.vstack([
87
+ [[0, 0, 0, 0]],
88
+ cm.get_cmap('tab20b')(np.linspace(0, 1, 20)),
89
+ cm.get_cmap('tab20c')(np.linspace(0, 1, 20)),
90
+ cm.get_cmap('gist_rainbow')(np.linspace(0, 1, num_labels))
91
+ ])[:, :4]
92
+ colors = colors[:num_labels]
93
+ colormap = ListedColormap(colors)
94
+
95
+ # Wrap data in PyVista objects
96
+ vol_img = pv.wrap(img_data)
97
+ vol_mask = pv.wrap(mask_data)
98
+
99
+ # Create plotter
100
+ plotter = pv.Plotter(window_size=[800, 800])
101
+ plotter.add_volume(vol_img, cmap="bone", opacity="sigmoid", name="CT Scan")
102
+ plotter.add_volume(
103
+ vol_mask,
104
+ cmap=colormap,
105
+ opacity=[0, 0.5], # Make label 0 transparent
106
+ mapper='gpu', # Use GPU for better performance
107
+ name="Segmentation Mask"
108
+ )
109
+ plotter.camera_position = 'xy'
110
+
111
+ return plotter
112
+
113
+
114
+ # --- Main Streamlit App ---
115
+ def main():
116
+ st.set_page_config(layout="wide", page_title="nnU-Net Inference App")
117
+
118
+ st.title("🦷 nnU-Net Inference and 3D Visualization")
119
+ st.markdown("Upload a medical image, run nnU-Net for segmentation, and visualize the results in 3D.")
120
+
121
+ # --- Sidebar for Inputs ---
122
+ st.sidebar.header("1. Configure Model")
123
+ # IMPORTANT: Update this path to your default nnU-Net results folder
124
+ default_model_path = "/path/to/your/nnUNet_results/Dataset114_ToothFairy2/nnUNetTrainer__nnUNetPlans__3d_fullres"
125
+ model_folder = st.sidebar.text_input(
126
+ "Enter path to trained model folder:",
127
+ value=default_model_path
128
+ )
129
+
130
+ if not os.path.isdir(model_folder):
131
+ st.sidebar.error("Model folder not found. Please provide a valid path.")
132
+ st.stop()
133
+
134
+ # Load the model (will be cached)
135
+ predictor = load_predictor(model_folder)
136
+ if predictor is None:
137
+ st.stop()
138
+
139
+ st.sidebar.header("2. Upload Image")
140
+ uploaded_file = st.sidebar.file_uploader(
141
+ "Choose a NIfTI file (.nii.gz)",
142
+ type=['nii.gz']
143
+ )
144
+
145
+ # --- Main Panel for Execution and Visualization ---
146
+ if uploaded_file is not None:
147
+ if st.sidebar.button("✨ Run Prediction and Visualize"):
148
+ # Use a temporary directory for safety and automatic cleanup
149
+ with tempfile.TemporaryDirectory() as temp_dir:
150
+ input_dir = os.path.join(temp_dir, 'input')
151
+ output_dir = os.path.join(temp_dir, 'output')
152
+ os.makedirs(input_dir, exist_ok=True)
153
+ os.makedirs(output_dir, exist_ok=True)
154
+
155
+ # Save the uploaded file to the temp input directory
156
+ # The filename needs the _0000 suffix for nnU-Net's default file prediction
157
+ base_name = uploaded_file.name.replace(".nii.gz", "")
158
+ input_file_path = os.path.join(input_dir, f"{base_name}_0000.nii.gz")
159
+
160
+ with open(input_file_path, "wb") as f:
161
+ f.write(uploaded_file.getbuffer())
162
+
163
+ st.info(f"File '{uploaded_file.name}' saved to temporary location.")
164
+
165
+ # --- Run Prediction ---
166
+ with st.spinner("🧠 Running nnU-Net inference... This can take a while."):
167
+ start_time = time.time()
168
+
169
+ # We use predict_from_files as it's the most efficient for file-based workflows
170
+ predictor.predict_from_files(
171
+ input_dir,
172
+ output_dir,
173
+ save_probabilities=False,
174
+ overwrite=True,
175
+ num_processes_preprocessing=2,
176
+ num_processes_segmentation_export=2
177
+ )
178
+
179
+ end_time = time.time()
180
+ st.success(f"Inference complete! 🎉 (Time taken: {end_time - start_time:.2f} seconds)")
181
+
182
+ # Find the output file
183
+ output_files = os.listdir(output_dir)
184
+ if not output_files:
185
+ st.error("Prediction failed. No output file was generated.")
186
+ st.stop()
187
+
188
+ output_mask_path = os.path.join(output_dir, output_files[0])
189
+
190
+ # --- Generate Visualization ---
191
+ with st.spinner("🎨 Generating 3D visualization..."):
192
+ plotter = generate_visualization(input_file_path, output_mask_path)
193
+ stpyvista(plotter, key="pv_plot")
194
+
195
+ # --- Provide Download Link for the Mask ---
196
+ with open(output_mask_path, "rb") as f:
197
+ st.download_button(
198
+ label="⬇️ Download Segmentation Mask",
199
+ data=f,
200
+ file_name=f"predicted_{uploaded_file.name}",
201
+ mime="application/gzip"
202
+ )
203
+
204
+ else:
205
+ st.info("Please upload a file to begin.")
206
+
207
+ if __name__ == '__main__':
208
+ main()