yolov5-ui / app.py
robmarkcole's picture
Upload app.py
5885827
raw
history blame contribute delete
No virus
2.32 kB
import streamlit as st
import torch
from PIL import Image, ImageDraw
from typing import Tuple
import numpy as np
import const
import time
def draw_box(
draw: ImageDraw,
box: Tuple[float, float, float, float],
text: str = "",
color: Tuple[int, int, int] = (255, 255, 0),
) -> None:
"""
Draw a bounding box on and image.
"""
line_width = 3
font_height = 8
y_min, x_min, y_max, x_max = box
(left, right, top, bottom) = (
x_min,
x_max,
y_min,
y_max,
)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
width=line_width,
fill=color,
)
if text:
draw.text(
(left + line_width, abs(top - line_width - font_height)), text, fill=color
)
@st.cache(allow_output_mutation=True, show_spinner=True)
def get_model(model_id : str = "yolov5s"):
model = torch.hub.load("ultralytics/yolov5", model_id)
return model
# Settings
st.sidebar.title("Settings")
model_id = st.sidebar.selectbox("Pretrained model", const.PRETRAINED_MODELS, index=1)
img_size = st.sidebar.selectbox("Image resize for inference", const.IMAGE_SIZES, index=1)
CONFIDENCE = st.sidebar.slider(
"Confidence threshold",
const.MIN_CONF,
const.MAX_CONF,
const.DEFAULT_CONF,
)
model = get_model(model_id)
st.title(f"{model_id}")
img_file_buffer = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if img_file_buffer is not None:
pil_image = Image.open(img_file_buffer)
else:
pil_image = Image.open(const.DEFAULT_IMAGE)
st.text(f"Input image width and height: {pil_image.width} x {pil_image.height}")
start_time = time.time()
results = model(pil_image, size=img_size)
end_time = time.time()
df = results.pandas().xyxy[0]
df = df[df["confidence"] > CONFIDENCE]
draw = ImageDraw.Draw(pil_image)
for _, obj in df.iterrows():
name = obj["name"]
confidence = obj["confidence"]
box_label = f"{name}"
draw_box(
draw,
(obj["ymin"], obj["xmin"], obj["ymax"], obj["xmax"]),
text=box_label,
color=const.RED,
)
st.image(
np.array(pil_image),
caption=f"Processed image",
use_column_width=True,
)
st.text(f"Time to inference: {round(time.time() - end_time, 2)} sec")
st.table(df)