sahi-yolox / app.py
fcakyon
fix incorrect import in app
0329b55
raw history blame
No virus
7.59 kB
import streamlit as st
import sahi.utils.mmdet
import sahi.model
import sahi.predict
from PIL import Image
import numpy
MMDET_YOLACT_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolact/yolact_r50_1x8_coco/yolact_r50_1x8_coco_20200908-f38d58df.pth"
MMDET_YOLOX_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth"
MMDET_FASTERRCNN_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth"
# Images
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
"apple_tree.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
"highway.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
"highway2.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
"highway3.jpg",
)
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_mmdet_model(model_name: str):
if model_name == "yolact":
model_path = "yolact.pt"
sahi.utils.file.download_from_url(
MMDET_YOLACT_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="yolact", config_file_name="yolact_r50_1x8_coco.py"
)
elif model_name == "yolox":
model_path = "yolox.pt"
sahi.utils.file.download_from_url(
MMDET_YOLOX_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="yolox", config_file_name="yolox_tiny_8x8_300e_coco.py"
)
elif model_name == "fasterrcnn":
model_path = "fasterrcnn.pt"
sahi.utils.file.download_from_url(
MMDET_FASTERRCNN_MODEL_URL,
model_path,
)
config_path = sahi.utils.mmdet.download_mmdet_config(
model_name="faster_rcnn", config_file_name="faster_rcnn_r50_fpn_2x_coco.py"
)
detection_model = sahi.model.MmdetDetectionModel(
model_path=model_path,
config_path=config_path,
confidence_threshold=0.4,
device="cpu",
)
return detection_model
def sahi_mmdet_inference(
image,
detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
image_size=640,
postprocess_type="UNIONMERGE",
postprocess_match_metric="IOS",
postprocess_match_threshold=0.5,
postprocess_class_agnostic=False,
):
# standard inference
prediction_result_1 = sahi.predict.get_prediction(
image=image, detection_model=detection_model, image_size=image_size
)
visual_result_1 = sahi.utils.cv.visualize_object_predictions(
image=numpy.array(image),
object_prediction_list=prediction_result_1.object_prediction_list,
)
output_1 = Image.fromarray(visual_result_1["image"])
# sliced inference
prediction_result_2 = sahi.predict.get_sliced_prediction(
image=image,
detection_model=detection_model,
image_size=image_size,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
postprocess_type=postprocess_type,
postprocess_match_metric=postprocess_match_metric,
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
)
visual_result_2 = sahi.utils.cv.visualize_object_predictions(
image=numpy.array(image),
object_prediction_list=prediction_result_2.object_prediction_list,
)
output_2 = Image.fromarray(visual_result_2["image"])
return output_1, output_2
st.set_page_config(
page_title="SAHI + MMDetection Demo",
page_icon="",
layout="centered",
initial_sidebar_state="auto",
)
st.markdown(
"<h2 style='text-align: center'> SAHI + MMDetection Demo </h1>",
unsafe_allow_html=True,
)
st.markdown(
"<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection/ instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> | <a href='https://github.com/fcakyon/yolov5-pip'>YOLOv5 Github</a> </p>",
unsafe_allow_html=True,
)
st.markdown(
"<h3 style='text-align: center'> Parameters: </h1>",
unsafe_allow_html=True,
)
col1, col2, col3 = st.columns([6, 1, 6])
with col1:
image_file = st.file_uploader(
"Upload an image to test:", type=["jpg", "jpeg", "png"]
)
def slider_func(option):
option_to_id = {
"apple_tree.jpg": str(1),
"highway.jpg": str(2),
"highway2.jpg": str(3),
"highway3.jpg": str(4),
}
return option_to_id[option]
slider = st.select_slider(
"Or select from example images:",
options=["apple_tree.jpg", "highway.jpg", "highway2.jpg", "highway3.jpg"],
format_func=slider_func,
)
image = Image.open(slider)
st.image(image, caption=slider, width=300)
with col3:
model_name = st.selectbox(
"Select MMDetection model:", ("fasterrcnn", "yolact", "yolox")
)
slice_size = st.number_input("slice_size", 256, value=512, step=256)
overlap_ratio = st.number_input("overlap_ratio", 0.0, 0.6, value=0.2, step=0.2)
postprocess_type = st.selectbox(
"postprocess_type", options=["NMS", "UNIONMERGE"], index=1
)
postprocess_match_metric = st.selectbox(
"postprocess_match_metric", options=["IOU", "IOS"], index=1
)
postprocess_match_threshold = st.number_input(
"postprocess_match_threshold", value=0.5, step=0.1
)
postprocess_class_agnostic = st.checkbox("postprocess_class_agnostic", value=True)
col1, col2, col3 = st.columns([6, 1, 6])
with col2:
submit = st.button("Submit")
if image_file is not None:
image = Image.open(image_file)
else:
image = Image.open(slider)
if submit:
# perform prediction
st.markdown(
"<h3 style='text-align: center'> Results: </h1>",
unsafe_allow_html=True,
)
with st.spinner(text="Downloading model weight.."):
detection_model = get_mmdet_model(model_name)
if model_name == "yolox":
image_size = 416
else:
image_size = 640
with st.spinner(
text="Performing prediction.. Meanwhile check out [other features of SAHI](https://github.com/obss/sahi/blob/main/README.md)!"
):
output_1, output_2 = sahi_mmdet_inference(
image,
detection_model,
image_size=image_size,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap_ratio,
overlap_width_ratio=overlap_ratio,
)
st.markdown(f"##### Standard {model_name} Prediction:")
st.image(output_1, width=700)
st.markdown(f"##### Sliced {model_name} Prediction:")
st.image(output_2, width=700)