import streamlit as st st.set_page_config(layout="wide") import base64 import random import numpy as np import pandas as pd from PIL import Image from streamlit_drawable_canvas import st_canvas from utils import utils DEFAULT_IMG_TAG = 'model-architecture' with open("figures/medsam.png", "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') IMAGE_TAG_BASE64 = f'model-architecture' PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model(checkpoint='checkpoint/medsam_vit_b.pth') def process_box(predictor_model, show_mask, radius_width): bg_image = st.session_state['image'] width, height = bg_image.size[:2] container_width = 700 scale = container_width/width scaled_wh = (container_width, int(height * scale)) if not predictor_model.is_image_set: np_image = np.asanyarray(bg_image) with st.spinner(text="Extracing embeddings.."): predictor_model.set_image(np_image) if 'result_image' not in st.session_state: st.session_state.result_image = bg_image.resize(scaled_wh) box_canvas = st_canvas( fill_color="rgba(255, 255, 0, 0)", background_image = bg_image, drawing_mode='rect', stroke_color = "rgba(0, 255, 0, 0.6)", stroke_width = radius_width, width = container_width, height = height * scale, point_display_radius = 12, update_streamlit=True, key="box" ) if not show_mask: if 'rerun_once' in st.session_state: if st.session_state.rerun_once: st.session_state.rerun_once = False else: st.session_state.rerun_once = True st.session_state.display_result = True st.warning("Mask view is disabled", icon="❗") if st.session_state.rerun_once: st.experimental_rerun() else: return np.asarray(bg_image) elif box_canvas.json_data is not None: df = pd.json_normalize(box_canvas.json_data["objects"]) center_point,center_label,input_box = [],[],[] center_point, center_label, input_box = [], [], [] for _, row in df.iterrows(): x, y, w,h = row["left"], row["top"], row["width"], row["height"] x = int(x/scale) y = int(y/scale) w = int(w/scale) h = int(h/scale) center_point.append([x+w/2,y+h/2]) center_label.append([1]) input_box.append([x,y,x+w,y+h]) masks = [] if predictor_model: masks = utils.model_predict_masks_box(predictor_model, center_point, center_label, input_box) if len(masks) == 0: st.warning("No Masks Found", icon="❗") return np.asarray(bg_image) bg_image = np.asarray(bg_image) color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) im_masked = utils.show_click(masks,color) im_masked = Image.fromarray(im_masked).convert('RGBA') result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") result_image = result_image.resize(scaled_wh) st.session_state.display_result = True return result_image else: return np.asarray(bg_image) return np.asarray(bg_image) def process_click(predictor_model, show_mask, radius_width): bg_image = st.session_state['image'] width, height = bg_image.size[:2] container_width = 700 scale = container_width/width scaled_wh = (container_width, int(height * scale)) if not predictor_model.is_image_set: np_image = np.asanyarray(bg_image) with st.spinner(text="Extracing embeddings.."): predictor_model.set_image(np_image) if 'result_image' not in st.session_state: st.session_state.result_image = bg_image.resize(scaled_wh) click_canvas = st_canvas( fill_color="rgba(255, 255, 0, 0.8)", background_image = bg_image, drawing_mode='point', width = container_width, height = height * scale, point_display_radius = radius_width, stroke_width=2, update_streamlit=True, key="point",) if not show_mask: if 'rerun_once' in st.session_state: if st.session_state.rerun_once: st.session_state.rerun_once = False else: st.session_state.rerun_once = True st.session_state.display_result = True st.warning("Mask view is disabled", icon="❗") if st.session_state.rerun_once: st.experimental_rerun() else: return np.asarray(bg_image) elif click_canvas.json_data is not None: df = pd.json_normalize(click_canvas.json_data["objects"]) input_points = [] input_labels = [] for _, row in df.iterrows(): x, y = int(row["left"] + row["width"]/2), int(row["top"] + row["height"]/2) x = int(x/scale) y = int(y/scale) input_points.append([x, y]) if row['fill'] == "rgba(0, 255, 0, 0.8)": input_labels.append(1) else: input_labels.append(0) masks = [] if predictor_model: masks = utils.model_predict_masks_click(predictor_model, input_points, input_labels) if len(masks) == 0: st.warning("No Masks Found", icon="❗") return np.asarray(bg_image) bg_image = np.asarray(bg_image) color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) im_masked = utils.show_click(masks,color) im_masked = Image.fromarray(im_masked).convert('RGBA') result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") result_image = result_image.resize(scaled_wh) st.session_state.display_result = True return result_image else: return np.asarray(bg_image) return np.asarray(bg_image) def process_everything(automask_model, show_mask, radius_width): bg_image = st.session_state['image'] width, height = bg_image.size[:2] container_width = 700 scale = container_width/width scaled_wh = (container_width, int(height * scale)) if 'result_image' not in st.session_state: st.session_state.result_image = bg_image.resize(scaled_wh) dummy_canvas = st_canvas( fill_color="rgba(255, 255, 0, 0.8)", background_image = bg_image, drawing_mode='freedraw', width = container_width, height = height * scale, point_display_radius = radius_width, stroke_width=2, update_streamlit=False, key="everything",) if not show_mask: if 'rerun_once' in st.session_state: if st.session_state.rerun_once: st.session_state.rerun_once = False else: st.session_state.rerun_once = True st.session_state.display_result = True st.warning("Mask view is disabled", icon="❗") if st.session_state.rerun_once: st.experimental_rerun() else: return np.asarray(bg_image) if automask_model: bg_image = np.asarray(bg_image) masks = utils.model_predict_masks_everything(automask_model, bg_image) im_masked = utils.show_everything(masks) if len(im_masked) == 0: st.warning("No Masks Found", icon="❗") return np.asarray(bg_image) im_masked = Image.fromarray(im_masked).convert('RGBA') result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") result_image = result_image.resize(scaled_wh) st.session_state.display_result = True return result_image else: return np.asarray(bg_image) def image_preprocess_callback(predictor_model, option): if 'uploaded_image' not in st.session_state: return if st.session_state.uploaded_image is not None: with st.spinner(text="Uploading image..."): image = Image.open(st.session_state.uploaded_image).convert("RGB") if predictor_model and option != 'Everything': np_image = np.asanyarray(image) with st.spinner(text="Extracing embeddings.."): predictor_model.set_image(np_image) else: if predictor_model: predictor_model.reset_image() st.session_state.image = image else: with st.spinner(text="Cleaning up!"): if 'display_result' in st.session_state: st.session_state.display_result = False if 'image' in st.session_state: st.session_state.image = None if 'result_image' in st.session_state: del st.session_state['result_image'] if predictor_model: predictor_model.reset_image() def main(): with open('index.html', encoding='utf-8') as f: html_content = f.read() html_content = html_content.replace(DEFAULT_IMG_TAG, IMAGE_TAG_BASE64) st.components.v1.html(html_content, width=None, height=925, scrolling=False) with st.container(): col1, col2, col3, col4 = st.columns(4) with col1: option = st.selectbox('Segmentation mode', ('Click', 'Box', 'Everything')) with col2: st.write("Show or Hide Mask") show_mask = st.checkbox('Show mask',value = True) with col3: mask_threshold = st.slider('SAM Confidence Threshold',0.0,1.0,0.5,0.05) PREDICTOR_MODEL.model.mask_threshold = mask_threshold with col4: radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1) with st.container(): st.write("Upload Image") st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(PREDICTOR_MODEL, option,), label_visibility="hidden") result_image = None canvas_input, canvas_output = st.columns(2) if 'image' in st.session_state: with canvas_input: st.write("Select Interest Area/Objects") if st.session_state.image is not None: with st.spinner(text="Computing masks"): if option == 'Click': result_image = process_click(PREDICTOR_MODEL, show_mask, radius_width) elif option == 'Box': result_image = process_box(PREDICTOR_MODEL, show_mask, radius_width) else: result_image = process_everything(AUTOMASK_MODEL, show_mask, radius_width) if 'display_result' in st.session_state: if st.session_state.display_result: with canvas_output: if result_image is not None: st.write("Result") st.image(result_image) else: st.warning("No result found, please set input prompt", icon="⚠️") st.success('Process completed!', icon="✅") else: st.cache_data.clear() if __name__ == '__main__': main()