SegVolOnIDC / app.py
cciausu97's picture
Update app.py
16e917e verified
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from streamlit_image_coordinates import streamlit_image_coordinates
from idc_index import index
import os
import glob
import shutil
import dcm2niix
import subprocess
import random
import base64
from model.data_process.demo_data_process import process_ct_gt
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import monai.transforms as transforms
from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run
import nibabel as nib
import tempfile
print('script run')
#further improvement
#decorator singletion or use cache data class
# https://docs.streamlit.io/develop/api-reference/caching-and-state/st.experimental_singleton
# https://docs.streamlit.io/develop/concepts/architecture/caching
def download_idc_data_serieUID(serieUID_lst, output_folder):
#download IDC data cases
client = index.IDCClient()
#define serieUIDs to download
#download series and convert to .nii.gz
if os.path.exists(output_folder):
shutil.rmtree(output_folder)
os.makedirs(output_folder)
for idx, serieUID_ddl in enumerate(serieUID_lst):
sample_dcm_dir = os.path.join(output_folder, f"ddl_series{idx}_dcm")
sample_nii_dir = os.path.join(output_folder, f"ddl_series{idx}_nii")
for dir in [sample_dcm_dir, sample_nii_dir]:
if os.path.exists(dir):
shutil.rmtree(dir)
os.makedirs(dir)
client.download_from_selection(seriesInstanceUID=serieUID_ddl, downloadDir=sample_dcm_dir)
subprocess.call(["dcm2niix", "-o", sample_nii_dir, "-z", "y",
"-f", "IDC_%i", "-g", "y", sample_dcm_dir])
return glob.glob(os.path.join(output_folder, "*nii/*.nii.gz"))
def get_random_sample_idc_from_bodypart(bodypart_selected):
client = index.IDCClient()
# body_parts = client.index[(client.index['Modality'].isin(['CT']))&(idc_client.index['instanceCount']> '100')]['BodyPartExamined'].unique()
matching_series_list = client.index[client.index['Modality'].isin(["CT"]) \
& (client.index['BodyPartExamined'] == bodypart_selected) & \
(client.index['instanceCount']> '100')]['SeriesInstanceUID'].values
# select random series from the list
random_series_uid = random.choice(matching_series_list)
random_series_viewer_url = client.get_viewer_URL(random_series_uid)
return random_series_uid, random_series_viewer_url
def retrieve_idc_index_body_parts():
idc_client = index.IDCClient()
body_parts = idc_client.index[(idc_client.index['Modality'].isin(['CT']))&(idc_client.index['instanceCount']< '150')]['BodyPartExamined'].unique()
return body_parts
#############################################
st.session_state.option = None
if 'idc_data' not in st.session_state:
case_list = download_idc_data_serieUID(serieUID_lst=["1.3.6.1.4.1.14519.5.2.1.8421.4008.125612661111422710051062993644",
"1.3.6.1.4.1.14519.5.2.1.3344.4008.552105302448832783460360105045",
"1.3.6.1.4.1.14519.5.2.1.3344.4008.217290429362492484143666931850",
"1.3.6.1.4.1.14519.5.2.1.3344.4008.315023636447426194723399171147",
"1.3.6.1.4.1.14519.5.2.1.3344.4008.307374355712319704057189924161"],
output_folder="model/asset/idc_samples")
st.session_state.idc_data = True
else:
case_list = glob.glob("model/asset/idc_samples/*nii/*.nii.gz")
if 'idc_serieUID_sample' not in st.session_state:
st.session_state.idc_serieUID_sample = None
# init session_state
if 'option' not in st.session_state:
st.session_state.option = None
if 'text_prompt' not in st.session_state:
st.session_state.text_prompt = None
if 'reset_demo_case' not in st.session_state:
st.session_state.reset_demo_case = False
if 'preds_3D' not in st.session_state:
st.session_state.preds_3D = None
st.session_state.preds_3D_ori = None
if 'data_item' not in st.session_state:
st.session_state.data_item = None
if 'points' not in st.session_state:
st.session_state.points = []
if 'use_text_prompt' not in st.session_state:
st.session_state.use_text_prompt = False
if 'use_text_serieUID' not in st.session_state:
st.session_state.use_text_serieUID = False
if 'use_point_prompt' not in st.session_state:
st.session_state.use_point_prompt = False
if 'use_box_prompt' not in st.session_state:
st.session_state.use_box_prompt = False
if 'rectangle_3Dbox' not in st.session_state:
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
if 'irregular_box' not in st.session_state:
st.session_state.irregular_box = False
if 'running' not in st.session_state:
st.session_state.running = False
if 'transparency' not in st.session_state:
st.session_state.transparency = 0.25
#############################################
#############################################
# reset functions
def clear_prompts():
st.session_state.points = []
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
def reset_demo_case():
st.session_state.data_item = None
st.session_state.idc_serieUID_sample = None
st.session_state.reset_demo_case = True
st.session_state.idc_bodypart_selected = False
clear_prompts()
def clear_file():
st.session_state.option = None
st.session_state.idc_serieUID_sample = None
st.session_state.idc_bodypart_selected = False
process_ct_gt.clear()
reset_demo_case()
clear_prompts()
#############################################
st.image("idc_intro_extended.jpg")
st.write("Below is an example on how to select a SeriesInstanceUID from Imaging Data Commons (IDC) to further use in this demo:")
st.image("https://github.com/ccosmin97/huggingface_idc_demos/raw/main/idc_serieUID_selection.gif")
st.write("Below is an overview of the SegVol method and authors acknowledgement.")
st.image(Image.open('model/asset/overview back.png'), use_column_width=True)
github_col, arxive_col = st.columns(2)
with github_col:
st.write('SegVol GitHub repo:https://github.com/BAAI-DCAI/SegVol')
with arxive_col:
st.write('SegVol Paper:https://arxiv.org/abs/2311.13385')
# modify demo case here
demo_type = st.radio(
"Demo case source",
["Select an IDC demo case from tcga_lihc collection",
"Filter by DICOM SeriesInstanceUID",
"Random sampling based on BodyPartExamined"],
on_change=clear_file
)
if demo_type=="Select an IDC demo case from tcga_lihc collection":
uploaded_file = st.selectbox(
"Select a demo case",
case_list,
index=None,
placeholder="Select a demo case...",
on_change=reset_demo_case)
elif demo_type=="Filter by DICOM SeriesInstanceUID":
with st.form("Filter by DICOM SeriesInstanceUID"):
uploaded_serieUID = st.text_input("Enter a DICOM SeriesInstanceUID", value=None)
submitted = st.form_submit_button("Submit", on_click=clear_prompts)
if submitted:
st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(uploaded_serieUID).strip()], "model/asset/idc_serieUID_sample")[0]
# st.session_state.option = uploaded_file
uploaded_file = st.session_state.idc_serieUID_sample
else:
uploaded_file = st.session_state.idc_serieUID_sample
else:#elif demo_type == "Random sampling based on BodyPartExamined":
with st.form("Filter by DICOM BodyPartExamined Tag") as form_body_part:
# body_part_list = retrieve_idc_index_body_parts()
body_part_selected = st.selectbox(
"Select a bodypart to randomly sample a CT scan from",
["ABDOMEN", "LUNG", "LIVER",
"PELVIS"],
index=None,
placeholder="Select a bodypart to pick a SeriesInstanceUID from...")
submitted = st.form_submit_button("Submit", on_click=reset_demo_case)
#if st.session_state.reset_demo_case == True and body_part_selected is not None:# and st.session_state.idc_bodypart_selected == False and
if submitted:
serieUID, ohif_link = get_random_sample_idc_from_bodypart(body_part_selected)
for i in range(0,5):
if os.path.exists("model/asset/idc_serieUID_random_sample"):
shutil.rmtree("model/asset/idc_serieUID_random_sample")
st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(serieUID)], "model/asset/idc_serieUID_random_sample")[0]
path_file = glob.glob(f"model/asset/idc_serieUID_random_sample/ddl_series0_nii/*.nii.gz")
if path_file and len(path_file) == 1:
break
else:
print("serieUID NOT FILLING BASIC REQs --> MORE THAN 1 NII FILE OR NO NII FILE")
# st.write(f"SeriesInstanceUID randomly sampled from chosen BodyPartExamined : {random_series_uid}")
# st.write(f"OHIF URL of selected sample : {random_series_viewer_url}")
# st.session_state.idc_bodypart_selected = True
uploaded_file = st.session_state.idc_serieUID_sample
else:
uploaded_file = st.session_state.idc_serieUID_sample
st.session_state.option = uploaded_file
if st.session_state.option is not None and \
st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
st.session_state.data_item = process_ct_gt(st.session_state.option)
st.session_state.reset_demo_case = False
st.session_state.preds_3D = None
st.session_state.preds_3D_ori = None
prompt_col1, prompt_col2 = st.columns(2)
with prompt_col1:
st.session_state.use_text_prompt = st.toggle('Semantic prompt')
text_prompt_type = st.radio(
"Semantic prompt type",
["Predefined", "Custom"],
disabled=(not st.session_state.use_text_prompt)
)
if text_prompt_type == "Predefined":
pre_text = st.selectbox(
"Predefined anatomical category:",
['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
index=None,
disabled=(not st.session_state.use_text_prompt)
)
else:
pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
disabled=(not st.session_state.use_text_prompt))
if pre_text is None or len(pre_text) > 0:
st.session_state.text_prompt = pre_text
else:
st.session_state.text_prompt = None
with prompt_col2:
spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
spatial_prompt = st.radio(
"Spatial prompt type",
["Point prompt", "Box prompt"],
on_change=clear_prompts,
disabled=(not spatial_prompt_on))
st.session_state.enforce_zoom = st.checkbox('Enforce zoom-out-zoom-in')
if spatial_prompt == "Point prompt":
st.session_state.use_point_prompt = True
st.session_state.use_box_prompt = False
elif spatial_prompt == "Box prompt":
st.session_state.use_box_prompt = True
st.session_state.use_point_prompt = False
else:
st.session_state.use_point_prompt = False
st.session_state.use_box_prompt = False
if not spatial_prompt_on:
st.session_state.use_point_prompt = False
st.session_state.use_box_prompt = False
if not st.session_state.use_text_prompt:
st.session_state.text_prompt = None
if st.session_state.option is None:
st.write('please select demo case first')
else:
image_3D = st.session_state.data_item['z_image'][0].numpy()
col_control1, col_control2 = st.columns(2)
with col_control1:
selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running)
with col_control2:
selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running)
if st.session_state.use_box_prompt:
top, bottom = st.select_slider(
'Top and bottom of box',
options=range(0, 325),
value=(0, 324),
disabled=st.session_state.running
)
st.session_state.rectangle_3Dbox[0] = top
st.session_state.rectangle_3Dbox[3] = bottom
col_image1, col_image2 = st.columns(2)
if st.session_state.preds_3D is not None:
st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
with col_image1:
image_z_array = image_3D[selected_index_z]
preds_z_array = None
if st.session_state.preds_3D is not None:
preds_z_array = st.session_state.preds_3D[selected_index_z]
image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
if st.session_state.use_point_prompt:
value_xy = streamlit_image_coordinates(image_z, width=325)
if value_xy is not None:
point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
if len(st.session_state.points) >= 3:
st.warning('Max point num is 3', icon="??")
elif point_ax_xy not in st.session_state.points:
st.session_state.points.append(point_ax_xy)
print('point_ax_xy add rerun')
st.rerun()
elif st.session_state.use_box_prompt:
canvas_result_xy = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=3,
stroke_color='#2909F1',
background_image=image_z,
update_streamlit=True,
height=325,
width=325,
drawing_mode='transform',
point_display_radius=0,
key="canvas_xy",
initial_drawing=initial_rectangle,
display_toolbar=True
)
try:
print(canvas_result_xy.json_data['objects'][0]['angle'])
if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
st.warning('Rotating is undefined behavior', icon="??")
st.session_state.irregular_box = True
else:
st.session_state.irregular_box = False
reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
except:
print('exception')
pass
else:
st.image(image_z, use_column_width=False)
with col_image2:
image_y_array = image_3D[:, selected_index_y, :]
preds_y_array = None
if st.session_state.preds_3D is not None:
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
if st.session_state.use_point_prompt:
value_yz = streamlit_image_coordinates(image_y, width=325)
if value_yz is not None:
point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
if len(st.session_state.points) >= 3:
st.warning('Max point num is 3', icon="??")
elif point_ax_xz not in st.session_state.points:
st.session_state.points.append(point_ax_xz)
print('point_ax_xz add rerun')
st.rerun()
elif st.session_state.use_box_prompt:
if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
draw = ImageDraw.Draw(image_y)
#rectangle xz view (upper-left and lower-right)
rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
(st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
# Draw the rectangle on the image
draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
st.image(image_y, use_column_width=False)
else:
st.image(image_y, use_column_width=False)
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Clear", use_container_width=True,
disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
clear_prompts()
st.session_state.preds_3D = None
st.session_state.preds_3D_ori = None
st.rerun()
with col2:
img_nii = None
if st.session_state.preds_3D_ori is not None and st.session_state.data_item is not None:
meta_dict = st.session_state.data_item['meta']
foreground_start_coord = st.session_state.data_item['foreground_start_coord']
foreground_end_coord = st.session_state.data_item['foreground_end_coord']
original_shape = st.session_state.data_item['ori_shape']
pred_array = st.session_state.preds_3D_ori
original_array = np.zeros(original_shape)
original_array[foreground_start_coord[0]:foreground_end_coord[0],
foreground_start_coord[1]:foreground_end_coord[1],
foreground_start_coord[2]:foreground_end_coord[2]] = pred_array
original_array = original_array.transpose(2, 1, 0)
img_nii = nib.Nifti1Image(original_array, affine=meta_dict['affine'])
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
nib.save(img_nii, tmpfile.name)
with open(tmpfile.name, "rb") as f:
bytes_data = f.read()
st.download_button(
label="Download result(.nii.gz)",
data=bytes_data,
file_name="segvol_preds.nii.gz",
mime="application/octet-stream",
disabled=img_nii is None
)
with col3:
run_button_name = 'Run'if not st.session_state.running else 'Running'
if st.button(run_button_name, type="primary", use_container_width=True,
disabled=(
st.session_state.data_item is None or
(st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
st.session_state.irregular_box or
st.session_state.running
)):
st.session_state.running = True
st.rerun()
if st.session_state.running:
st.session_state.running = False
with st.status("Running...", expanded=False) as status:
run()
st.rerun()