streamlit-api / algorithm.py
tkau's picture
Upload 15 files (#1)
f8ce4cf
from ultralytics import YOLO
import streamlit as st
import cv2
#import pafy
import numpy as np
import settings
import requests
class YoloV8Detection:
def __init__(self):
pass
def display_tracker_options(self):
col1, col2 = st.columns(2)
with col1:
display_tracker = st.radio("显示追踪", ('是', '否'))
is_display_tracker = True if display_tracker == '是' else False
if is_display_tracker:
with col2:
tracker_type = st.radio("追踪器选择", ("bytetrack.yaml", "botsort.yaml"))
return is_display_tracker, tracker_type
return is_display_tracker, None
def _display_detected_frames(self,conf, model, st_frame, image, is_display_tracking=None, tracker=None):
image = cv2.resize(image, (720, int(720*(9/16))))
if is_display_tracking:
res = model.track(image, conf=conf, persist=True, tracker=tracker)
else:
res = model.predict(image, conf=conf)
res_plotted = res[0].plot()
try:
st_frame.image(res_plotted,
caption='实时检测',
channels="BGR",
use_column_width=True
)
except requests.exceptions.RequestException as e:
st.write("Unable to get image, using placeholder")
st.image("placeholder.png")
def play_rtsp_stream(self,conf, model):
source_rtsp = st.sidebar.text_input("rtsp stream url")
is_display_tracker, tracker = self.display_tracker_options()
start = st.sidebar.button('检测预览')
stop = st.sidebar.button('停止')
if start:
try:
vid_cap = cv2.VideoCapture(source_rtsp)
st_frame = st.empty()
while (vid_cap.isOpened()):
success, image = vid_cap.read()
if success:
#print("Success")
if stop:
break
self._display_detected_frames(conf,
model,
st_frame,
image,
is_display_tracker,
tracker
)
else:
vid_cap.release()
print("Error")
break
except Exception as e:
st.sidebar.error("Error loading RTSP stream: " + str(e))
class BoundaryDetection(YoloV8Detection):
def __init__(self,intrusion_area= [(100, 200), (400, 200), (500, 300), (100, 300)]):
self.intrusion_area = intrusion_area
def _display_detected_frames(self, conf, model, st_frame, image, is_display_tracking=None, tracker=None):
image = cv2.resize(image, (720, int(720*(9/16))))
if is_display_tracking:
res = model.track(image, conf=conf, persist=True, tracker=tracker)
else:
res = model.predict(image, conf=conf)
cv2.polylines(image, [np.array(self.intrusion_area)], isClosed=True, color=(0, 0, 255), thickness=2)
intrusion_detected = False
for result in res:
# for xyxy_box in result.boxes.xyxy:
print(f"boxes是:{result.boxes}\n")
box = result.boxes.xyxy
print(f"box是:{box}\n")
if box.size(0) != 0:
for (xyxy,obj) in zip(box,result.boxes):
x, y, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
cx, cy = (x + x2) / 2, (y + y2) / 2
print(f"中点坐标是:{cx, cy}\n")
if cv2.pointPolygonTest(np.array(self.intrusion_area), (cx, cy), False) >= 0:
# 有人进入特定区域,进行报警
print("有人入侵!\n")
intrusion_detected = True
# 添加框出入侵者的代码
cv2.rectangle(image, (x, y), (x2, y2), (0, 255, 0), 2)
cv2.putText(image, f"{result.names[int(obj.cls[0])]} {obj.conf[0]:.2f}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
else:
continue
if intrusion_detected:
cv2.putText(image, "Intrusion detected!", (10, 30), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 2)
# res_plotted = res[0].plot()
# res_plotted=cv2.polylines(res_plotted, [np.array(self.intrusion_area)], isClosed=True, color=(0, 0, 255), thickness=2)
try:
st_frame.image(image,
caption='实时检测',
channels="BGR",
use_column_width=True
)
except requests.exceptions.RequestException as e:
st.write("Unable to get image, using placeholder")
st.image("placeholder.png")
def play_rtsp_stream(self, conf, model):
source_rtsp = st.sidebar.text_input("rtsp stream url")
is_display_tracker, tracker = self.display_tracker_options()
start = st.sidebar.button('检测预览')
stop = st.sidebar.button('停止')
if start:
try:
vid_cap = cv2.VideoCapture(source_rtsp)
st_frame = st.empty()
while (vid_cap.isOpened()):
success, image = vid_cap.read()
if success:
if stop:
break
self._display_detected_frames(conf,
model,
st_frame,
image,
is_display_tracker,
tracker
)
else:
vid_cap.release()
break
except Exception as e:
st.sidebar.error("Error loading RTSP stream: " + str(e))