import streamlit as st
import sahi.utils.mmdet
import sahi.model
from PIL import Image
import random
from utils import imagecompare
from utils import sahi_mmdet_inference
import pathlib
import os

MMDET_YOLACT_MODEL_URL = ""
MMDET_YOLOX_MODEL_URL = ""
MMDET_FASTERRCNN_MODEL_URL = ""

# Images
sahi.utils.file.download_from_url(
    "",
    "apple_tree.jpg",
)
sahi.utils.file.download_from_url(
    "",
    "highway.jpg",
)
sahi.utils.file.download_from_url(
    "",
    "highway2.jpg",
)
sahi.utils.file.download_from_url(
    "",
    "highway3.jpg",
)


@st.cache(allow_output_mutation=True, show_spinner=False)
def get_mmdet_model(model_name: str):
    if model_name == "yolact":
        model_path = ""
        sahi.utils.file.download_from_url(
            MMDET_YOLACT_MODEL_URL,
            model_path,
        )
        config_path = sahi.utils.mmdet.download_mmdet_config(
            model_name="yolact", config_file_name=""
        )
    elif model_name == "yolox":
        model_path = ""
        sahi.utils.file.download_from_url(
            MMDET_YOLOX_MODEL_URL,
            model_path,
        )
        config_path = sahi.utils.mmdet.download_mmdet_config(
            model_name="yolox", config_file_name=""
        )
    elif model_name == "faster_rcnn":
        model_path = ""
        sahi.utils.file.download_from_url(
            MMDET_FASTERRCNN_MODEL_URL,
            model_path,
        )
        config_path = sahi.utils.mmdet.download_mmdet_config(
            model_name="faster_rcnn",
            config_file_name=""
        )

    detection_model = sahi.model.MmdetDetectionModel(
        model_path=model_path,
        config_path=config_path,
        confidence_threshold=0.4,
        device="cpu",
    )

    return detection_model


st.set_page_config(
    page_title="Small Object Detection with SAHI + YOLOX",
    page_icon="🚀",
    layout="centered",
    initial_sidebar_state="auto",
)

st.markdown(
    """

Small Object Detection

""",
    unsafe_allow_html=True,
)

SAHI Github | YOLOX Github | SAHI+YOLOv5 Demo
Follow me on twitter, linkedin and medium for more..

""", unsafe_allow_html=True, ) st.write("##") col1, col2, col3 = st.columns([6, 1, 6]) with col1: st.markdown(f"##### Set input image:") image_file = st.file_uploader( "Upload an image to test:", type=["jpg", "jpeg", "png"] ) def slider_func(option): option_to_id = { "apple_tree.jpg": str(1), "highway.jpg": str(2), "highway2.jpg": str(3), "highway3.jpg": str(4), } return option_to_id[option] slider = st.select_slider( "Or select from example images:", options=["apple_tree.jpg", "highway.jpg", "highway2.jpg", "highway3.jpg"], format_func=slider_func, ) image = st.image(image, caption=slider, width=300) with col3: st.markdown(f"##### Set SAHI parameters:") model_name = "yolox" slice_size = st.number_input("slice_size", min_value=256, value=512, step=256) overlap_ratio = st.number_input( "overlap_ratio", min_value=0.0, max_value=0.6, value=0.2, step=0.2 ) postprocess_type = st.selectbox( "postprocess_type", options=["NMS", "UNIONMERGE"], index=1 ) postprocess_match_metric = st.selectbox( "postprocess_match_metric", options=["IOU", "IOS"], index=1 ) postprocess_match_threshold = st.number_input( "postprocess_match_threshold", value=0.5, step=0.1 ) postprocess_class_agnostic = st.checkbox("postprocess_class_agnostic", value=True) col1, col2, col3 = st.columns([6, 1, 6]) with col2: submit = st.button("Submit") if image_file is not None: image = else: image = class SpinnerTexts: def __init__(self): self.ind_history_list = [] self.text_list = [ "Meanwhile check out [MMDetection Colab notebook of SAHI](!", "Meanwhile check out [YOLOv5 Colab notebook of SAHI](!", "Meanwhile check out [aerial object detection with SAHI](!", "Meanwhile check out [COCO Utilities of SAHI](!", "Meanwhile check out [FiftyOne utilities of SAHI](!", "Meanwhile [give a Github star to SAHI](!", "Meanwhile see [how easy is to install SAHI](!", "Meanwhile check out [Medium blogpost of SAHI](!", "Meanwhile try out [YOLOv5 HF Spaces demo of SAHI](!", ] def _store(self, ind): if len(self.ind_history_list) == 6: self.ind_history_list.pop(0) self.ind_history_list.append(ind) def get(self): ind = 0 while ind in self.ind_history_list: ind = random.randint(0, len(self.text_list) - 1) self._store(ind) return self.text_list[ind] if "last_spinner_texts" not in st.session_state: st.session_state["last_spinner_texts"] = SpinnerTexts() if submit: # perform prediction with st.spinner( text="Downloading model weight.. " + st.session_state["last_spinner_texts"].get() ): detection_model = get_mmdet_model(model_name) if model_name == "yolox": image_size = 416 else: image_size = 640 with st.spinner( text="Performing prediction.. " + st.session_state["last_spinner_texts"].get() ): output_1, output_2 = sahi_mmdet_inference( image, detection_model, image_size=image_size, slice_height=slice_size, slice_width=slice_size, overlap_height_ratio=overlap_ratio, overlap_width_ratio=overlap_ratio, postprocess_type=postprocess_type, postprocess_match_metric=postprocess_match_metric, postprocess_match_threshold=postprocess_match_threshold, postprocess_class_agnostic=postprocess_class_agnostic, ) st.markdown(f"##### YOLOX Standard vs SAHI Prediction:") static_component = imagecompare( output_1, output_2, label1="YOLOX", label2="SAHI+YOLOX", width=700, starting_position=50, show_labels=True, make_responsive=True, )