import streamlit as st from streamlit_drawable_canvas import st_canvas from streamlit_image_coordinates import streamlit_image_coordinates from idc_index import index import os import glob import shutil import dcm2niix import subprocess import random import base64 from model.data_process.demo_data_process import process_ct_gt import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageDraw import monai.transforms as transforms from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run import nibabel as nib import tempfile print('script run') #further improvement #decorator singletion or use cache data class # https://docs.streamlit.io/develop/api-reference/caching-and-state/st.experimental_singleton # https://docs.streamlit.io/develop/concepts/architecture/caching def download_idc_data_serieUID(serieUID_lst, output_folder): #download IDC data cases client = index.IDCClient() #define serieUIDs to download #download series and convert to .nii.gz if os.path.exists(output_folder): shutil.rmtree(output_folder) os.makedirs(output_folder) for idx, serieUID_ddl in enumerate(serieUID_lst): sample_dcm_dir = os.path.join(output_folder, f"ddl_series{idx}_dcm") sample_nii_dir = os.path.join(output_folder, f"ddl_series{idx}_nii") for dir in [sample_dcm_dir, sample_nii_dir]: if os.path.exists(dir): shutil.rmtree(dir) os.makedirs(dir) client.download_from_selection(seriesInstanceUID=serieUID_ddl, downloadDir=sample_dcm_dir) subprocess.call(["dcm2niix", "-o", sample_nii_dir, "-z", "y", "-f", "IDC_%i", "-g", "y", sample_dcm_dir]) return glob.glob(os.path.join(output_folder, "*nii/*.nii.gz")) def get_random_sample_idc_from_bodypart(bodypart_selected): client = index.IDCClient() # body_parts = client.index[(client.index['Modality'].isin(['CT']))&(idc_client.index['instanceCount']> '100')]['BodyPartExamined'].unique() matching_series_list = client.index[client.index['Modality'].isin(["CT"]) \ & (client.index['BodyPartExamined'] == bodypart_selected) & \ (client.index['instanceCount']> '100')]['SeriesInstanceUID'].values # select random series from the list random_series_uid = random.choice(matching_series_list) random_series_viewer_url = client.get_viewer_URL(random_series_uid) return random_series_uid, random_series_viewer_url def retrieve_idc_index_body_parts(): idc_client = index.IDCClient() body_parts = idc_client.index[(idc_client.index['Modality'].isin(['CT']))&(idc_client.index['instanceCount']< '150')]['BodyPartExamined'].unique() return body_parts ############################################# st.session_state.option = None if 'idc_data' not in st.session_state: case_list = download_idc_data_serieUID(serieUID_lst=["1.3.6.1.4.1.14519.5.2.1.8421.4008.125612661111422710051062993644", "1.3.6.1.4.1.14519.5.2.1.3344.4008.552105302448832783460360105045", "1.3.6.1.4.1.14519.5.2.1.3344.4008.217290429362492484143666931850", "1.3.6.1.4.1.14519.5.2.1.3344.4008.315023636447426194723399171147", "1.3.6.1.4.1.14519.5.2.1.3344.4008.307374355712319704057189924161"], output_folder="model/asset/idc_samples") st.session_state.idc_data = True else: case_list = glob.glob("model/asset/idc_samples/*nii/*.nii.gz") if 'idc_serieUID_sample' not in st.session_state: st.session_state.idc_serieUID_sample = None # init session_state if 'option' not in st.session_state: st.session_state.option = None if 'text_prompt' not in st.session_state: st.session_state.text_prompt = None if 'reset_demo_case' not in st.session_state: st.session_state.reset_demo_case = False if 'preds_3D' not in st.session_state: st.session_state.preds_3D = None st.session_state.preds_3D_ori = None if 'data_item' not in st.session_state: st.session_state.data_item = None if 'points' not in st.session_state: st.session_state.points = [] if 'use_text_prompt' not in st.session_state: st.session_state.use_text_prompt = False if 'use_text_serieUID' not in st.session_state: st.session_state.use_text_serieUID = False if 'use_point_prompt' not in st.session_state: st.session_state.use_point_prompt = False if 'use_box_prompt' not in st.session_state: st.session_state.use_box_prompt = False if 'rectangle_3Dbox' not in st.session_state: st.session_state.rectangle_3Dbox = [0,0,0,0,0,0] if 'irregular_box' not in st.session_state: st.session_state.irregular_box = False if 'running' not in st.session_state: st.session_state.running = False if 'transparency' not in st.session_state: st.session_state.transparency = 0.25 ############################################# ############################################# # reset functions def clear_prompts(): st.session_state.points = [] st.session_state.rectangle_3Dbox = [0,0,0,0,0,0] def reset_demo_case(): st.session_state.data_item = None st.session_state.idc_serieUID_sample = None st.session_state.reset_demo_case = True st.session_state.idc_bodypart_selected = False clear_prompts() def clear_file(): st.session_state.option = None st.session_state.idc_serieUID_sample = None st.session_state.idc_bodypart_selected = False process_ct_gt.clear() reset_demo_case() clear_prompts() ############################################# st.image("idc_intro_extended.jpg") st.write("Below is an example on how to select a SeriesInstanceUID from Imaging Data Commons (IDC) to further use in this demo:") st.image("https://github.com/ccosmin97/huggingface_idc_demos/raw/main/idc_serieUID_selection.gif") st.write("Below is an overview of the SegVol method and authors acknowledgement.") st.image(Image.open('model/asset/overview back.png'), use_column_width=True) github_col, arxive_col = st.columns(2) with github_col: st.write('SegVol GitHub repo:https://github.com/BAAI-DCAI/SegVol') with arxive_col: st.write('SegVol Paper:https://arxiv.org/abs/2311.13385') # modify demo case here demo_type = st.radio( "Demo case source", ["Select an IDC demo case from tcga_lihc collection", "Filter by DICOM SeriesInstanceUID", "Random sampling based on BodyPartExamined"], on_change=clear_file ) if demo_type=="Select an IDC demo case from tcga_lihc collection": uploaded_file = st.selectbox( "Select a demo case", case_list, index=None, placeholder="Select a demo case...", on_change=reset_demo_case) elif demo_type=="Filter by DICOM SeriesInstanceUID": with st.form("Filter by DICOM SeriesInstanceUID"): uploaded_serieUID = st.text_input("Enter a DICOM SeriesInstanceUID", value=None) submitted = st.form_submit_button("Submit", on_click=clear_prompts) if submitted: st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(uploaded_serieUID).strip()], "model/asset/idc_serieUID_sample")[0] # st.session_state.option = uploaded_file uploaded_file = st.session_state.idc_serieUID_sample else: uploaded_file = st.session_state.idc_serieUID_sample else:#elif demo_type == "Random sampling based on BodyPartExamined": with st.form("Filter by DICOM BodyPartExamined Tag") as form_body_part: # body_part_list = retrieve_idc_index_body_parts() body_part_selected = st.selectbox( "Select a bodypart to randomly sample a CT scan from", ["ABDOMEN", "LUNG", "LIVER", "PELVIS"], index=None, placeholder="Select a bodypart to pick a SeriesInstanceUID from...") submitted = st.form_submit_button("Submit", on_click=reset_demo_case) #if st.session_state.reset_demo_case == True and body_part_selected is not None:# and st.session_state.idc_bodypart_selected == False and if submitted: serieUID, ohif_link = get_random_sample_idc_from_bodypart(body_part_selected) for i in range(0,5): if os.path.exists("model/asset/idc_serieUID_random_sample"): shutil.rmtree("model/asset/idc_serieUID_random_sample") st.session_state.idc_serieUID_sample = download_idc_data_serieUID([str(serieUID)], "model/asset/idc_serieUID_random_sample")[0] path_file = glob.glob(f"model/asset/idc_serieUID_random_sample/ddl_series0_nii/*.nii.gz") if path_file and len(path_file) == 1: break else: print("serieUID NOT FILLING BASIC REQs --> MORE THAN 1 NII FILE OR NO NII FILE") # st.write(f"SeriesInstanceUID randomly sampled from chosen BodyPartExamined : {random_series_uid}") # st.write(f"OHIF URL of selected sample : {random_series_viewer_url}") # st.session_state.idc_bodypart_selected = True uploaded_file = st.session_state.idc_serieUID_sample else: uploaded_file = st.session_state.idc_serieUID_sample st.session_state.option = uploaded_file if st.session_state.option is not None and \ st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None): st.session_state.data_item = process_ct_gt(st.session_state.option) st.session_state.reset_demo_case = False st.session_state.preds_3D = None st.session_state.preds_3D_ori = None prompt_col1, prompt_col2 = st.columns(2) with prompt_col1: st.session_state.use_text_prompt = st.toggle('Semantic prompt') text_prompt_type = st.radio( "Semantic prompt type", ["Predefined", "Custom"], disabled=(not st.session_state.use_text_prompt) ) if text_prompt_type == "Predefined": pre_text = st.selectbox( "Predefined anatomical category:", ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'], index=None, disabled=(not st.session_state.use_text_prompt) ) else: pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20, disabled=(not st.session_state.use_text_prompt)) if pre_text is None or len(pre_text) > 0: st.session_state.text_prompt = pre_text else: st.session_state.text_prompt = None with prompt_col2: spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts) spatial_prompt = st.radio( "Spatial prompt type", ["Point prompt", "Box prompt"], on_change=clear_prompts, disabled=(not spatial_prompt_on)) st.session_state.enforce_zoom = st.checkbox('Enforce zoom-out-zoom-in') if spatial_prompt == "Point prompt": st.session_state.use_point_prompt = True st.session_state.use_box_prompt = False elif spatial_prompt == "Box prompt": st.session_state.use_box_prompt = True st.session_state.use_point_prompt = False else: st.session_state.use_point_prompt = False st.session_state.use_box_prompt = False if not spatial_prompt_on: st.session_state.use_point_prompt = False st.session_state.use_box_prompt = False if not st.session_state.use_text_prompt: st.session_state.text_prompt = None if st.session_state.option is None: st.write('please select demo case first') else: image_3D = st.session_state.data_item['z_image'][0].numpy() col_control1, col_control2 = st.columns(2) with col_control1: selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running) with col_control2: selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running) if st.session_state.use_box_prompt: top, bottom = st.select_slider( 'Top and bottom of box', options=range(0, 325), value=(0, 324), disabled=st.session_state.running ) st.session_state.rectangle_3Dbox[0] = top st.session_state.rectangle_3Dbox[3] = bottom col_image1, col_image2 = st.columns(2) if st.session_state.preds_3D is not None: st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running) with col_image1: image_z_array = image_3D[selected_index_z] preds_z_array = None if st.session_state.preds_3D is not None: preds_z_array = st.session_state.preds_3D[selected_index_z] image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy') if st.session_state.use_point_prompt: value_xy = streamlit_image_coordinates(image_z, width=325) if value_xy is not None: point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x']) if len(st.session_state.points) >= 3: st.warning('Max point num is 3', icon="??") elif point_ax_xy not in st.session_state.points: st.session_state.points.append(point_ax_xy) print('point_ax_xy add rerun') st.rerun() elif st.session_state.use_box_prompt: canvas_result_xy = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity stroke_width=3, stroke_color='#2909F1', background_image=image_z, update_streamlit=True, height=325, width=325, drawing_mode='transform', point_display_radius=0, key="canvas_xy", initial_drawing=initial_rectangle, display_toolbar=True ) try: print(canvas_result_xy.json_data['objects'][0]['angle']) if canvas_result_xy.json_data['objects'][0]['angle'] != 0: st.warning('Rotating is undefined behavior', icon="??") st.session_state.irregular_box = True else: st.session_state.irregular_box = False reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy') except: print('exception') pass else: st.image(image_z, use_column_width=False) with col_image2: image_y_array = image_3D[:, selected_index_y, :] preds_y_array = None if st.session_state.preds_3D is not None: preds_y_array = st.session_state.preds_3D[:, selected_index_y, :] image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz') if st.session_state.use_point_prompt: value_yz = streamlit_image_coordinates(image_y, width=325) if value_yz is not None: point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x']) if len(st.session_state.points) >= 3: st.warning('Max point num is 3', icon="??") elif point_ax_xz not in st.session_state.points: st.session_state.points.append(point_ax_xz) print('point_ax_xz add rerun') st.rerun() elif st.session_state.use_box_prompt: if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]: draw = ImageDraw.Draw(image_y) #rectangle xz view (upper-left and lower-right) rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]), (st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])] # Draw the rectangle on the image draw.rectangle(rectangle_coords, outline='#2909F1', width=3) st.image(image_y, use_column_width=False) else: st.image(image_y, use_column_width=False) col1, col2, col3 = st.columns(3) with col1: if st.button("Clear", use_container_width=True, disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))): clear_prompts() st.session_state.preds_3D = None st.session_state.preds_3D_ori = None st.rerun() with col2: img_nii = None if st.session_state.preds_3D_ori is not None and st.session_state.data_item is not None: meta_dict = st.session_state.data_item['meta'] foreground_start_coord = st.session_state.data_item['foreground_start_coord'] foreground_end_coord = st.session_state.data_item['foreground_end_coord'] original_shape = st.session_state.data_item['ori_shape'] pred_array = st.session_state.preds_3D_ori original_array = np.zeros(original_shape) original_array[foreground_start_coord[0]:foreground_end_coord[0], foreground_start_coord[1]:foreground_end_coord[1], foreground_start_coord[2]:foreground_end_coord[2]] = pred_array original_array = original_array.transpose(2, 1, 0) img_nii = nib.Nifti1Image(original_array, affine=meta_dict['affine']) with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile: nib.save(img_nii, tmpfile.name) with open(tmpfile.name, "rb") as f: bytes_data = f.read() st.download_button( label="Download result(.nii.gz)", data=bytes_data, file_name="segvol_preds.nii.gz", mime="application/octet-stream", disabled=img_nii is None ) with col3: run_button_name = 'Run'if not st.session_state.running else 'Running' if st.button(run_button_name, type="primary", use_container_width=True, disabled=( st.session_state.data_item is None or (st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or st.session_state.irregular_box or st.session_state.running )): st.session_state.running = True st.rerun() if st.session_state.running: st.session_state.running = False with st.status("Running...", expanded=False) as status: run() st.rerun()