|
|
import streamlit as st |
|
|
import torch |
|
|
import os |
|
|
import tempfile |
|
|
import time |
|
|
|
|
|
|
|
|
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor |
|
|
import pyvista as pv |
|
|
import nibabel as nib |
|
|
import numpy as np |
|
|
from matplotlib import cm |
|
|
from matplotlib.colors import ListedColormap |
|
|
from stpyvista import stpyvista |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_predictor(model_folder): |
|
|
""" |
|
|
Loads and initializes the nnUNetPredictor. |
|
|
The @st.cache_resource decorator ensures this function is only run once. |
|
|
""" |
|
|
st.write("Initializing nnU-Net predictor... (This may take a moment)") |
|
|
|
|
|
|
|
|
predictor = nnUNetPredictor( |
|
|
tile_step_size=0.5, |
|
|
use_gaussian=True, |
|
|
use_mirroring=True, |
|
|
perform_everything_on_device=True, |
|
|
device=torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu'), |
|
|
verbose=False, |
|
|
verbose_preprocessing=False, |
|
|
allow_tqdm=True |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
predictor.initialize_from_trained_model_folder( |
|
|
model_folder, |
|
|
use_folds=(0,), |
|
|
checkpoint_name='checkpoint_final.pth', |
|
|
) |
|
|
st.success("nnU-Net predictor initialized successfully!") |
|
|
return predictor |
|
|
except Exception as e: |
|
|
st.error(f"Failed to initialize predictor from {model_folder}. Error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def generate_visualization(base_image_path, mask_path): |
|
|
""" |
|
|
Generates a PyVista plot of the base image and the segmentation mask. |
|
|
""" |
|
|
|
|
|
img = nib.load(base_image_path) |
|
|
img_data = img.get_fdata() |
|
|
img_data = (img_data - np.min(img_data)) / np.ptp(img_data) |
|
|
|
|
|
|
|
|
mask = nib.load(mask_path) |
|
|
mask_data = mask.get_fdata().astype(np.uint8) |
|
|
|
|
|
|
|
|
label_dict = { |
|
|
1: "Lower Jawbone", 2: "Upper Jawbone", 3: "Left Inferior Alveolar Canal", |
|
|
4: "Right Inferior Alveolar Canal", 5: "Left Maxillary Sinus", 6: "Right Maxillary Sinus", |
|
|
7: "Pharynx", 8: "Bridge", 9: "Crown", 10: "Implant", 11: "Upper Right Central Incisor", |
|
|
12: "Upper Right Lateral Incisor", 13: "Upper Right Canine", 14: "Upper Right First Premolar", |
|
|
15: "Upper Right Second Premolar", 16: "Upper Right First Molar", 17: "Upper Right Second Molar", |
|
|
18: "Upper Right Third Molar", 21: "Upper Left Central Incisor", |
|
|
22: "Upper Left Lateral Incisor", 23: "Upper Left Canine", 24: "Upper Left First Premolar", |
|
|
25: "Upper Left Second Premolar", 26: "Upper Left First Molar", 27: "Upper Left Second Molar", |
|
|
28: "Upper Left Third Molar", 31: "Lower Left Central Incisor", |
|
|
32: "Lower Left Lateral Incisor", 33: "Lower Left Canine", 34: "Lower Left First Premolar", |
|
|
35: "Lower Left Second Premolar", 36: "Lower Left First Molar", 37: "Lower Left Second Molar", |
|
|
38: "Lower Left Third Molar", 41: "Lower Right Central Incisor", |
|
|
42: "Lower Right Lateral Incisor", 43: "Lower Right Canine", 44: "Lower Right First Premolar", |
|
|
45: "Lower Right Second Premolar", 46: "Lower Right First Molar", 47: "Lower Right Second Molar", |
|
|
48: "Lower Right Third Molar" |
|
|
} |
|
|
|
|
|
|
|
|
num_labels = max(label_dict.keys()) + 1 |
|
|
colors = np.vstack([ |
|
|
[[0, 0, 0, 0]], |
|
|
cm.get_cmap('tab20b')(np.linspace(0, 1, 20)), |
|
|
cm.get_cmap('tab20c')(np.linspace(0, 1, 20)), |
|
|
cm.get_cmap('gist_rainbow')(np.linspace(0, 1, num_labels)) |
|
|
])[:, :4] |
|
|
colors = colors[:num_labels] |
|
|
colormap = ListedColormap(colors) |
|
|
|
|
|
|
|
|
vol_img = pv.wrap(img_data) |
|
|
vol_mask = pv.wrap(mask_data) |
|
|
|
|
|
|
|
|
plotter = pv.Plotter(window_size=[800, 800]) |
|
|
plotter.add_volume(vol_img, cmap="bone", opacity="sigmoid", name="CT Scan") |
|
|
plotter.add_volume( |
|
|
vol_mask, |
|
|
cmap=colormap, |
|
|
opacity=[0, 0.5], |
|
|
mapper='gpu', |
|
|
name="Segmentation Mask" |
|
|
) |
|
|
plotter.camera_position = 'xy' |
|
|
|
|
|
return plotter |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
st.set_page_config(layout="wide", page_title="nnU-Net Inference App") |
|
|
|
|
|
st.title("π¦· nnU-Net Inference and 3D Visualization") |
|
|
st.markdown("Upload a medical image, run nnU-Net for segmentation, and visualize the results in 3D.") |
|
|
|
|
|
|
|
|
st.sidebar.header("1. Configure Model") |
|
|
|
|
|
default_model_path = "/path/to/your/nnUNet_results/Dataset114_ToothFairy2/nnUNetTrainer__nnUNetPlans__3d_fullres" |
|
|
model_folder = st.sidebar.text_input( |
|
|
"Enter path to trained model folder:", |
|
|
value=default_model_path |
|
|
) |
|
|
|
|
|
if not os.path.isdir(model_folder): |
|
|
st.sidebar.error("Model folder not found. Please provide a valid path.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
predictor = load_predictor(model_folder) |
|
|
if predictor is None: |
|
|
st.stop() |
|
|
|
|
|
st.sidebar.header("2. Upload Image") |
|
|
uploaded_file = st.sidebar.file_uploader( |
|
|
"Choose a NIfTI file (.nii.gz)", |
|
|
type=['nii.gz'] |
|
|
) |
|
|
|
|
|
|
|
|
if uploaded_file is not None: |
|
|
if st.sidebar.button("β¨ Run Prediction and Visualize"): |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
input_dir = os.path.join(temp_dir, 'input') |
|
|
output_dir = os.path.join(temp_dir, 'output') |
|
|
os.makedirs(input_dir, exist_ok=True) |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
base_name = uploaded_file.name.replace(".nii.gz", "") |
|
|
input_file_path = os.path.join(input_dir, f"{base_name}_0000.nii.gz") |
|
|
|
|
|
with open(input_file_path, "wb") as f: |
|
|
f.write(uploaded_file.getbuffer()) |
|
|
|
|
|
st.info(f"File '{uploaded_file.name}' saved to temporary location.") |
|
|
|
|
|
|
|
|
with st.spinner("π§ Running nnU-Net inference... This can take a while."): |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
predictor.predict_from_files( |
|
|
input_dir, |
|
|
output_dir, |
|
|
save_probabilities=False, |
|
|
overwrite=True, |
|
|
num_processes_preprocessing=2, |
|
|
num_processes_segmentation_export=2 |
|
|
) |
|
|
|
|
|
end_time = time.time() |
|
|
st.success(f"Inference complete! π (Time taken: {end_time - start_time:.2f} seconds)") |
|
|
|
|
|
|
|
|
output_files = os.listdir(output_dir) |
|
|
if not output_files: |
|
|
st.error("Prediction failed. No output file was generated.") |
|
|
st.stop() |
|
|
|
|
|
output_mask_path = os.path.join(output_dir, output_files[0]) |
|
|
|
|
|
|
|
|
with st.spinner("π¨ Generating 3D visualization..."): |
|
|
plotter = generate_visualization(input_file_path, output_mask_path) |
|
|
stpyvista(plotter, key="pv_plot") |
|
|
|
|
|
|
|
|
with open(output_mask_path, "rb") as f: |
|
|
st.download_button( |
|
|
label="β¬οΈ Download Segmentation Mask", |
|
|
data=f, |
|
|
file_name=f"predicted_{uploaded_file.name}", |
|
|
mime="application/gzip" |
|
|
) |
|
|
|
|
|
else: |
|
|
st.info("Please upload a file to begin.") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |