Kedreamix commited on
Commit
4350164
1 Parent(s): 5ca7884

上传主要代码

Browse files
Files changed (8) hide show
  1. Pipfile +19 -0
  2. app.py +400 -0
  3. get_yaml.py +14 -0
  4. instructions.md +9 -0
  5. packages.txt +2 -0
  6. predict.py +194 -0
  7. requirements.txt +15 -0
  8. yolo.py +422 -0
Pipfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ name = "pypi"
3
+ url = "https://pypi.org/simple"
4
+ verify_ssl = true
5
+
6
+ [dev-packages]
7
+
8
+ [packages]
9
+ streamlit = ">0.49.0"
10
+ opencv-python = "*"
11
+ numpy = "*"
12
+ torchvision = "0.9.1"
13
+ torch = "1.8.1"
14
+ Pillow = "8.2.0"
15
+ pyyaml = "6.0"
16
+ matplotlib = "*"
17
+ opencv-python-headless = "4.5.2.52"
18
+ av = "*"
19
+ streamlit-webrtc = "0.36.1"
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Create an Object Detection Web App using PyTorch and Streamlit."""
2
+ # import libraries
3
+ from PIL import Image
4
+ from torchvision import models, transforms
5
+ import torch
6
+ import streamlit as st
7
+ from yolo import YOLO
8
+ import os
9
+ import urllib
10
+ import numpy as np
11
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
12
+ import av
13
+ # 设置网页的icon
14
+ st.set_page_config(page_title='Gesture Detector', page_icon='✌',
15
+ layout='centered', initial_sidebar_state='expanded')
16
+
17
+ RTC_CONFIGURATION = RTCConfiguration(
18
+ {
19
+ "RTCIceServer": [{
20
+ "urls": ["stun:stun.l.google.com:19302"],
21
+ "username": "pikachu",
22
+ "credential": "1234",
23
+ }]
24
+ }
25
+ )
26
+ def main():
27
+ # Render the readme as markdown using st.markdown.
28
+ readme_text = st.markdown(open("instructions.md",encoding='utf-8').read())
29
+
30
+
31
+ # Once we have the dependencies, add a selector for the app mode on the sidebar.
32
+ st.sidebar.title("What to do")
33
+ app_mode = st.sidebar.selectbox("Choose the app mode",
34
+ ["Show instructions", "Run the app", "Show the source code"])
35
+ if app_mode == "Show instructions":
36
+ st.sidebar.success('To continue select "Run the app".')
37
+ elif app_mode == "Show the source code":
38
+ readme_text.empty()
39
+ st.code(open("app.py",encoding='utf-8').read())
40
+ elif app_mode == "Run the app":
41
+ # Download external dependencies.
42
+ for filename in EXTERNAL_DEPENDENCIES.keys():
43
+ download_file(filename)
44
+
45
+ readme_text.empty()
46
+ run_the_app()
47
+
48
+ # External files to download.
49
+ EXTERNAL_DEPENDENCIES = {
50
+ "yolov4_tiny.pth": {
51
+ "url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_tiny.pth",
52
+ "size": 23631189
53
+ },
54
+ "yolov4_SE.pth": {
55
+ "url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_SE.pth",
56
+ "size": 23806027
57
+ },
58
+ "yolov4_CBAM.pth":{
59
+ "url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_CBAM.pth",
60
+ "size": 23981478
61
+ },
62
+ "yolov4_ECA.pth":{
63
+ "url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_ECA.pth",
64
+ "size": 23632688
65
+ },
66
+ "yolov4_weights_ep150_608.pth":{
67
+ "url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_weights_ep150_608.pth",
68
+ "size": 256423031
69
+ },
70
+ "yolov4_weights_ep150_416.pth":{
71
+ "url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_weights_ep150_416.pth",
72
+ "size": 256423031
73
+ },
74
+ }
75
+
76
+
77
+ # This file downloader demonstrates Streamlit animation.
78
+ def download_file(file_path):
79
+ # Don't download the file twice. (If possible, verify the download using the file length.)
80
+ if os.path.exists(file_path):
81
+ if "size" not in EXTERNAL_DEPENDENCIES[file_path]:
82
+ return
83
+ elif os.path.getsize(file_path) == EXTERNAL_DEPENDENCIES[file_path]["size"]:
84
+ return
85
+ # print(os.path.getsize(file_path))
86
+ # These are handles to two visual elements to animate.
87
+ weights_warning, progress_bar = None, None
88
+ try:
89
+ weights_warning = st.warning("Downloading %s..." % file_path)
90
+ progress_bar = st.progress(0)
91
+ with open(file_path, "wb") as output_file:
92
+ with urllib.request.urlopen(EXTERNAL_DEPENDENCIES[file_path]["url"]) as response:
93
+ length = int(response.info()["Content-Length"])
94
+ counter = 0.0
95
+ MEGABYTES = 2.0 ** 20.0
96
+ while True:
97
+ data = response.read(8192)
98
+ if not data:
99
+ break
100
+ counter += len(data)
101
+ output_file.write(data)
102
+
103
+ # We perform animation by overwriting the elements.
104
+ weights_warning.warning("Downloading %s... (%6.2f/%6.2f MB)" %
105
+ (file_path, counter / MEGABYTES, length / MEGABYTES))
106
+ progress_bar.progress(min(counter / length, 1.0))
107
+ except Exception as e:
108
+ print(e)
109
+ # Finally, we remove these visual elements by calling .empty().
110
+ finally:
111
+ if weights_warning is not None:
112
+ weights_warning.empty()
113
+ if progress_bar is not None:
114
+ progress_bar.empty()
115
+
116
+ # This is the main app app itself, which appears when the user selects "Run the app".
117
+ def run_the_app():
118
+ class Config():
119
+ def __init__(self, weights = 'yolov4_tiny.pth', tiny = True, phi = 0, shape = 416,nms_iou = 0.3, confidence = 0.5):
120
+ self.weights = weights
121
+ self.tiny = tiny
122
+ self.phi = phi
123
+ self.cuda = False
124
+ self.shape = shape
125
+ self.confidence = confidence
126
+ self.nms_iou = nms_iou
127
+ # set title of app
128
+ st.markdown('<h1 align="center">✌ Gesture Detection</h1>',
129
+ unsafe_allow_html=True)
130
+ st.sidebar.markdown("# Gesture Detection on?")
131
+ activities = ["Example","Image", "Camera", "FPS", "Heatmap","Real Time", "Video"]
132
+ choice = st.sidebar.selectbox("Choose among the given options:", activities)
133
+ phi = st.sidebar.selectbox("yolov4-tiny 使用的自注意力模式:",('0tiny','1SE','2CABM','3ECA'))
134
+ print("")
135
+
136
+ tiny = st.sidebar.checkbox('是否使用 yolov4 tiny 模型')
137
+ if not tiny:
138
+ shape = st.sidebar.selectbox("Choose shape to Input:", [416,608])
139
+ conf,nms = object_detector_ui()
140
+ @st.cache
141
+ def get_yolo(tiny,phi,conf,nms,shape=416):
142
+ weights = 'yolov4_tiny.pth'
143
+ if tiny:
144
+ if phi == '0tiny':
145
+ weights = 'yolov4_tiny.pth'
146
+ elif phi == '1SE':
147
+ weights = 'yolov4_SE.pth'
148
+ elif phi == '2CABM':
149
+ weights = 'yolov4_CBAM.pth'
150
+ elif phi == '3ECA':
151
+ weights = 'yolov4_ECA.pth'
152
+ else:
153
+ if shape == 608:
154
+ weights = 'yolov4_weights_ep150_608.pth'
155
+ elif shape == 416:
156
+ weights = 'yolov4_weights_ep150_416.pth'
157
+ opt = Config(weights = weights, tiny = tiny , phi = int(phi[0]), shape = shape,nms_iou = nms, confidence = conf)
158
+ yolo = YOLO(opt)
159
+ return yolo
160
+
161
+ if tiny:
162
+ yolo = get_yolo(tiny, phi, conf, nms)
163
+ st.write("YOLOV4 tiny 模型加载完毕")
164
+ else:
165
+ yolo = get_yolo(tiny, phi, conf, nms, shape)
166
+ st.write("YOLOV4 模型加载完毕")
167
+
168
+ if choice == 'Image':
169
+ detect_image(yolo)
170
+ elif choice =='Camera':
171
+ detect_camera(yolo)
172
+ elif choice == 'FPS':
173
+ detect_fps(yolo)
174
+ elif choice == "Heatmap":
175
+ detect_heatmap(yolo)
176
+ elif choice == "Example":
177
+ detect_example(yolo)
178
+ elif choice == "Real Time":
179
+ detect_realtime(yolo)
180
+ elif choice == "Video":
181
+ detect_video(yolo)
182
+
183
+
184
+
185
+ # This sidebar UI lets the user select parameters for the YOLO object detector.
186
+ def object_detector_ui():
187
+ st.sidebar.markdown("# Model")
188
+ confidence_threshold = st.sidebar.slider("Confidence threshold", 0.0, 1.0, 0.5, 0.01)
189
+ overlap_threshold = st.sidebar.slider("Overlap threshold", 0.0, 1.0, 0.3, 0.01)
190
+ return confidence_threshold, overlap_threshold
191
+
192
+ def predict(image,yolo):
193
+ """Return predictions.
194
+
195
+ Parameters
196
+ ----------
197
+ :param image: uploaded image
198
+ :type image: jpg
199
+ :rtype: list
200
+ :return: none
201
+ """
202
+ crop = False
203
+ count = False
204
+ try:
205
+ # image = Image.open(image)
206
+ r_image = yolo.detect_image(image, crop = crop, count=count)
207
+ transform = transforms.Compose([transforms.ToTensor()])
208
+ result = transform(r_image)
209
+ st.image(result.permute(1,2,0).numpy(), caption = 'Processed Image.', use_column_width = True)
210
+ except Exception as e:
211
+ print(e)
212
+
213
+ def fps(image,yolo):
214
+ test_interval = 50
215
+ tact_time = yolo.get_FPS(image, test_interval)
216
+ st.write(str(tact_time) + ' seconds, ', str(1/tact_time),'FPS, @batch_size 1')
217
+ return tact_time
218
+ # print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
219
+
220
+
221
+ def detect_image(yolo):
222
+ # enable users to upload images for the model to make predictions
223
+ file_up = st.file_uploader("Upload an image", type = ["jpg","png","jpeg"])
224
+ classes = ["up","down","left","right","front","back","clockwise","anticlockwise"]
225
+ class_to_idx = {cls: idx for (idx, cls) in enumerate(classes)}
226
+ st.sidebar.markdown("See the model preformance and play with it")
227
+ if file_up is not None:
228
+ with st.spinner(text='Preparing Image'):
229
+ # display image that user uploaded
230
+ image = Image.open(file_up)
231
+ st.image(image, caption = 'Uploaded Image.', use_column_width = True)
232
+ st.balloons()
233
+ detect = st.button("开始检测Image")
234
+ if detect:
235
+ st.write("")
236
+ st.write("Just a second ...")
237
+ predict(image,yolo)
238
+ st.balloons()
239
+
240
+
241
+
242
+ def detect_camera(yolo):
243
+ picture = st.camera_input("Take a picture")
244
+ if picture:
245
+ filters_to_funcs = {
246
+ "No filter": predict,
247
+ "Heatmap": heatmap,
248
+ "FPS": fps,
249
+ }
250
+ filters = st.selectbox("...and now, apply a filter!", filters_to_funcs.keys())
251
+ image = Image.open(picture)
252
+ with st.spinner(text='Preparing Image'):
253
+ filters_to_funcs[filters](image,yolo)
254
+ st.balloons()
255
+
256
+ def detect_fps(yolo):
257
+ file_up = st.file_uploader("Upload an image", type = ["jpg","png","jpeg"])
258
+ classes = ["up","down","left","right","front","back","clockwise","anticlockwise"]
259
+ class_to_idx = {cls: idx for (idx, cls) in enumerate(classes)}
260
+ st.sidebar.markdown("See the model preformance and play with it")
261
+ if file_up is not None:
262
+ # display image that user uploaded
263
+ image = Image.open(file_up)
264
+ st.image(image, caption = 'Uploaded Image.', use_column_width = True)
265
+ st.balloons()
266
+ detect = st.button("开始检测 FPS")
267
+ if detect:
268
+ with st.spinner(text='Preparing Image'):
269
+ st.write("")
270
+ st.write("Just a second ...")
271
+ tact_time = fps(image,yolo)
272
+ # st.write(str(tact_time) + ' seconds, ', str(1/tact_time),'FPS, @batch_size 1')
273
+ st.balloons()
274
+
275
+ def heatmap(image,yolo):
276
+ heatmap_save_path = "heatmap_vision.png"
277
+ yolo.detect_heatmap(image, heatmap_save_path)
278
+ img = Image.open(heatmap_save_path)
279
+ transform = transforms.Compose([transforms.ToTensor()])
280
+ result = transform(img)
281
+ st.image(result.permute(1,2,0).numpy(), caption = 'Processed Image.', use_column_width = True)
282
+
283
+ def detect_heatmap(yolo):
284
+ file_up = st.file_uploader("Upload an image", type = ["jpg","png","jpeg"])
285
+ classes = ["up","down","left","right","front","back","clockwise","anticlockwise"]
286
+ class_to_idx = {cls: idx for (idx, cls) in enumerate(classes)}
287
+ st.sidebar.markdown("See the model preformance and play with it")
288
+ if file_up is not None:
289
+ # display image that user uploaded
290
+ image = Image.open(file_up)
291
+ st.image(image, caption = 'Uploaded Image.', use_column_width = True)
292
+ st.balloons()
293
+ detect = st.button("开始检测 heatmap")
294
+ if detect:
295
+ with st.spinner(text='Preparing Heatmap'):
296
+ st.write("")
297
+ st.write("Just a second ...")
298
+ heatmap(image,yolo)
299
+ st.balloons()
300
+
301
+ def detect_example(yolo):
302
+ st.sidebar.title("Choose an Image as a example")
303
+ images = os.listdir('./img')
304
+ images.sort()
305
+ image = st.sidebar.selectbox("Image Name", images)
306
+ st.sidebar.markdown("See the model preformance and play with it")
307
+ image = Image.open(os.path.join('img',image))
308
+ st.image(image, caption = 'Choose Image.', use_column_width = True)
309
+ st.balloons()
310
+ detect = st.button("开始检测Image")
311
+ if detect:
312
+ st.write("")
313
+ st.write("Just a second ...")
314
+ predict(image,yolo)
315
+ st.balloons()
316
+
317
+ def detect_realtime(yolo):
318
+
319
+ class VideoProcessor:
320
+ def recv(self, frame):
321
+ img = frame.to_ndarray(format="bgr24")
322
+ img = Image.fromarray(img)
323
+ crop = False
324
+ count = False
325
+ r_image = yolo.detect_image(img, crop = crop, count=count)
326
+ transform = transforms.Compose([transforms.ToTensor()])
327
+ result = transform(r_image)
328
+ result = result.permute(1,2,0).numpy()
329
+ result = (result * 255).astype(np.uint8)
330
+ return av.VideoFrame.from_ndarray(result, format="bgr24")
331
+
332
+ webrtc_ctx = webrtc_streamer(
333
+ key="example",
334
+ mode=WebRtcMode.SENDRECV,
335
+ rtc_configuration=RTC_CONFIGURATION,
336
+ media_stream_constraints={"video": True, "audio": False},
337
+ async_processing=False,
338
+ video_processor_factory=VideoProcessor
339
+ )
340
+
341
+ import cv2
342
+ import time
343
+ def detect_video(yolo):
344
+ file_up = st.file_uploader("Upload a video", type = ["mp4"])
345
+ print(file_up)
346
+ classes = ["up","down","left","right","front","back","clockwise","anticlockwise"]
347
+
348
+ if file_up is not None:
349
+ video_path = 'video.mp4'
350
+ st.video(file_up)
351
+ with open(video_path, 'wb') as f:
352
+ f.write(file_up.read())
353
+ detect = st.button("开始检测 Video")
354
+
355
+ if detect:
356
+ video_save_path = 'video2.mp4'
357
+ # display image that user uploaded
358
+ capture = cv2.VideoCapture(video_path)
359
+
360
+ video_fps = st.slider("Video FPS", 5, 30, int(capture.get(cv2.CAP_PROP_FPS)), 1)
361
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
362
+ size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
363
+ out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
364
+
365
+
366
+
367
+ while(True):
368
+ # 读取某一帧
369
+ ref, frame = capture.read()
370
+ if not ref:
371
+ break
372
+ # 转变成Image
373
+ # frame = Image.fromarray(np.uint8(frame))
374
+ # 格式转变,BGRtoRGB
375
+ frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
376
+ # 转变成Image
377
+ frame = Image.fromarray(np.uint8(frame))
378
+ # 进行检测
379
+ frame = np.array(yolo.detect_image(frame))
380
+ # RGBtoBGR满足opencv显示格式
381
+ frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
382
+
383
+ # print("fps= %.2f"%(fps))
384
+ # frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
385
+ out.write(frame)
386
+
387
+ out.release()
388
+ capture.release()
389
+ print("Save processed video to the path :" + video_save_path)
390
+
391
+ with open(video_save_path, "rb") as file:
392
+ btn = st.download_button(
393
+ label="Download Video",
394
+ data=file,
395
+ file_name="video.mp4",
396
+ )
397
+ st.balloons()
398
+
399
+ if __name__ == "__main__":
400
+ main()
get_yaml.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import yaml
4
+
5
+ def get_config():
6
+ yaml_path = 'model_data/gesture.yaml'
7
+ f = open(yaml_path,'r',encoding='utf-8')
8
+ config = yaml.load(f,Loader =yaml.FullLoader)
9
+ f.close()
10
+ return config
11
+
12
+ if __name__ == "__main__":
13
+ config = get_config()
14
+ print(config)
instructions.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ✌ Gesture Detection
2
+
3
+
4
+ 这是一个基于无人机视觉图像手势识别控制系统,选择了YOLOv4模型进行训练
5
+
6
+ **YOLOv4 = CSPDarknet53(主干) + SPP** **附加模块(颈** **) +** **PANet** **路径聚合(颈** **) + YOLOv3(头部)**
7
+
8
+ ![img](https://pdf.cdn.readpaper.com/parsed/fetch_target/699143cdb334ecfc63caf8192472490c_0_Figure_1.png)
9
+
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ freeglut3-dev
2
+ libgtk2.0-dev
predict.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-----------------------------------------------------------------------#
2
+ # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
3
+ # 整合到了一个py文件中,通过指定mode进行模式的修改。
4
+ #-----------------------------------------------------------------------#
5
+ import time
6
+ import yaml
7
+ import cv2
8
+ import numpy as np
9
+ from PIL import Image
10
+ from get_yaml import get_config
11
+ from yolo import YOLO
12
+ import argparse
13
+ if __name__ == "__main__":
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--weights',type=str,default='model_data/yolotiny_SE_ep100.pth',help='initial weights path')
16
+ parser.add_argument('--tiny',action='store_true',help='使用yolotiny模型')
17
+ parser.add_argument('--phi',type=int,default=1,help='yolov4tiny注意力机制类型')
18
+ parser.add_argument('--mode',type=str,choices=['dir_predict', 'video', 'fps','predict','heatmap','export_onnx'],default="dir_predict",help='预测的模式')
19
+ parser.add_argument('--cuda',action='store_true',help='表示是否使用GPU')
20
+ parser.add_argument('--shape',type=int,default=416,help='输入图像的shape')
21
+ parser.add_argument('--video',type=str,default='',help='需要检测的视频文件')
22
+ parser.add_argument('--save-video',type=str,default='',help='保存视频的位置')
23
+ parser.add_argument('--confidence',type=float,default=0.5,help='只有得分大于置信度的预测框会被保留下来')
24
+ parser.add_argument('--nms_iou',type=float,default=0.3,help='非极大抑制所用到的nms_iou大小')
25
+ opt = parser.parse_args()
26
+ print(opt)
27
+
28
+ # 配置文件
29
+ config = get_config()
30
+ yolo = YOLO(opt)
31
+
32
+ #----------------------------------------------------------------------------------------------------------#
33
+ # mode用于指定测试的模式:
34
+ # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
35
+ # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
36
+ # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
37
+ # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
38
+ # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
39
+ # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
40
+ #----------------------------------------------------------------------------------------------------------#
41
+ mode = opt.mode
42
+ #-------------------------------------------------------------------------#
43
+ # crop 指定了是否在单张图片预测后对目标进行截取
44
+ # count 指定了是否进行目标的计数
45
+ # crop、count仅在mode='predict'时有效
46
+ #-------------------------------------------------------------------------#
47
+ crop = False
48
+ count = False
49
+ #----------------------------------------------------------------------------------------------------------#
50
+ # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
51
+ # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
52
+ # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
53
+ # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
54
+ # video_fps 用于保存的视频的fps
55
+ #
56
+ # video_path、video_save_path和video_fps仅在mode='video'时有效
57
+ # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
58
+ #----------------------------------------------------------------------------------------------------------#
59
+ video_path = 0 if opt.video == '' else opt.video
60
+ video_save_path = opt.save_video
61
+ video_fps = 25.0
62
+ #----------------------------------------------------------------------------------------------------------#
63
+ # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
64
+ # fps_image_path 用于指定测试的fps图片
65
+ #
66
+ # test_interval和fps_image_path仅在mode='fps'有效
67
+ #----------------------------------------------------------------------------------------------------------#
68
+ test_interval = 100
69
+ fps_image_path = "img/up.jpg"
70
+ #-------------------------------------------------------------------------#
71
+ # dir_origin_path 指定了用于检测的图片的文件夹路径
72
+ # dir_save_path 指定了检测完图片的保存路径
73
+ #
74
+ # dir_origin_path和dir_save_path���在mode='dir_predict'时有效
75
+ #-------------------------------------------------------------------------#
76
+ dir_origin_path = "img/"
77
+ dir_save_path = "img_out/"
78
+ #-------------------------------------------------------------------------#
79
+ # heatmap_save_path 热力图的保存路径,默认保存在model_data下
80
+ #
81
+ # heatmap_save_path仅在mode='heatmap'有效
82
+ #-------------------------------------------------------------------------#
83
+ heatmap_save_path = "model_data/heatmap_vision.png"
84
+ #-------------------------------------------------------------------------#
85
+ # simplify 使用Simplify onnx
86
+ # onnx_save_path 指定了onnx的保存路径
87
+ #-------------------------------------------------------------------------#
88
+ simplify = True
89
+ onnx_save_path = "model_data/models.onnx"
90
+
91
+ if mode == "predict":
92
+ '''
93
+ 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
94
+ 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
95
+ 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
96
+ 在原图上利用矩阵的方式进行截取。
97
+ 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
98
+ 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
99
+ '''
100
+ while True:
101
+ img = input('Input image filename:')
102
+ try:
103
+ image = Image.open(img)
104
+ except:
105
+ print('Open Error! Try again!')
106
+ continue
107
+ else:
108
+ r_image = yolo.detect_image(image, crop = crop, count=count)
109
+ r_image.show()
110
+ r_image.save(dir_save_path + 'img_result.jpg')
111
+
112
+ elif mode == "video":
113
+ capture = cv2.VideoCapture(video_path)
114
+ if video_save_path != '':
115
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
116
+ size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
117
+ out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
118
+
119
+ ref, frame = capture.read()
120
+ if not ref:
121
+ raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
122
+
123
+ fps = 0.0
124
+ while(True):
125
+ t1 = time.time()
126
+ # 读取某一帧
127
+ ref, frame = capture.read()
128
+ if not ref:
129
+ break
130
+ # 格式转变,BGRtoRGB
131
+ frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
132
+ # 转变成Image
133
+ frame = Image.fromarray(np.uint8(frame))
134
+ # 进行检测
135
+ frame = np.array(yolo.detect_image(frame))
136
+ # RGBtoBGR满足opencv显示格式
137
+ frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
138
+
139
+ fps = ( fps + (1./(time.time()-t1)) ) / 2
140
+ print("fps= %.2f"%(fps))
141
+ frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
142
+
143
+ cv2.imshow("video",frame)
144
+ c= cv2.waitKey(1) & 0xff
145
+ if video_save_path != '':
146
+ out.write(frame)
147
+
148
+ if c==27:
149
+ capture.release()
150
+ break
151
+
152
+ print("Video Detection Done!")
153
+ capture.release()
154
+ if video_save_path != '':
155
+ print("Save processed video to the path :" + video_save_path)
156
+ out.release()
157
+ cv2.destroyAllWindows()
158
+
159
+ elif mode == "fps":
160
+ img = Image.open(fps_image_path)
161
+ tact_time = yolo.get_FPS(img, test_interval)
162
+ print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
163
+
164
+ elif mode == "dir_predict":
165
+ import os
166
+
167
+ from tqdm import tqdm
168
+
169
+ img_names = os.listdir(dir_origin_path)
170
+ for img_name in tqdm(img_names):
171
+ if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
172
+ image_path = os.path.join(dir_origin_path, img_name)
173
+ image = Image.open(image_path)
174
+ r_image = yolo.detect_image(image)
175
+ if not os.path.exists(dir_save_path):
176
+ os.makedirs(dir_save_path)
177
+ r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
178
+
179
+ elif mode == "heatmap":
180
+ while True:
181
+ img = input('Input image filename:')
182
+ try:
183
+ image = Image.open(img)
184
+ except:
185
+ print('Open Error! Try again!')
186
+ continue
187
+ else:
188
+ yolo.detect_heatmap(image, heatmap_save_path)
189
+
190
+ elif mode == "export_onnx":
191
+ yolo.convert_to_onnx(simplify, onnx_save_path)
192
+
193
+ else:
194
+ raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scipy
2
+ numpy
3
+ matplotlib
4
+ opencv_python
5
+ torch==1.8.1
6
+ torchvision==0.9.1
7
+ tqdm==4.60.0
8
+ Pillow==8.2.0
9
+ h5py==2.10.0
10
+ tensorboard
11
+ pyyaml==6.0
12
+ torchinfo
13
+ labelimg==1.8.6
14
+ streamlit==1.8.1
15
+ opencv-python-headless==4.5.2.52
yolo.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import ImageDraw, ImageFont
9
+
10
+ from nets.yolo import YoloBody
11
+ from nets.yolo_tiny import YoloBodytiny
12
+ from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
13
+ resize_image)
14
+ from utils.utils_bbox import DecodeBox
15
+ from get_yaml import get_config
16
+ import argparse
17
+ '''
18
+ 训练自己的数据集必看注释!
19
+ '''
20
+ class YOLO(object):
21
+ # 配置文件
22
+ config = get_config()
23
+ _defaults = {
24
+ #--------------------------------------------------------------------------#
25
+ # 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
26
+ # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
27
+ #
28
+ # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
29
+ # 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
30
+ # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
31
+ #--------------------------------------------------------------------------#
32
+ "class_names" : config['classes'],
33
+ "num_classes" : config['nc'],
34
+ #---------------------------------------------------------------------#
35
+ # anchors_path代表先验框对应的txt文件,一般不修改。
36
+ # anchors_mask用于帮助代码找到对应的先验框,一般不修改。
37
+ #---------------------------------------------------------------------#
38
+ "anchors_path" : 'model_data/yolo_anchors.txt',
39
+ "anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
40
+ #---------------------------------------------------------------------#
41
+ # 只有得分大于置信度的预测框会被保留下来
42
+ #---------------------------------------------------------------------#
43
+ "confidence" : 0.5, # 0.5,
44
+ #---------------------------------------------------------------------#
45
+ # 非极大抑制所用到的nms_iou大小
46
+ #---------------------------------------------------------------------#
47
+ "nms_iou" : 0.3, # 0.3,
48
+ #---------------------------------------------------------------------#
49
+ # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
50
+ # 在多次测试后,发现关闭letterbox_image直接resize的效果更好
51
+ #---------------------------------------------------------------------#
52
+ "letterbox_image" : config['letterbox_image'], # False,
53
+ }
54
+
55
+
56
+
57
+ @classmethod
58
+ def get_defaults(cls, n):
59
+ if n in cls._defaults:
60
+ return cls._defaults[n]
61
+ else:
62
+ return "Unrecognized attribute name '" + n + "'"
63
+
64
+ #---------------------------------------------------#
65
+ # 初始化YOLO
66
+ #---------------------------------------------------#
67
+ def __init__(self, opt, **kwargs):
68
+ self.__dict__.update(self._defaults)
69
+ for name, value in kwargs.items():
70
+ setattr(self, name, value)
71
+ self.phi = opt.phi
72
+ self.tiny = opt.tiny
73
+ self.cuda = opt.cuda
74
+ self.input_shape = [opt.shape,opt.shape]
75
+ self.model_path = opt.weights
76
+ self.phi = opt.phi
77
+ self.confidence = opt.confidence
78
+ self.nms_iou = opt.nms_iou
79
+ if self.tiny:
80
+ self.anchors_mask = [[3,4,5], [1,2,3]]
81
+ self.anchors_path = 'model_data/yolotiny_anchors.txt'
82
+ #---------------------------------------------------#
83
+ # 获得种类和先验框的数量
84
+ #---------------------------------------------------#
85
+ # self.class_names, self.num_classes = get_classes(self.classes_path)
86
+ self.anchors, self.num_anchors = get_anchors(self.anchors_path)
87
+ self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
88
+
89
+ #---------------------------------------------------#
90
+ # 画框设置不同的颜色
91
+ #---------------------------------------------------#
92
+ hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
93
+ self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
94
+ self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
95
+ self.generate()
96
+
97
+ #---------------------------------------------------#
98
+ # 生成模型
99
+ #---------------------------------------------------#
100
+ def generate(self, onnx=False):
101
+ #---------------------------------------------------#
102
+ # 建立yolo模型,载入yolo模型的权重
103
+ #---------------------------------------------------#
104
+
105
+ if not self.tiny:
106
+ self.net = YoloBody(self.anchors_mask, self.num_classes)
107
+ elif self.tiny:
108
+ self.net = YoloBodytiny(self.anchors_mask, self.num_classes, self.phi)
109
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
110
+ self.net.load_state_dict(torch.load(self.model_path, map_location=device))
111
+ self.net = self.net.eval()
112
+
113
+ print('{} model, anchors, and classes loaded.'.format(self.model_path))
114
+ if not onnx:
115
+ if self.cuda:
116
+ self.net = nn.DataParallel(self.net)
117
+ self.net = self.net.cuda()
118
+
119
+ #---------------------------------------------------#
120
+ # 检测图片
121
+ #---------------------------------------------------#
122
+ def detect_image(self, image, crop = False, count = False):
123
+ #---------------------------------------------------#
124
+ # 计算输入图片的高和宽
125
+ #---------------------------------------------------#
126
+ image_shape = np.array(np.shape(image)[0:2])
127
+ #---------------------------------------------------------#
128
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
129
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
130
+ #---------------------------------------------------------#
131
+ image = cvtColor(image)
132
+ #---------------------------------------------------------#
133
+ # 给图像增加灰条,实现不失真的resize
134
+ # 也可以直接resize进行识别
135
+ #---------------------------------------------------------#
136
+ image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
137
+ #---------------------------------------------------------#
138
+ # 添加上batch_size维度
139
+ #---------------------------------------------------------#
140
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
141
+
142
+ with torch.no_grad():
143
+ images = torch.from_numpy(image_data)
144
+ if self.cuda:
145
+ images = images.cuda()
146
+ #---------------------------------------------------------#
147
+ # 将图像输入网络当中进行预测!
148
+ #---------------------------------------------------------#
149
+ outputs = self.net(images)
150
+ outputs = self.bbox_util.decode_box(outputs)
151
+ #---------------------------------------------------------#
152
+ # 将预测框进行堆叠,然后进行非极大抑制
153
+ #---------------------------------------------------------#
154
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
155
+ image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
156
+
157
+ if results[0] is None:
158
+ return image
159
+
160
+ top_label = np.array(results[0][:, 6], dtype = 'int32')
161
+ top_conf = results[0][:, 4] * results[0][:, 5]
162
+ top_boxes = results[0][:, :4]
163
+ #---------------------------------------------------------#
164
+ # 设置字体与边框厚度
165
+ #---------------------------------------------------------#
166
+ font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
167
+ thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
168
+ #---------------------------------------------------------#
169
+ # 计数
170
+ #---------------------------------------------------------#
171
+ if count:
172
+ print("top_label:", top_label)
173
+ classes_nums = np.zeros([self.num_classes])
174
+ for i in range(self.num_classes):
175
+ num = np.sum(top_label == i)
176
+ if num > 0:
177
+ print(self.class_names[i], " : ", num)
178
+ classes_nums[i] = num
179
+ print("classes_nums:", classes_nums)
180
+ #---------------------------------------------------------#
181
+ # 是否进行目标的裁剪
182
+ #---------------------------------------------------------#
183
+ if crop:
184
+ for i, c in list(enumerate(top_label)):
185
+ top, left, bottom, right = top_boxes[i]
186
+ top = max(0, np.floor(top).astype('int32'))
187
+ left = max(0, np.floor(left).astype('int32'))
188
+ bottom = min(image.size[1], np.floor(bottom).astype('int32'))
189
+ right = min(image.size[0], np.floor(right).astype('int32'))
190
+
191
+ dir_save_path = "img_crop"
192
+ if not os.path.exists(dir_save_path):
193
+ os.makedirs(dir_save_path)
194
+ crop_image = image.crop([left, top, right, bottom])
195
+ crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
196
+ print("save crop_" + str(i) + ".png to " + dir_save_path)
197
+ #---------------------------------------------------------#
198
+ # 图像绘制
199
+ #---------------------------------------------------------#
200
+ for i, c in list(enumerate(top_label)):
201
+ predicted_class = self.class_names[int(c)]
202
+ box = top_boxes[i]
203
+ score = top_conf[i]
204
+
205
+ top, left, bottom, right = box
206
+
207
+ top = max(0, np.floor(top).astype('int32'))
208
+ left = max(0, np.floor(left).astype('int32'))
209
+ bottom = min(image.size[1], np.floor(bottom).astype('int32'))
210
+ right = min(image.size[0], np.floor(right).astype('int32'))
211
+
212
+ label = '{} {:.2f}'.format(predicted_class, score)
213
+ draw = ImageDraw.Draw(image)
214
+ label_size = draw.textsize(label, font)
215
+ label = label.encode('utf-8')
216
+ print(label, top, left, bottom, right)
217
+
218
+ if top - label_size[1] >= 0:
219
+ text_origin = np.array([left, top - label_size[1]])
220
+ else:
221
+ text_origin = np.array([left, top + 1])
222
+
223
+ for i in range(thickness):
224
+ draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
225
+ draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
226
+ draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
227
+ del draw
228
+
229
+ return image
230
+
231
+ def get_FPS(self, image, test_interval):
232
+ image_shape = np.array(np.shape(image)[0:2])
233
+ #---------------------------------------------------------#
234
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
235
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
236
+ #---------------------------------------------------------#
237
+ image = cvtColor(image)
238
+ #---------------------------------------------------------#
239
+ # 给图像增加灰条,实现不失真的resize
240
+ # 也可以直接resize进行识别
241
+ #---------------------------------------------------------#
242
+ image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
243
+ #---------------------------------------------------------#
244
+ # 添加上batch_size维度
245
+ #---------------------------------------------------------#
246
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
247
+
248
+ with torch.no_grad():
249
+ images = torch.from_numpy(image_data)
250
+ if self.cuda:
251
+ images = images.cuda()
252
+ #---------------------------------------------------------#
253
+ # 将图像输入网络当中进行预测!
254
+ #---------------------------------------------------------#
255
+ outputs = self.net(images)
256
+ outputs = self.bbox_util.decode_box(outputs)
257
+ #---------------------------------------------------------#
258
+ # 将预测框进行堆叠,然后进行非极大抑制
259
+ #---------------------------------------------------------#
260
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
261
+ image_shape, self.letterbox_image, conf_thres=self.confidence, nms_thres=self.nms_iou)
262
+
263
+ t1 = time.time()
264
+ for _ in range(test_interval):
265
+ with torch.no_grad():
266
+ #---------------------------------------------------------#
267
+ # 将图像输入网络当中进行预测!
268
+ #---------------------------------------------------------#
269
+ outputs = self.net(images)
270
+ outputs = self.bbox_util.decode_box(outputs)
271
+ #---------------------------------------------------------#
272
+ # 将预测框进行堆叠,然后进行非极大抑制
273
+ #---------------------------------------------------------#
274
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
275
+ image_shape, self.letterbox_image, conf_thres=self.confidence, nms_thres=self.nms_iou)
276
+
277
+ t2 = time.time()
278
+ tact_time = (t2 - t1) / test_interval
279
+ return tact_time
280
+
281
+ def detect_heatmap(self, image, heatmap_save_path):
282
+ import cv2
283
+ import matplotlib.pyplot as plt
284
+ def sigmoid(x):
285
+ y = 1.0 / (1.0 + np.exp(-x))
286
+ return y
287
+ #---------------------------------------------------------#
288
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
289
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
290
+ #---------------------------------------------------------#
291
+ image = cvtColor(image)
292
+ #---------------------------------------------------------#
293
+ # 给图像增加灰条,实现不失真的resize
294
+ # 也可以直接resize进行识别
295
+ #---------------------------------------------------------#
296
+ image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
297
+ #---------------------------------------------------------#
298
+ # 添加上batch_size维度
299
+ #---------------------------------------------------------#
300
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
301
+
302
+ with torch.no_grad():
303
+ images = torch.from_numpy(image_data)
304
+ if self.cuda:
305
+ images = images.cuda()
306
+ #---------------------------------------------------------#
307
+ # 将图像输入网络当中进行预测!
308
+ #---------------------------------------------------------#
309
+ outputs = self.net(images)
310
+
311
+ plt.imshow(image, alpha=1)
312
+ plt.axis('off')
313
+ mask = np.zeros((image.size[1], image.size[0]))
314
+ for sub_output in outputs:
315
+ sub_output = sub_output.cpu().numpy()
316
+ b, c, h, w = np.shape(sub_output)
317
+ sub_output = np.transpose(np.reshape(sub_output, [b, 3, -1, h, w]), [0, 3, 4, 1, 2])[0]
318
+ score = np.max(sigmoid(sub_output[..., 4]), -1)
319
+ score = cv2.resize(score, (image.size[0], image.size[1]))
320
+ normed_score = (score * 255).astype('uint8')
321
+ mask = np.maximum(mask, normed_score)
322
+
323
+ plt.imshow(mask, alpha=0.5, interpolation='nearest', cmap="jet")
324
+
325
+ plt.axis('off')
326
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
327
+ plt.margins(0, 0)
328
+ plt.savefig(heatmap_save_path, dpi=200, bbox_inches='tight', pad_inches = -0.1)
329
+ print("Save to the " + heatmap_save_path)
330
+ plt.show()
331
+
332
+ def convert_to_onnx(self, simplify, model_path):
333
+ import onnx
334
+ self.generate(onnx=True)
335
+
336
+ im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW
337
+ input_layer_names = ["images"]
338
+ output_layer_names = ["output"]
339
+
340
+ # Export the model
341
+ print(f'Starting export with onnx {onnx.__version__}.')
342
+ torch.onnx.export(self.net,
343
+ im,
344
+ f = model_path,
345
+ verbose = False,
346
+ opset_version = 12,
347
+ training = torch.onnx.TrainingMode.EVAL,
348
+ do_constant_folding = True,
349
+ input_names = input_layer_names,
350
+ output_names = output_layer_names,
351
+ dynamic_axes = None)
352
+
353
+ # Checks
354
+ model_onnx = onnx.load(model_path) # load onnx model
355
+ onnx.checker.check_model(model_onnx) # check onnx model
356
+
357
+ # Simplify onnx
358
+ if simplify:
359
+ import onnxsim
360
+ print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
361
+ model_onnx, check = onnxsim.simplify(
362
+ model_onnx,
363
+ dynamic_input_shape=False,
364
+ input_shapes=None)
365
+ assert check, 'assert check failed'
366
+ onnx.save(model_onnx, model_path)
367
+
368
+ print('Onnx model save as {}'.format(model_path))
369
+
370
+ def get_map_txt(self, image_id, image, class_names, map_out_path):
371
+ f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
372
+ image_shape = np.array(np.shape(image)[0:2])
373
+ #---------------------------------------------------------#
374
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
375
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
376
+ #---------------------------------------------------------#
377
+ image = cvtColor(image)
378
+ #---------------------------------------------------------#
379
+ # 给图像增加灰条,实现不失真的resize
380
+ # 也可以直接resize进行识别
381
+ #---------------------------------------------------------#
382
+ image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
383
+ #---------------------------------------------------------#
384
+ # 添加上batch_size维度
385
+ #---------------------------------------------------------#
386
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
387
+
388
+ with torch.no_grad():
389
+ images = torch.from_numpy(image_data)
390
+ if self.cuda:
391
+ images = images.cuda()
392
+ #---------------------------------------------------------#
393
+ # 将图像输入网络当中进行预测!
394
+ #---------------------------------------------------------#
395
+ outputs = self.net(images)
396
+ outputs = self.bbox_util.decode_box(outputs)
397
+ #---------------------------------------------------------#
398
+ # 将预测框进行堆叠,然后进行非极大抑制
399
+ #---------------------------------------------------------#
400
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
401
+ image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
402
+
403
+ if results[0] is None:
404
+ return
405
+
406
+ top_label = np.array(results[0][:, 6], dtype = 'int32')
407
+ top_conf = results[0][:, 4] * results[0][:, 5]
408
+ top_boxes = results[0][:, :4]
409
+
410
+ for i, c in list(enumerate(top_label)):
411
+ predicted_class = self.class_names[int(c)]
412
+ box = top_boxes[i]
413
+ score = str(top_conf[i])
414
+
415
+ top, left, bottom, right = box
416
+ if predicted_class not in class_names:
417
+ continue
418
+
419
+ f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
420
+
421
+ f.close()
422
+ return