sahi-yolox / app.py
fcakyon
links open to new tab
4c15790
raw history blame
No virus
8.57 kB
import streamlit as st
import sahi.utils.mmdet
import sahi.model
from PIL import Image
import random
from utils import imagecompare
from utils import sahi_mmdet_inference
import pathlib
import os
MMDET_YOLACT_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolact/yolact_r50_1x8_coco/yolact_r50_1x8_coco_20200908-f38d58df.pth"
MMDET_YOLOX_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth"
MMDET_FASTERRCNN_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth"
# Images
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
"apple_tree.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
"highway.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
"highway2.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
"highway3.jpg",
)
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_mmdet_model(model_name: str):
if model_name == "yolact":
model_path = "yolact.pt"
sahi.utils.file.download_from_url(
MMDET_YOLACT_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="yolact", config_file_name="yolact_r50_1x8_coco.py"
)
elif model_name == "yolox":
model_path = "yolox.pt"
sahi.utils.file.download_from_url(
MMDET_YOLOX_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="yolox", config_file_name="yolox_tiny_8x8_300e_coco.py"
)
elif model_name == "faster_rcnn":
model_path = "faster_rcnn.pt"
sahi.utils.file.download_from_url(
MMDET_FASTERRCNN_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="faster_rcnn", config_file_name="faster_rcnn_r50_fpn_2x_coco.py"
)
detection_model = sahi.model.MmdetDetectionModel(
model_path=model_path,
config_path=config_path,
confidence_threshold=0.4,
device="cpu",
)
return detection_model
st.set_page_config(
page_title="Small Object Detection with SAHI + YOLOX",
page_icon="πŸš€",
layout="centered",
initial_sidebar_state="auto",
)
st.markdown(
"""
<h2 style='text-align: center'>
Small Object Detection <br />
with SAHI + YOLOX
</h2>
""",
unsafe_allow_html=True,
)
st.markdown(
"""
<p style='text-align: center'>
<a href='https://github.com/obss/sahi' target='_blank'>SAHI Github</a> | <a href='https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox' target='_blank'>YOLOX Github</a> | <a href='https://huggingface.co/spaces/fcakyon/sahi-yolov5' target='_blank'>SAHI+YOLOv5 Demo</a>
<br />
Follow me on <a href='https://twitter.com/fcakyon' target='_blank'>twitter</a>, <a href='https://www.linkedin.com/in/fcakyon/' target='_blank'>linkedin</a> and <a href='https://fcakyon.medium.com/' target='_blank'>medium</a> for more..
</p>
""",
unsafe_allow_html=True,
)
st.write("##")
col1, col2, col3 = st.columns([6, 1, 6])
with col1:
st.markdown(f"##### Set input image:")
image_file = st.file_uploader(
"Upload an image to test:", type=["jpg", "jpeg", "png"]
)
def slider_func(option):
option_to_id = {
"apple_tree.jpg": str(1),
"highway.jpg": str(2),
"highway2.jpg": str(3),
"highway3.jpg": str(4),
}
return option_to_id[option]
slider = st.select_slider(
"Or select from example images:",
options=["apple_tree.jpg", "highway.jpg", "highway2.jpg", "highway3.jpg"],
format_func=slider_func,
)
image = Image.open(slider)
st.image(image, caption=slider, width=300)
with col3:
st.markdown(f"##### Set SAHI parameters:")
model_name = "yolox"
slice_size = st.number_input("slice_size", min_value=256, value=512, step=256)
overlap_ratio = st.number_input(
"overlap_ratio", min_value=0.0, max_value=0.6, value=0.2, step=0.2
)
postprocess_type = st.selectbox(
"postprocess_type", options=["NMS", "UNIONMERGE"], index=1
)
postprocess_match_metric = st.selectbox(
"postprocess_match_metric", options=["IOU", "IOS"], index=1
)
postprocess_match_threshold = st.number_input(
"postprocess_match_threshold", value=0.5, step=0.1
)
postprocess_class_agnostic = st.checkbox("postprocess_class_agnostic", value=True)
col1, col2, col3 = st.columns([6, 1, 6])
with col2:
submit = st.button("Submit")
if image_file is not None:
image = Image.open(image_file)
else:
image = Image.open(slider)
class SpinnerTexts:
def __init__(self):
self.ind_history_list = []
self.text_list = [
"Meanwhile check out [MMDetection Colab notebook of SAHI](https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_mmdetection.ipynb)!",
"Meanwhile check out [YOLOv5 Colab notebook of SAHI](https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_yolov5.ipynb)!",
"Meanwhile check out [aerial object detection with SAHI](https://blog.ml6.eu/how-to-detect-small-objects-in-very-large-images-70234bab0f98?gi=b434299595d4)!",
"Meanwhile check out [COCO Utilities of SAHI](https://github.com/obss/sahi/blob/main/docs/COCO.md)!",
"Meanwhile check out [FiftyOne utilities of SAHI](https://github.com/obss/sahi#fiftyone-utilities)!",
"Meanwhile [give a Github star to SAHI](https://github.com/obss/sahi/stargazers)!",
"Meanwhile see [how easy is to install SAHI](https://github.com/obss/sahi#getting-started)!",
"Meanwhile check out [Medium blogpost of SAHI](https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80)!",
"Meanwhile try out [YOLOv5 HF Spaces demo of SAHI](https://huggingface.co/spaces/fcakyon/sahi-yolov5)!",
]
def _store(self, ind):
if len(self.ind_history_list) == 6:
self.ind_history_list.pop(0)
self.ind_history_list.append(ind)
def get(self):
ind = 0
while ind in self.ind_history_list:
ind = random.randint(0, len(self.text_list) - 1)
self._store(ind)
return self.text_list[ind]
if "last_spinner_texts" not in st.session_state:
st.session_state["last_spinner_texts"] = SpinnerTexts()
if submit:
# perform prediction
with st.spinner(
text="Downloading model weight.. "
+ st.session_state["last_spinner_texts"].get()
):
detection_model = get_mmdet_model(model_name)
if model_name == "yolox":
image_size = 416
else:
image_size = 640
with st.spinner(
text="Performing prediction.. " + st.session_state["last_spinner_texts"].get()
):
output_1, output_2 = sahi_mmdet_inference(
image,
detection_model,
image_size=image_size,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap_ratio,
overlap_width_ratio=overlap_ratio,
postprocess_type=postprocess_type,
postprocess_match_metric=postprocess_match_metric,
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
)
st.markdown(f"##### YOLOX Standard vs SAHI Prediction:")
static_component, image_1_path, image_2_path = imagecompare(
output_1,
output_2,
label1="YOLOX",
label2="SAHI+YOLOX",
width=700,
starting_position=50,
show_labels=True,
make_responsive=True,
)
st.magic(static_component)
st.write(image_1_path)
st.write(pathlib.Path(image_1_path).exists())
st.image(Image.open(image_1_path))