Spaces:
Runtime error
Runtime error
from ultralytics import YOLO | |
import streamlit as st | |
import cv2 | |
#import pafy | |
import numpy as np | |
import settings | |
import requests | |
class YoloV8Detection: | |
def __init__(self): | |
pass | |
def display_tracker_options(self): | |
col1, col2 = st.columns(2) | |
with col1: | |
display_tracker = st.radio("显示追踪", ('是', '否')) | |
is_display_tracker = True if display_tracker == '是' else False | |
if is_display_tracker: | |
with col2: | |
tracker_type = st.radio("追踪器选择", ("bytetrack.yaml", "botsort.yaml")) | |
return is_display_tracker, tracker_type | |
return is_display_tracker, None | |
def _display_detected_frames(self,conf, model, st_frame, image, is_display_tracking=None, tracker=None): | |
image = cv2.resize(image, (720, int(720*(9/16)))) | |
if is_display_tracking: | |
res = model.track(image, conf=conf, persist=True, tracker=tracker) | |
else: | |
res = model.predict(image, conf=conf) | |
res_plotted = res[0].plot() | |
try: | |
st_frame.image(res_plotted, | |
caption='实时检测', | |
channels="BGR", | |
use_column_width=True | |
) | |
except requests.exceptions.RequestException as e: | |
st.write("Unable to get image, using placeholder") | |
st.image("placeholder.png") | |
def play_rtsp_stream(self,conf, model): | |
source_rtsp = st.sidebar.text_input("rtsp stream url") | |
is_display_tracker, tracker = self.display_tracker_options() | |
start = st.sidebar.button('检测预览') | |
stop = st.sidebar.button('停止') | |
if start: | |
try: | |
vid_cap = cv2.VideoCapture(source_rtsp) | |
st_frame = st.empty() | |
while (vid_cap.isOpened()): | |
success, image = vid_cap.read() | |
if success: | |
#print("Success") | |
if stop: | |
break | |
self._display_detected_frames(conf, | |
model, | |
st_frame, | |
image, | |
is_display_tracker, | |
tracker | |
) | |
else: | |
vid_cap.release() | |
print("Error") | |
break | |
except Exception as e: | |
st.sidebar.error("Error loading RTSP stream: " + str(e)) | |
class BoundaryDetection(YoloV8Detection): | |
def __init__(self,intrusion_area= [(100, 200), (400, 200), (500, 300), (100, 300)]): | |
self.intrusion_area = intrusion_area | |
def _display_detected_frames(self, conf, model, st_frame, image, is_display_tracking=None, tracker=None): | |
image = cv2.resize(image, (720, int(720*(9/16)))) | |
if is_display_tracking: | |
res = model.track(image, conf=conf, persist=True, tracker=tracker) | |
else: | |
res = model.predict(image, conf=conf) | |
cv2.polylines(image, [np.array(self.intrusion_area)], isClosed=True, color=(0, 0, 255), thickness=2) | |
intrusion_detected = False | |
for result in res: | |
# for xyxy_box in result.boxes.xyxy: | |
print(f"boxes是:{result.boxes}\n") | |
box = result.boxes.xyxy | |
print(f"box是:{box}\n") | |
if box.size(0) != 0: | |
for (xyxy,obj) in zip(box,result.boxes): | |
x, y, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3]) | |
cx, cy = (x + x2) / 2, (y + y2) / 2 | |
print(f"中点坐标是:{cx, cy}\n") | |
if cv2.pointPolygonTest(np.array(self.intrusion_area), (cx, cy), False) >= 0: | |
# 有人进入特定区域,进行报警 | |
print("有人入侵!\n") | |
intrusion_detected = True | |
# 添加框出入侵者的代码 | |
cv2.rectangle(image, (x, y), (x2, y2), (0, 255, 0), 2) | |
cv2.putText(image, f"{result.names[int(obj.cls[0])]} {obj.conf[0]:.2f}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) | |
else: | |
continue | |
if intrusion_detected: | |
cv2.putText(image, "Intrusion detected!", (10, 30), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 2) | |
# res_plotted = res[0].plot() | |
# res_plotted=cv2.polylines(res_plotted, [np.array(self.intrusion_area)], isClosed=True, color=(0, 0, 255), thickness=2) | |
try: | |
st_frame.image(image, | |
caption='实时检测', | |
channels="BGR", | |
use_column_width=True | |
) | |
except requests.exceptions.RequestException as e: | |
st.write("Unable to get image, using placeholder") | |
st.image("placeholder.png") | |
def play_rtsp_stream(self, conf, model): | |
source_rtsp = st.sidebar.text_input("rtsp stream url") | |
is_display_tracker, tracker = self.display_tracker_options() | |
start = st.sidebar.button('检测预览') | |
stop = st.sidebar.button('停止') | |
if start: | |
try: | |
vid_cap = cv2.VideoCapture(source_rtsp) | |
st_frame = st.empty() | |
while (vid_cap.isOpened()): | |
success, image = vid_cap.read() | |
if success: | |
if stop: | |
break | |
self._display_detected_frames(conf, | |
model, | |
st_frame, | |
image, | |
is_display_tracker, | |
tracker | |
) | |
else: | |
vid_cap.release() | |
break | |
except Exception as e: | |
st.sidebar.error("Error loading RTSP stream: " + str(e)) | |