streamlit-api / helper.py
tkau's picture
Upload 15 files (#1)
f8ce4cf
from ultralytics import YOLO
import streamlit as st
import cv2
#import pafy
import requests
import settings
import numpy as np
from test_api.algorithm import *
def load_model(model_path):
model = YOLO(model_path)
return model
def display_tracker_options():
col1, col2 = st.columns(2)
with col1:
display_tracker = st.radio("显示追踪", ('Yes', 'No'))
is_display_tracker = True if display_tracker == 'Yes' else False
if is_display_tracker:
with col2:
tracker_type = st.radio("追踪器", ("bytetrack.yaml", "botsort.yaml"))
return is_display_tracker, tracker_type
return is_display_tracker, None
def _display_detected_frames(conf, model, st_frame, image, is_display_tracking=None, tracker=None):
# Resize the image to a standard size
image = cv2.resize(image, (720, int(720*(9/16))))
# Display object tracking, if specified
if is_display_tracking:
res = model.track(image, conf=conf, persist=True, tracker=tracker)
else:
# Predict the objects in the image using the YOLOv8 model
res = model.predict(image, conf=conf)
# # Plot the detected objects on the video frame
res_plotted = res[0].plot()
try:
st_frame.image(res_plotted,
caption='实时检测',
channels="BGR",
use_column_width=True
)
except requests.exceptions.RequestException as e:
st.write("Unable to get image, using placeholder")
st.image("placeholder.png")
def play_rtsp_stream(conf, model):
source_rtsp = st.sidebar.text_input("rtsp stream url")
is_display_tracker, tracker = display_tracker_options()
if st.sidebar.button('检测目标'):
try:
vid_cap = cv2.VideoCapture(source_rtsp)
st_frame = st.empty()
while (vid_cap.isOpened()):
success, image = vid_cap.read()
if success:
_display_detected_frames(conf,
model,
st_frame,
image,
is_display_tracker,
tracker
)
else:
vid_cap.release()
break
except Exception as e:
st.sidebar.error("Error loading RTSP stream: " + str(e))