# Code copied and modified from: https://huggingface.co/spaces/BAAI/SegVol import tempfile from pathlib import Path import SimpleITK as sitk from mrsegmentator.utils import add_postfix import streamlit as st import utils print("script run") st.title("MRSegmentator") st.write("(On-site segmentation is currently disabled, because we lack access to GPUs)") ############################################# # init session_state if "option" not in st.session_state: st.session_state.option = None if "reset_demo_case" not in st.session_state: st.session_state.reset_demo_case = False if "preds_3D" not in st.session_state: st.session_state.preds_3D = None st.session_state.preds_path = None if "data_item" not in st.session_state: st.session_state.data_item = None if "rectangle_3Dbox" not in st.session_state: st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0] if "running" not in st.session_state: st.session_state.running = False if "transparency" not in st.session_state: st.session_state.transparency = 0.25 case_list = [ "amos_0517_MRI.nii.gz", "amos_0541_MRI.nii.gz", "amos_0571_MRI.nii.gz", ] ############################################# ############################################# # reset functions def clear_prompts(): st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0] def reset_demo_case(): st.session_state.data_item = None st.session_state.reset_demo_case = True st.session_state.preds_3D = None st.session_state.preds_3D_ori = None clear_prompts() def clear_file(): st.session_state.option = None reset_demo_case() clear_prompts() ############################################# github_col, arxive_col = st.columns(2) with github_col: st.write("Git: https://github.com/hhaentze/mrsegmentator") with arxive_col: st.write("Paper: https://arxiv.org/abs/2405.06463") # modify demo case here demo_type = st.radio("Demo case source", ["Select (presegmented)", "Upload"], on_change=clear_file) with tempfile.TemporaryDirectory() as tmpdirname: # modify demo case here if demo_type == "Select (presegmented)": selection = st.selectbox( "Select a demo case", case_list, index=None, placeholder="Select a demo case...", on_change=reset_demo_case, ) if selection: uploaded_file = "images/" + selection seg_path = Path(__file__).parent / ("segmentations/" + add_postfix(selection, "seg")) st.session_state.preds_3D = utils.read_image(seg_path) st.session_state.preds_3D_ori = sitk.ReadImage(seg_path) else: uploaded_file = None else: uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case) if uploaded_file is not None: with open(tmpdirname + "/" + uploaded_file.name, "wb") as f: f.write(uploaded_file.getvalue()) uploaded_file = tmpdirname + "/" + uploaded_file.name st.session_state.option = uploaded_file if ( st.session_state.option is not None and st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None) ): st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file)) # st.session_state.preds_3D = None # st.session_state.preds_3D_ori = None st.session_state.reset_demo_case = False if st.session_state.option is None: st.write("please select demo case first") else: image_3D = st.session_state.data_item px_range = st.slider( "Select intensity range", int(image_3D.min()), int(image_3D.max()), (int(image_3D.min()), int(image_3D.max())), ) col_control1, col_control2 = st.columns(2) with col_control1: selected_index_z = st.slider( "Axial view", 0, image_3D.shape[0] - 1, image_3D.shape[0] // 2, key="xy", disabled=st.session_state.running, ) with col_control2: selected_index_y = st.slider( "Coronal view", 0, image_3D.shape[1] - 1, image_3D.shape[1] // 2, key="xz", disabled=st.session_state.running, ) col_image1, col_image2 = st.columns(2) if st.session_state.preds_3D is not None: st.session_state.transparency = st.slider("Mask opacity", 0.0, 1.0, 0.35, disabled=st.session_state.running) with col_image1: image_z_array = image_3D[selected_index_z] preds_z_array = None if st.session_state.preds_3D is not None: preds_z_array = st.session_state.preds_3D[selected_index_z] image_z = utils.make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency) st.image(image_z, use_column_width=False) with col_image2: image_y_array = image_3D[:, selected_index_y, :] preds_y_array = None if st.session_state.preds_3D is not None: preds_y_array = st.session_state.preds_3D[:, selected_index_y, :] image_y = utils.make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency) st.image(image_y, use_column_width=False) ###################################################### col1, col2, col3 = st.columns(3) with col1: st.markdown("#") st.markdown("####") st.markdown("####") if st.button( "Clear", use_container_width=True, disabled=(st.session_state.option is None or (st.session_state.preds_3D is None)), ): clear_prompts() st.session_state.preds_3D = None st.session_state.preds_path = None st.rerun() with col2: st.markdown("#") st.markdown("####") st.markdown("####") if st.session_state.preds_3D is not None and st.session_state.data_item is not None: with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile: preds = st.session_state.preds_3D_ori sitk.WriteImage(preds, tmpfile.name) with open(tmpfile.name, "rb") as f: bytes_data = f.read() st.download_button( label="Download result (.nii.gz)", data=bytes_data, file_name="segmentation.nii.gz", mime="application/octet-stream", disabled=False, ) with col3: folds = st.radio("", ["Model of Fold 1 (fast)", "Ensemble Segmentation"]) if folds == "Model of Fold 1": st.session_state.folds = (0,) else: st.session_state.folds = ( 0, 1, 2, 3, 4, ) run_button_name = "Run" if not st.session_state.running else "Running" if st.button( run_button_name, type="primary", use_container_width=True, disabled=True, # disabled=(st.session_state.data_item is None or st.session_state.running), ): st.session_state.running = True st.rerun() if st.session_state.running: st.session_state.running = False with st.status("Running...", expanded=False) as status: utils.run(tmpdirname) st.rerun()