import streamlit.components.v1 as components import numpy import sahi.predict import sahi.utils from PIL import Image import base64 import io import os import uuid TEMP_DIR = "temp" def sahi_mmdet_inference( image, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2, image_size=640, postprocess_type="UNIONMERGE", postprocess_match_metric="IOS", postprocess_match_threshold=0.5, postprocess_class_agnostic=False, ): # standard inference prediction_result_1 = sahi.predict.get_prediction( image=image, detection_model=detection_model, image_size=image_size ) visual_result_1 = sahi.utils.cv.visualize_object_predictions( image=numpy.array(image), object_prediction_list=prediction_result_1.object_prediction_list, ) output_1 = Image.fromarray(visual_result_1["image"]) # sliced inference prediction_result_2 = sahi.predict.get_sliced_prediction( image=image, detection_model=detection_model, image_size=image_size, slice_height=slice_height, slice_width=slice_width, overlap_height_ratio=overlap_height_ratio, overlap_width_ratio=overlap_width_ratio, postprocess_type=postprocess_type, postprocess_match_metric=postprocess_match_metric, postprocess_match_threshold=postprocess_match_threshold, postprocess_class_agnostic=postprocess_class_agnostic, ) visual_result_2 = sahi.utils.cv.visualize_object_predictions( image=numpy.array(image), object_prediction_list=prediction_result_2.object_prediction_list, ) output_2 = Image.fromarray(visual_result_2["image"]) return output_1, output_2 def pillow_to_base64(image: Image.Image): in_mem_file = io.BytesIO() image.save(in_mem_file, format="JPEG", subsampling=0, quality=100) img_bytes = in_mem_file.getvalue() # bytes image_str = base64.b64encode(img_bytes).decode("utf-8") base64_src = f"data:image/jpg;base64,{image_str}" return base64_src def local_file_to_base64(image_path: str): file_ = open(image_path, "rb") img_bytes = file_.read() image_str = base64.b64encode(img_bytes).decode("utf-8") file_.close() base64_src = f"data:image/jpg;base64,{image_str}" return base64_src def pillow_local_file_to_base64(image: Image.Image): # pillow to local file img_path = TEMP_DIR + "/" + str(uuid.uuid4()) + ".jpg" image.save(img_path, subsampling=0, quality=100) # local file base64 str base64_src = local_file_to_base64(img_path) return base64_src def image_comparison( img1: str, img2: str, label1: str = "1", label2: str = "2", width: int = 700, show_labels: bool = True, starting_position: int = 50, make_responsive: bool = True, in_memory=False, ): """Create a new juxtapose component. Parameters ---------- img1: str, PosixPath, PIL.Image or URL Input image to compare img2: str, PosixPath, PIL.Image or URL Input image to compare label1: str or None Label for image 1 label2: str or None Label for image 2 width: int or None Width of the component in px show_labels: bool or None Show given labels on images starting_position: int or None Starting position of the slider as percent (0-100) make_responsive: bool or None Enable responsive mode in_memory: bool or None Handle pillow to base64 conversion in memory without saving to local Returns ------- static_component: Boolean Returns a static component with a timeline """ # prepare images img_width, img_height = img1.size h_to_w = img_height / img_width height = (width * h_to_w) * 0.95 img1_pillow = sahi.utils.cv.read_image_as_pil(img1) img2_pillow = sahi.utils.cv.read_image_as_pil(img2) if in_memory: # create base64 str from pillow images img1 = pillow_to_base64(img1_pillow) img2 = pillow_to_base64(img2_pillow) else: # clean temp dir os.makedirs(TEMP_DIR, exist_ok=True) for file_ in os.listdir(TEMP_DIR): if file_.endswith(".jpg"): os.remove(TEMP_DIR + "/" + file_) # create base64 str from pillow images img1 = pillow_local_file_to_base64(img1_pillow) img2 = pillow_local_file_to_base64(img2_pillow) # load css + js cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest" css_block = f'' js_block = f'' # write html block htmlcode = f""" {css_block} {js_block}
""" static_component = components.html(htmlcode, height=height, width=width) return static_component