Spaces:
Running
Running
# Code copied and modified from: https://huggingface.co/spaces/BAAI/SegVol | |
import tempfile | |
from pathlib import Path | |
import nibabel as nib | |
import numpy as np | |
from PIL import ImageDraw | |
from streamlit_drawable_canvas import st_canvas | |
from streamlit_image_coordinates import streamlit_image_coordinates | |
import nibabel as nib | |
import SimpleITK as sitk | |
import streamlit as st | |
import utils | |
from utils import ( | |
initial_rectangle, | |
make_fig, | |
reflect_box_into_model, | |
reflect_json_data_to_3D_box, | |
run, | |
) | |
# from viewer import BasicViewer | |
print("script run") | |
st.title("MRSegmentator") | |
############################################# | |
# init session_state | |
if "option" not in st.session_state: | |
st.session_state.option = None | |
if "reset_demo_case" not in st.session_state: | |
st.session_state.reset_demo_case = False | |
if "preds_3D" not in st.session_state: | |
st.session_state.preds_3D = None | |
st.session_state.preds_path = None | |
if "data_item" not in st.session_state: | |
st.session_state.data_item = None | |
if "rectangle_3Dbox" not in st.session_state: | |
st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0] | |
if "running" not in st.session_state: | |
st.session_state.running = False | |
if "transparency" not in st.session_state: | |
st.session_state.transparency = 0.25 | |
case_list = [ | |
"images/amos_0541_MRI.nii.gz", | |
"images/amos_0571_MRI.nii.gz", | |
"images/amos_0001_CT.nii.gz", | |
] | |
############################################# | |
############################################# | |
# reset functions | |
def clear_prompts(): | |
st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0] | |
def reset_demo_case(): | |
st.session_state.data_item = None | |
st.session_state.reset_demo_case = True | |
clear_prompts() | |
def clear_file(): | |
st.session_state.option = None | |
reset_demo_case() | |
clear_prompts() | |
############################################# | |
github_col, arxive_col = st.columns(2) | |
with github_col: | |
st.write("Git: https://github.com/hhaentze/mrsegmentator") | |
with arxive_col: | |
st.write("Paper: https://arxiv.org/abs/2405.06463") | |
# modify demo case here | |
demo_type = st.radio("Demo case source", ["Select", "Upload"], on_change=clear_file) | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
# modify demo case here | |
if demo_type == "Select": | |
uploaded_file = st.selectbox( | |
"Select a demo case", | |
case_list, | |
index=None, | |
placeholder="Select a demo case...", | |
on_change=reset_demo_case, | |
) | |
else: | |
uploaded_file = st.file_uploader( | |
"Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case | |
) | |
if( uploaded_file is not None ): | |
with open(tmpdirname + "/" + uploaded_file.name, 'wb') as f: | |
f.write(uploaded_file.getvalue()) | |
uploaded_file = tmpdirname + "/" + uploaded_file.name | |
st.session_state.option = uploaded_file | |
if ( | |
st.session_state.option is not None | |
and st.session_state.reset_demo_case | |
or (st.session_state.data_item is None and st.session_state.option is not None) | |
): | |
st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file)) | |
st.session_state.data_item_ori = sitk.ReadImage(Path(__file__).parent / str(uploaded_file)) | |
st.session_state.reset_demo_case = False | |
st.session_state.preds_3D = None | |
st.session_state.preds_path = None | |
if st.session_state.option is None: | |
st.write("please select demo case first") | |
else: | |
image_3D = st.session_state.data_item | |
px_range = st.slider( "Select intensity range", | |
int(image_3D.min()), | |
int(image_3D.max()), | |
(int(image_3D.min()), int(image_3D.max())) | |
) | |
col_control1, col_control2 = st.columns(2) | |
with col_control1: | |
selected_index_z = st.slider( | |
"Axial view", 0, image_3D.shape[0] - 1, image_3D.shape[0] // 2, key="xy", disabled=st.session_state.running | |
) | |
with col_control2: | |
selected_index_y = st.slider( | |
"Coronal view", 0, image_3D.shape[1] - 1, image_3D.shape[1] // 2, key="xz", disabled=st.session_state.running | |
) | |
col_image1, col_image2 = st.columns(2) | |
if st.session_state.preds_3D is not None: | |
st.session_state.transparency = st.slider( | |
"Mask opacity", 0.0, 1.0, 0.5, disabled=st.session_state.running | |
) | |
with col_image1: | |
image_z_array = image_3D[selected_index_z] | |
preds_z_array = None | |
if st.session_state.preds_3D is not None: | |
preds_z_array = st.session_state.preds_3D[selected_index_z] | |
image_z = make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency) | |
st.image(image_z, use_column_width=False) | |
with col_image2: | |
image_y_array = image_3D[:, selected_index_y, :] | |
preds_y_array = None | |
if st.session_state.preds_3D is not None: | |
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :] | |
image_y = make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency) | |
st.image(image_y, use_column_width=False) | |
###################################################### | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
if st.button( | |
"Clear", | |
use_container_width=True, | |
disabled=(st.session_state.option is None or (st.session_state.preds_3D is None)), | |
): | |
clear_prompts() | |
st.session_state.preds_3D = None | |
st.session_state.preds_path = None | |
st.rerun() | |
with col2: | |
if st.session_state.preds_3D is not None and st.session_state.data_item is not None: | |
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile: | |
preds = st.session_state.preds_3D_ori | |
#result_image.CopyInformation(inputImage) | |
sitk.WriteImage(preds, tmpfile.name) | |
#nib.save(st.session_state.preds_3D, tmpfile.name) | |
with open(tmpfile.name, "rb") as f: | |
bytes_data = f.read() | |
st.download_button( | |
label="Download result(.nii.gz)", | |
data=bytes_data, | |
file_name="segmentation.nii.gz", | |
mime="application/octet-stream", | |
disabled=False, | |
) | |
with col3: | |
run_button_name = "Run" if not st.session_state.running else "Running" | |
if st.button( | |
run_button_name, | |
type="primary", | |
use_container_width=True, | |
disabled=(st.session_state.data_item is None or st.session_state.running), | |
): | |
st.session_state.running = True | |
st.rerun() | |
if st.session_state.running: | |
st.session_state.running = False | |
with st.status("Running...", expanded=False) as status: | |
run(tmpdirname) | |
st.rerun() | |