tkau commited on
Commit
d55dccd
·
1 Parent(s): f8ce4cf
test_api/__pycache__/algorithm.cpython-39.pyc ADDED
Binary file (4.69 kB). View file
 
test_api/algorithm.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import streamlit as st
3
+ import cv2
4
+ #import pafy
5
+ import numpy as np
6
+ import settings
7
+ import requests
8
+
9
+
10
+
11
+ class YoloV8Detection:
12
+
13
+ def __init__(self):
14
+ pass
15
+
16
+ def display_tracker_options(self):
17
+ col1, col2 = st.columns(2)
18
+ with col1:
19
+ display_tracker = st.radio("显示追踪", ('是', '否'))
20
+ is_display_tracker = True if display_tracker == '是' else False
21
+ if is_display_tracker:
22
+ with col2:
23
+ tracker_type = st.radio("追踪器选择", ("bytetrack.yaml", "botsort.yaml"))
24
+ return is_display_tracker, tracker_type
25
+ return is_display_tracker, None
26
+
27
+
28
+
29
+ def _display_detected_frames(self,conf, model, st_frame, image, is_display_tracking=None, tracker=None):
30
+
31
+
32
+ image = cv2.resize(image, (720, int(720*(9/16))))
33
+
34
+ if is_display_tracking:
35
+ res = model.track(image, conf=conf, persist=True, tracker=tracker)
36
+ else:
37
+ res = model.predict(image, conf=conf)
38
+
39
+ res_plotted = res[0].plot()
40
+ try:
41
+ st_frame.image(res_plotted,
42
+ caption='实时检测',
43
+ channels="BGR",
44
+ use_column_width=True
45
+ )
46
+ except requests.exceptions.RequestException as e:
47
+ st.write("Unable to get image, using placeholder")
48
+ st.image("placeholder.png")
49
+
50
+
51
+ def play_rtsp_stream(self,conf, model,display_video=False):
52
+
53
+ source_rtsp = st.sidebar.text_input("rtsp stream url",value="http://127.0.0.1:8999/live/test.live.flv")
54
+ is_display_tracker, tracker = self.display_tracker_options()
55
+
56
+ start = st.sidebar.button('检测目标')
57
+ stop = st.sidebar.button('停止')
58
+
59
+ if start:
60
+ try:
61
+ vid_cap = cv2.VideoCapture(source_rtsp)
62
+ st_frame = st.empty()
63
+
64
+ while (vid_cap.isOpened()):
65
+ success, image = vid_cap.read()
66
+ if success:
67
+ #print("Success")
68
+ if stop:
69
+ break
70
+
71
+ elif display_video:
72
+ self._display_detected_frames(conf,
73
+ model,
74
+ st_frame,
75
+ image,
76
+ is_display_tracker,
77
+ tracker
78
+ )
79
+
80
+ else:
81
+ vid_cap.release()
82
+ print("Error")
83
+ break
84
+ except Exception as e:
85
+ st.sidebar.error("Error loading RTSP stream: " + str(e))
86
+
87
+
88
+
89
+ class BoundaryDetection(YoloV8Detection):
90
+ def __init__(self,intrusion_area= [(100, 200), (400, 200), (500, 300), (100, 300)]):
91
+ self.intrusion_area = intrusion_area
92
+
93
+ def _display_detected_frames(self, conf, model, st_frame, image, is_display_tracking=None, tracker=None):
94
+
95
+ image = cv2.resize(image, (720, int(720*(9/16))))
96
+
97
+ if is_display_tracking:
98
+ res = model.track(image, conf=conf, persist=True, tracker=tracker)
99
+ else:
100
+ res = model.predict(image, conf=conf)
101
+
102
+ cv2.polylines(image, [np.array(self.intrusion_area)], isClosed=True, color=(0, 0, 255), thickness=2)
103
+ intrusion_detected = False
104
+
105
+ for result in res:
106
+ # for xyxy_box in result.boxes.xyxy:
107
+ print(f"boxes是:{result.boxes}\n")
108
+
109
+ box = result.boxes.xyxy
110
+
111
+ print(f"box是:{box}\n")
112
+ if box.size(0) != 0:
113
+ for (xyxy,obj) in zip(box,result.boxes):
114
+ x, y, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
115
+ cx, cy = (x + x2) / 2, (y + y2) / 2
116
+
117
+ print(f"中点坐标是:{cx, cy}\n")
118
+ if cv2.pointPolygonTest(np.array(self.intrusion_area), (cx, cy), False) >= 0:
119
+ # 有人进入特定区域,进行报警
120
+ print("有人入侵!\n")
121
+ intrusion_detected = True
122
+
123
+ # 添加框出入侵者的代码
124
+ cv2.rectangle(image, (x, y), (x2, y2), (0, 255, 0), 2)
125
+ cv2.putText(image, f"{result.names[int(obj.cls[0])]} {obj.conf[0]:.2f}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
126
+
127
+ else:
128
+ continue
129
+
130
+ if intrusion_detected:
131
+ cv2.putText(image, "Intrusion detected!", (10, 30), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 2)
132
+
133
+ # res_plotted = res[0].plot()
134
+ # res_plotted=cv2.polylines(res_plotted, [np.array(self.intrusion_area)], isClosed=True, color=(0, 0, 255), thickness=2)
135
+
136
+
137
+ try:
138
+ st_frame.image(image,
139
+ caption='实时检测',
140
+ channels="BGR",
141
+ use_column_width=True
142
+ )
143
+ except requests.exceptions.RequestException as e:
144
+ st.write("Unable to get image, using placeholder")
145
+ st.image("placeholder.png")
146
+
147
+
148
+ def play_rtsp_stream(self, conf, model,display_video=False):
149
+ source_rtsp = st.sidebar.text_input("rtsp stream url")
150
+ is_display_tracker, tracker = self.display_tracker_options()
151
+ if st.sidebar.button('检测目标'):
152
+ try:
153
+ vid_cap = cv2.VideoCapture(source_rtsp)
154
+ st_frame = st.empty()
155
+ while (vid_cap.isOpened()):
156
+ success, image = vid_cap.read()
157
+ if success:
158
+
159
+ if display_video:
160
+ self._display_detected_frames(conf,
161
+ model,
162
+ st_frame,
163
+ image,
164
+ is_display_tracker,
165
+ tracker
166
+ )
167
+ else:
168
+ vid_cap.release()
169
+ break
170
+ except Exception as e:
171
+ st.sidebar.error("Error loading RTSP stream: " + str(e))
172
+
173
+
test_api/app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ import PIL
4
+
5
+ # External packages
6
+ import streamlit as st
7
+
8
+ # Local Modules
9
+ import settings
10
+ import helper
11
+ import test_api.algorithm as algorithm
12
+ import multiprocessing
13
+ import time
14
+ import requests
15
+ import socketio
16
+ import base64
17
+ import io
18
+ from PIL import Image
19
+
20
+
21
+ st.set_page_config(
22
+ page_title="YOLOv8 目标检测",
23
+ page_icon="🤖",
24
+ layout="wide",
25
+ initial_sidebar_state="expanded" # 或者 "collapsed"
26
+ )
27
+
28
+
29
+ # # Main page heading
30
+ st.title("YOLOv8 目标检测")
31
+
32
+ # Sidebar
33
+ st.sidebar.header("模型配置")
34
+
35
+ # Model Options
36
+ model_type = st.sidebar.selectbox(
37
+ "任务选择", ['检测', '分割',"越界检测","行为检测"])
38
+
39
+ confidence = float(st.sidebar.slider(
40
+ "选择模型Confidence", 25, 100, 40)) / 100
41
+
42
+ # Selecting Detection Or Segmentation
43
+ if model_type == '检测':
44
+ model_path = Path(settings.DETECTION_MODEL)
45
+ elif model_type == '分割':
46
+ model_path = Path(settings.SEGMENTATION_MODEL)
47
+ elif model_type == "越界检测":
48
+ model_path = Path(settings.SEGMENTATION_MODEL)
49
+ elif model_type == "行为检测":
50
+ model_path = Path(settings.SEGMENTATION_MODEL)
51
+
52
+ # Load Pre-trained ML Model
53
+ try:
54
+ model = helper.load_model(model_path)
55
+ except Exception as ex:
56
+ st.error(f"Unable to load model. Check the specified path: {model_path}")
57
+ st.error(ex)
58
+
59
+ st.sidebar.header("图像/视频 配置")
60
+ source_radio = st.sidebar.radio(
61
+ "选择来源", settings.SOURCES_LIST)
62
+
63
+ source_img = None
64
+ # If image is selected
65
+ if source_radio == settings.IMAGE:
66
+ source_img = st.sidebar.file_uploader(
67
+ "选择一张图像...", type=("jpg", "jpeg", "png", 'bmp', 'webp'))
68
+
69
+ col1, col2 = st.columns(2)
70
+
71
+ with col1:
72
+ try:
73
+ if source_img is None:
74
+ default_image_path = str(settings.DEFAULT_IMAGE)
75
+ default_image = PIL.Image.open(default_image_path)
76
+ st.image(default_image_path, caption="默认图像",
77
+ use_column_width=True)
78
+ else:
79
+ uploaded_image = PIL.Image.open(source_img)
80
+ st.image(source_img, caption="Uploaded Image",
81
+ use_column_width=True)
82
+ except Exception as ex:
83
+ st.error("Error occurred while opening the image.")
84
+ st.error(ex)
85
+
86
+ with col2:
87
+ if source_img is None:
88
+ default_detected_image_path = str(settings.DEFAULT_DETECT_IMAGE)
89
+ default_detected_image = PIL.Image.open(
90
+ default_detected_image_path)
91
+ st.image(default_detected_image_path, caption='检测图像',
92
+ use_column_width=True)
93
+ else:
94
+ if st.sidebar.button('检测目标'):
95
+ res = model.predict(uploaded_image,
96
+ conf=confidence
97
+ )
98
+ boxes = res[0].boxes
99
+ res_plotted = res[0].plot()[:, :, ::-1]
100
+ st.image(res_plotted, caption='Detected Image',
101
+ use_column_width=True)
102
+ try:
103
+ with st.expander("Detection Results"):
104
+ for box in boxes:
105
+ st.write(box.data)
106
+ except Exception as ex:
107
+ # st.write(ex)
108
+ st.write("No image is uploaded yet!")
109
+
110
+
111
+
112
+
113
+ elif source_radio == settings.RTSP:
114
+ if model_type == '检测':
115
+ src ={
116
+ "video_url": "http://127.0.0.1:8999/live/test.live.flv"
117
+ }
118
+ start = st.sidebar.button('检测目标')
119
+ stop = st.sidebar.button('停止')
120
+
121
+ if start:
122
+ response = requests.post('http://192.168.110.232:7555/analyzerControlAdd', json=src)
123
+ print(response.text)
124
+
125
+ elif model_type == '分割':
126
+ yolov8=algorithm.YoloV8Detection()
127
+ yolov8.play_rtsp_stream(confidence,model,display_video=True)
128
+
129
+ elif model_type == "越界检测":
130
+ yolov8_b=algorithm.BoundaryDetection()
131
+ yolov8_b.play_rtsp_stream(confidence,model,display_video=True)
132
+
133
+
134
+ #helper.play_rtsp_stream(confidence, model)
135
+
136
+ else:
137
+ st.error("Please select a valid source type!")
138
+
139
+
140
+ # 创建 SocketIO 实例
141
+ sio = socketio.Client()
142
+
143
+ # 连接到 WebSocket 服务器
144
+ sio.connect('http://192.168.110.232:7555')
145
+ # 接收帧数据并显示
146
+ @sio.on('frame_data')
147
+ def handle_frame_data(data):
148
+ image_bytes = base64.b64decode(data['image'])
149
+ image = Image.open(io.BytesIO(image_bytes))
150
+ st.image(image, caption='Real-time Detection', channels='BGR')
151
+
152
+ # 等待连接断开
153
+ st.text('Waiting for real-time updates...')
154
+ sio.wait()
test_api/helper.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import streamlit as st
3
+ import cv2
4
+ #import pafy
5
+ import requests
6
+ import test_api.settings as settings
7
+ import numpy as np
8
+ from test_api.algorithm import *
9
+
10
+ def load_model(model_path):
11
+
12
+ model = YOLO(model_path)
13
+ return model
14
+
15
+
16
+ def display_tracker_options():
17
+ col1, col2 = st.columns(2)
18
+ with col1:
19
+ display_tracker = st.radio("显示追踪", ('Yes', 'No'))
20
+ is_display_tracker = True if display_tracker == 'Yes' else False
21
+ if is_display_tracker:
22
+ with col2:
23
+ tracker_type = st.radio("追踪器", ("bytetrack.yaml", "botsort.yaml"))
24
+ return is_display_tracker, tracker_type
25
+ return is_display_tracker, None
26
+
27
+
28
+
29
+ def _display_detected_frames(conf, model, st_frame, image, is_display_tracking=None, tracker=None):
30
+
31
+
32
+ # Resize the image to a standard size
33
+ image = cv2.resize(image, (720, int(720*(9/16))))
34
+
35
+ # Display object tracking, if specified
36
+ if is_display_tracking:
37
+ res = model.track(image, conf=conf, persist=True, tracker=tracker)
38
+ else:
39
+ # Predict the objects in the image using the YOLOv8 model
40
+ res = model.predict(image, conf=conf)
41
+
42
+ # # Plot the detected objects on the video frame
43
+ res_plotted = res[0].plot()
44
+ try:
45
+ st_frame.image(res_plotted,
46
+ caption='实时检测',
47
+ channels="BGR",
48
+ use_column_width=True
49
+ )
50
+ except requests.exceptions.RequestException as e:
51
+ st.write("Unable to get image, using placeholder")
52
+ st.image("placeholder.png")
53
+
54
+
55
+ def play_rtsp_stream(conf, model):
56
+ source_rtsp = st.sidebar.text_input("rtsp stream url")
57
+ is_display_tracker, tracker = display_tracker_options()
58
+ if st.sidebar.button('检测目标'):
59
+ try:
60
+ vid_cap = cv2.VideoCapture(source_rtsp)
61
+ st_frame = st.empty()
62
+ while (vid_cap.isOpened()):
63
+ success, image = vid_cap.read()
64
+ if success:
65
+ _display_detected_frames(conf,
66
+ model,
67
+ st_frame,
68
+ image,
69
+ is_display_tracker,
70
+ tracker
71
+ )
72
+ else:
73
+ vid_cap.release()
74
+ break
75
+ except Exception as e:
76
+ st.sidebar.error("Error loading RTSP stream: " + str(e))
77
+
78
+
79
+
80
+
test_api/yoloflaskapi.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import base64
4
+ import json
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from pathlib import Path
9
+ from ultralytics import YOLO
10
+ import json
11
+ from flask import Flask,request_finished
12
+ from flask import Flask, request, jsonify
13
+ import threading
14
+ import time
15
+ from flask import Flask
16
+ from flask_cors import CORS
17
+ from queue import Queue
18
+ from flask_socketio import SocketIO, emit
19
+ import pickle
20
+
21
+ app = Flask(__name__)
22
+ CORS(app)
23
+ socketio = SocketIO(app)
24
+
25
+ algorithm_running = False
26
+
27
+ data = {
28
+ "code": 0,
29
+ "msg": "unknown error",
30
+ }
31
+
32
+ @socketio.on('connect')
33
+ def handle_connect():
34
+ print('Client connected')
35
+
36
+ def algorithm_thread(source_rtsp):
37
+ global algorithm_running
38
+ global data
39
+
40
+
41
+
42
+ vid_cap = cv2.VideoCapture(source_rtsp)
43
+
44
+ while algorithm_running:
45
+ success, image = vid_cap.read()
46
+
47
+ if success:
48
+
49
+ res = model(image)
50
+ res_plotted = res[0].plot()
51
+ # 将帧转换为 base64 编码的字符串
52
+ _, buffer = cv2.imencode('.jpg', res_plotted)
53
+ frame_bytes = buffer.tobytes()
54
+ frame_base64 = base64.b64encode(frame_bytes).decode('utf-8')
55
+
56
+ # 发送帧数据到客户端
57
+ socketio.emit('frame_data', {'image': frame_base64})
58
+
59
+ vid_cap.release()
60
+
61
+ @app.route("/analyzerControlAdd",methods=['POST'])
62
+ def start_algorithm():
63
+ global algorithm_running
64
+ try:
65
+ params = request.get_json()
66
+ except:
67
+ params = request.form
68
+ source_rtsp = params.get("video_url")
69
+ print(f"video_url",source_rtsp)
70
+ if not algorithm_running :
71
+ algorithm_running = True
72
+ print("start algorithm")
73
+ threading.Thread(target=algorithm_thread(source_rtsp)).start()
74
+
75
+
76
+
77
+
78
+
79
+ @app.route("/analyzerControlCancel", methods=['POST'])
80
+ def stop_algorithm():
81
+ global algorithm_running
82
+ global data
83
+ algorithm_running = False
84
+
85
+
86
+ return json.dumps({"code": 1000,"msg": "算法停止!"}, ensure_ascii=False)
87
+
88
+ if __name__ == "__main__":
89
+
90
+
91
+ model = YOLO('../weights/yolov8n.pt')
92
+
93
+ app.run(host="0.0.0.0",port=7555)
94
+
95
+
96
+