Jianfeng777 commited on
Commit
808a161
1 Parent(s): 9a86fb3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +764 -0
app.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from mmdeploy_runtime import Detector, Segmentor, Classifier
3
+ import numpy as np
4
+ import gradio as gr
5
+ import math
6
+ import os
7
+
8
+
9
+ # Load models globally to avoid redundancy
10
+
11
+ helmet_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/helmet', device_name='cuda', device_id=0)
12
+ red_tree_segmentor = Segmentor(model_path='/mnt/e/AI/mmdeploy/output/red_tree', device_name='cuda', device_id=0)
13
+ vest_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/vest_detection', device_name='cuda', device_id=0)
14
+ car_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/car_calculation', device_name='cuda', device_id=0)
15
+ crack_classifier = Classifier(model_path='/mnt/e/AI/mmdeploy/output/crack_classification', device_name='cuda', device_id=0)
16
+ disease_object_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/disease_object_detection', device_name='cuda', device_id=0)
17
+ crack_segmentor = Segmentor(model_path='/mnt/e/AI/mmdeploy/output/crack_detection2', device_name='cuda', device_id=0)
18
+ leaf_disease_segmentor = Segmentor(model_path='/mnt/e/AI/mmdeploy/output/disease_leaf', device_name='cuda', device_id=0)
19
+ single_label_disease_segmentor = Segmentor(model_path='/mnt/e/AI/mmdeploy/output/disease_detection', device_name='cuda', device_id=0)
20
+ fall_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/fall_detection_fastercnn', device_name='cuda', device_id=0)
21
+ mask_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/mask_detection', device_name='cuda', device_id=0)
22
+ smoker_detector_object = Detector(model_path='/mnt/e/AI/mmdeploy/output/smoker_nonsmoker', device_name='cuda', device_id=0)
23
+
24
+ def smoker_detector(frame, confidence_threshold=0.3):
25
+ SMOKE_LABELS = ['smoker', 'nonsmoker'] # 新的标签列表
26
+ bboxes, labels, masks = smoker_detector_object(frame) # 修改检测器名字
27
+
28
+ # 获取有效的bbox索引
29
+ valid_indices = [(i, SMOKE_LABELS[label]) for i, label in enumerate(labels) if SMOKE_LABELS[label] == 'smoker' and bboxes[i][4] >= confidence_threshold]
30
+
31
+ smoker_count = 0
32
+
33
+ for i, label_name in valid_indices:
34
+ bbox = bboxes[i]
35
+ [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
36
+
37
+ if label_name == 'smoker':
38
+ color = (255, 0, 0) # 绿色用于'smoker'
39
+ smoker_count += 1
40
+
41
+ line_thickness = 2
42
+ font_scale = 0.8
43
+ cv2.rectangle(frame, (left, top), (right, bottom), color, thickness=line_thickness)
44
+ label_text = f"{label_name} ({score:.2f})"
45
+ cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)
46
+
47
+ if masks and masks[i].size:
48
+ mask = masks[i]
49
+ blue, green, red = cv2.split(frame)
50
+ if mask.shape == frame.shape[:2]:
51
+ mask_img = blue
52
+ else:
53
+ x0 = int(max(math.floor(bbox[0]) - 1, 0))
54
+ y0 = int(max(math.floor(bbox[1]) - 1, 0))
55
+ mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
56
+ cv2.bitwise_or(mask, mask_img, mask_img)
57
+ frame = cv2.merge([blue, green, red])
58
+
59
+ # 显示smoker的数量
60
+ frame_height, frame_width = frame.shape[:2]
61
+ summary_text = f"Smokers: {smoker_count}"
62
+ cv2.putText(frame, summary_text, (frame_width - 200, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
63
+
64
+ return frame, smoker_count
65
+
66
+
67
+
68
+ def crack_classification(frame, confidence_threshold=0.5):
69
+ # 定义标签
70
+ labels_dict = {0: 'Negative', 1: 'Positive'}
71
+
72
+ # 使用裂缝分类器进行预测
73
+ result = crack_classifier(frame)
74
+
75
+ # 获取最大置信度的标签ID
76
+ label_id, score = max(result, key=lambda x: x[1])
77
+
78
+ if label_id == 1 and score > confidence_threshold: # 如果检测到有裂缝,并且置信度超过阈值
79
+ seg = crack_segmentor(frame)
80
+ crack_pixel_count = np.sum(seg == 1)
81
+ current_palette = [(255, 255, 255), (255, 0, 0)] # 背景为白色,裂缝为红色
82
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
83
+ for label, color in enumerate(current_palette):
84
+ color_seg[seg == label, :] = color
85
+ frame = frame * 0.5 + color_seg * 0.5
86
+ frame = frame.astype(np.uint8)
87
+ elif label_id == 0 and score <= confidence_threshold:
88
+ crack_pixel_count = None
89
+ else:
90
+ crack_pixel_count = None
91
+ label_id = 0 # 这里我默认设置为0,即"Negative",但你可以根据实际情况进行调整
92
+
93
+ # 在图像右上角显示预测结果和置信度
94
+ label_text = labels_dict[label_id] + f" ({score:.2f})"
95
+ color = (255, 0, 0) if label_id == 1 else (0, 255, 0) # 裂缝为红色,否则为绿色
96
+ font_scale = 0.8
97
+ line_thickness = 2
98
+ text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)[0]
99
+ cv2.putText(frame, label_text, (frame.shape[1] - text_size[0] - 10, text_size[1] + 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)
100
+
101
+ return frame, labels_dict[label_id], crack_pixel_count
102
+
103
+
104
+ def crack_detection(frame):
105
+ # 使用裂缝检测器进行检测
106
+ seg = crack_segmentor(frame)
107
+ crack_pixel_count = np.sum(seg == 1)
108
+
109
+ # 如果检测到裂缝,进行可视化处理
110
+ if crack_pixel_count > 0:
111
+ current_palette = [(255, 255, 255), (255, 0, 0)] # 背景为白色,裂缝为红色
112
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
113
+ for label, color in enumerate(current_palette):
114
+ color_seg[seg == label, :] = color
115
+ frame = frame * 0.5 + color_seg * 0.5
116
+ frame = frame.astype(np.uint8)
117
+
118
+ # 在图像右上角显示检测到的裂缝像素数量
119
+ label_text = f"Crack Pixels: {crack_pixel_count}"
120
+ color = (255, 0, 0) if crack_pixel_count > 0 else (0, 255, 0) # 如果有裂缝则为红色,否则为绿色
121
+ font_scale = 0.8
122
+ line_thickness = 2
123
+ text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)[0]
124
+ cv2.putText(frame, label_text, (frame.shape[1] - text_size[0] - 10, text_size[1] + 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)
125
+
126
+ return frame, crack_pixel_count
127
+
128
+
129
+ def car_calculation(frame, confidence_threshold=0.7):
130
+ CAR_LABEL = 'car' # 这里只有一个车辆标签
131
+ bboxes, labels, masks = car_detector(frame)
132
+ valid_indices = [i for i, label in enumerate(labels) if bboxes[i][4] >= confidence_threshold]
133
+
134
+ car_count = 0
135
+
136
+ for i in valid_indices:
137
+ bbox = bboxes[i]
138
+ [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
139
+
140
+ color = (0, 255, 0) # 使用绿色标记车辆
141
+ line_thickness = 2
142
+ font_scale = 0.8
143
+
144
+ cv2.rectangle(frame, (left, top), (right, bottom), color, thickness=line_thickness)
145
+ label_text = CAR_LABEL + f" ({score:.2f})"
146
+ cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)
147
+
148
+ if masks and masks[i].size:
149
+ mask = masks[i]
150
+ blue, green, red = cv2.split(frame)
151
+ if mask.shape == frame.shape[:2]:
152
+ mask_img = blue
153
+ else:
154
+ x0 = int(max(math.floor(bbox[0]) - 1, 0))
155
+ y0 = int(max(math.floor(bbox[1]) - 1, 0))
156
+ mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
157
+ cv2.bitwise_or(mask, mask_img, mask_img)
158
+ frame = cv2.merge([blue, green, red])
159
+
160
+ car_count += 1
161
+
162
+ return frame, car_count
163
+
164
+
165
+
166
+ def vest_detection(frame, confidence_threshold=0.3):
167
+ VEST_LABELS = ['other_clothes', 'vest'] # 新的标签列表
168
+ bboxes, labels, masks = vest_detector(frame)
169
+
170
+ # 获取有效的bbox索引
171
+ valid_indices = [(i, VEST_LABELS[label]) for i, label in enumerate(labels) if VEST_LABELS[label] in ['vest', 'other_clothes'] and bboxes[i][4] >= confidence_threshold]
172
+
173
+ vest_count = 0
174
+ other_clothes_count = 0
175
+
176
+ for i, label_name in valid_indices:
177
+ bbox = bboxes[i]
178
+ [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
179
+
180
+ if label_name == 'vest':
181
+ color = (0, 255, 255) # 黄色用于'vest'
182
+ vest_count += 1
183
+ else:
184
+ color = (255, 0, 0) # 蓝色用于'other_clothes'
185
+ other_clothes_count += 1
186
+
187
+ line_thickness = 2
188
+ font_scale = 0.8
189
+ cv2.rectangle(frame, (left, top), (right, bottom), color, thickness=line_thickness)
190
+ label_text = f"{label_name} ({score:.2f})"
191
+ cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)
192
+
193
+ if masks and masks[i].size:
194
+ mask = masks[i]
195
+ blue, green, red = cv2.split(frame)
196
+ if mask.shape == frame.shape[:2]:
197
+ mask_img = blue
198
+ else:
199
+ x0 = int(max(math.floor(bbox[0]) - 1, 0))
200
+ y0 = int(max(math.floor(bbox[1]) - 1, 0))
201
+ mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
202
+ cv2.bitwise_or(mask, mask_img, mask_img)
203
+ frame = cv2.merge([blue, green, red])
204
+
205
+ # 显示vest和other_clothes的数量和置信度
206
+ frame_height, frame_width = frame.shape[:2]
207
+ summary_text = f"Vests: {vest_count}, Other Clothes: {other_clothes_count}"
208
+ cv2.putText(frame, summary_text, (frame_width - 300, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
209
+
210
+ return frame, vest_count, other_clothes_count
211
+
212
+ def detect_falls(frame, confidence_threshold=0.5):
213
+ # 假设输入图像是RGB格式,转换为BGR
214
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
215
+
216
+ LABELS = ['fall', 'person']
217
+ # 初始化摔倒计数器
218
+ fall_count = 0
219
+
220
+ # 使用模型进行检测
221
+ bboxes, labels, masks = fall_detector(frame)
222
+
223
+ for bbox, label_id in zip(bboxes, labels):
224
+ [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
225
+ if score < confidence_threshold:
226
+ continue
227
+ if LABELS[label_id] == 'fall': # 仅显示摔倒的标注框
228
+ cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), 2)
229
+ label_text = f"{LABELS[label_id]}: {int(score*100)}%"
230
+ cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
231
+ # 递增摔倒计数器
232
+ fall_count += 1
233
+
234
+ # 转换图像回RGB格式
235
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
236
+
237
+ # 返回处理后的图像和摔倒的数量
238
+ return frame, fall_count
239
+
240
+ def leaf_disease_detection(frame, confidence_threshold=0.3):
241
+ # 假设输入图像是RGB格式,转换为BGR
242
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
243
+
244
+ LABELS = ['disease']
245
+ # 初始化病害计数器
246
+ disease_count = 0
247
+ bboxes, labels, masks = disease_object_detector(frame)
248
+ indices = [i for i in range(len(bboxes))]
249
+ for index, bbox, label_id in zip(indices, bboxes, labels):
250
+ [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
251
+ if score < confidence_threshold:
252
+ continue
253
+ cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), 1)
254
+ label_text = f"{LABELS[label_id]}: {int(score*100)}%"
255
+ cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
256
+
257
+ if masks[index].size:
258
+ mask = masks[index]
259
+ blue, green, red = cv2.split(frame)
260
+ if mask.shape == frame.shape[:2]:
261
+ mask_img = blue
262
+ else:
263
+ x0 = int(max(math.floor(bbox[0]) - 1, 0))
264
+ y0 = int(max(math.floor(bbox[1]) - 1, 0))
265
+ mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
266
+ cv2.bitwise_or(mask, mask_img, mask_img)
267
+ frame = cv2.merge([blue, green, red])
268
+ # 递增病害计数器
269
+ disease_count += 1
270
+
271
+ # 转换图像回RGB格式
272
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
273
+
274
+ # 返回处理后的图像、病害计数和保存的图像路径
275
+ return frame, disease_count
276
+
277
+ def detect_masks(frame, confidence_threshold=0.5):
278
+ # 假设输入图像是RGB格式,转换为BGR
279
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
280
+
281
+ LABELS = ['unfit', 'mask', 'nomask']
282
+ # 初始化三个标签的计数器
283
+ mask_count, nomask_count, unfit_count = 0, 0, 0
284
+
285
+ # 使用模型进行检测
286
+ bboxes, labels, masks = mask_detector(frame)
287
+
288
+ for bbox, label_id in zip(bboxes, labels):
289
+ [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
290
+ if score < confidence_threshold:
291
+ continue
292
+
293
+ # 根据标签ID判断类别,并进行相应的计数
294
+ if LABELS[label_id] == 'mask':
295
+ mask_count += 1
296
+ color = (0, 255, 0)
297
+ elif LABELS[label_id] == 'nomask':
298
+ nomask_count += 1
299
+ color = (0, 0, 255)
300
+ elif LABELS[label_id] == 'unfit':
301
+ unfit_count += 1
302
+ color = (255, 0, 0)
303
+
304
+ cv2.rectangle(frame, (left, top), (right, bottom), color, 2)
305
+ label_text = f"{LABELS[label_id]}: {int(score*100)}%"
306
+ cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
307
+
308
+ # 转换图像回RGB格式
309
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
310
+
311
+ # 返回处理后的图像和每个标签的数量
312
+ return frame, mask_count, nomask_count, unfit_count
313
+
314
+ def helmet_detection(frame, confidence_threshold=0.3):
315
+
316
+ HEL_LABELS = ['head', 'helmet']
317
+ bboxes, labels, masks = helmet_detector(frame)
318
+ valid_indices = [i for i, bbox in enumerate(bboxes) if bbox[4] >= confidence_threshold]
319
+
320
+ helmet_count = 0
321
+ head_count = 0
322
+
323
+ for i in valid_indices:
324
+ bbox = bboxes[i]
325
+ label_id = labels[i]
326
+ [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
327
+
328
+ if HEL_LABELS[label_id] == 'helmet':
329
+ color = (0, 255, 0) # Green color for 'helmet'
330
+ line_thickness = 1
331
+ font_scale = 0.5
332
+ elif HEL_LABELS[label_id] == 'head':
333
+ color = (255, 0, 0) # Red color for 'head'
334
+ line_thickness = 1 # Increased line thickness for 'head' boxes
335
+ font_scale = 0.5 # Decreased font size for 'head' labels
336
+
337
+ cv2.rectangle(frame, (left, top), (right, bottom), color, thickness=line_thickness)
338
+ label_text = HEL_LABELS[label_id] + f" ({score:.2f})"
339
+ cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)
340
+
341
+ if HEL_LABELS[label_id] == 'helmet':
342
+ helmet_count += 1
343
+ elif HEL_LABELS[label_id] == 'head':
344
+ head_count += 1
345
+
346
+ return frame, helmet_count, head_count
347
+
348
+
349
+
350
+ def human_calculation(frame, confidence_threshold=0.3):
351
+ """
352
+ Process the given image to count the number of humans.
353
+ """
354
+ HEL_LABELS = ['head', 'helmet']
355
+ bboxes, labels, masks = helmet_detector(frame)
356
+
357
+ human_count = 0 # Initialize human count
358
+
359
+ for i in range(len(bboxes)):
360
+ bbox = bboxes[i]
361
+ label_id = labels[i]
362
+ score = bbox[4]
363
+
364
+ # Check if the label is 'head' or 'helmet' and the score is greater than confidence_threshold
365
+ if HEL_LABELS[label_id] in ['head', 'helmet'] and score > confidence_threshold:
366
+ human_count += 1
367
+ [left, top, right, bottom] = bbox[0:4].astype(int)
368
+ cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), thickness=1) # Red color for boxes
369
+ label_text = f"human ({score:.2f})" # Include confidence score in label_text
370
+ cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
371
+
372
+
373
+ return frame, human_count
374
+
375
+
376
+ def red_tree(img):
377
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
378
+ def get_palette(num_classes=2):
379
+ return [(255, 255, 255), (255, 0, 0)]
380
+ seg = red_tree_segmentor(img_bgr)
381
+ red_tree_pixel_count = np.sum(seg == 1)
382
+ current_palette = get_palette()
383
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
384
+ for label, color in enumerate(current_palette):
385
+ color_seg[seg == label, :] = color
386
+ color_seg_bgr = color_seg[..., ::-1]
387
+
388
+ img_bgr = img_bgr * 0.5 + color_seg_bgr * 0.5
389
+ img_bgr = img_bgr.astype(np.uint8)
390
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
391
+
392
+ return img_rgb, red_tree_pixel_count
393
+
394
+
395
+ def leaf_disease(img):
396
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
397
+
398
+ def get_palette(num_classes=3):
399
+ return [(255, 255, 255), (0, 255, 0), (255, 0, 0)]
400
+
401
+ seg = leaf_disease_segmentor(img_bgr)
402
+
403
+ leaf_pixel_count = np.sum(seg == 1)
404
+ disease_pixel_count = np.sum(seg == 2)
405
+
406
+ current_palette = get_palette()
407
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
408
+
409
+ for label, color in enumerate(current_palette):
410
+ color_seg[seg == label, :] = color
411
+
412
+ color_seg_bgr = color_seg[..., ::-1]
413
+ img_bgr = img_bgr * 0.5 + color_seg_bgr * 0.5
414
+ img_bgr = img_bgr.astype(np.uint8)
415
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
416
+
417
+ return img_rgb, leaf_pixel_count, disease_pixel_count
418
+
419
+ def single_label_disease(img):
420
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
421
+
422
+ def get_palette(num_classes=2):
423
+ return [(255, 255, 255), (255, 0, 0)]
424
+
425
+ seg = single_label_disease_segmentor(img_bgr)
426
+
427
+ disease_pixel_count = np.sum(seg == 1)
428
+
429
+ current_palette = get_palette()
430
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
431
+
432
+ for label, color in enumerate(current_palette):
433
+ color_seg[seg == label, :] = color
434
+
435
+ color_seg_bgr = color_seg[..., ::-1]
436
+ img_bgr = img_bgr * 0.5 + color_seg_bgr * 0.5
437
+ img_bgr = img_bgr.astype(np.uint8)
438
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
439
+
440
+
441
+ return img_rgb, disease_pixel_count
442
+
443
+
444
+
445
+ def get_image_examples():
446
+ image_dir = "/mnt/e/AI/mmdeploy/gradio/photo"
447
+ image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
448
+ image_files.sort(key=lambda f: int(''.join(filter(str.isdigit, f)))) # 按数字排序
449
+ example_choices = [
450
+ '红树林识别', '红树林识别', '红树林识别',
451
+ '安全帽检测', '安全帽检测', '安全帽检测',
452
+ '人数统计', '人数统计', '人数统计',
453
+ '反光衣检测','反光衣检测','反光衣检测',
454
+ '道路车辆统计', '道路车辆统计', '道路车辆统计',
455
+ '裂缝识别', '裂缝识别', '裂缝识别',
456
+ '吸烟检测','吸烟检测','吸烟检测',
457
+ '树叶病害识别1','树叶病害识别1','树叶病害识别1',
458
+ '树叶病害识别2','树叶病害识别2','树叶病害识别2',
459
+ '树叶病害检测3','树叶病害检测3','树叶病害检测3',
460
+ '摔倒检测','摔倒检测','摔倒检测',
461
+ '口罩佩戴检测','口罩佩戴检测','口罩佩戴检测',
462
+ ]
463
+
464
+ confidence_thresholds = [
465
+ 0, 0, 0,
466
+ 0.7, 0.8, 0.6,
467
+ 0.3, 0.8, 0.5,
468
+ 0.8, 0.7, 0.8,
469
+ 0.5, 0.2, 0.7,
470
+ 0, 0, 0,
471
+ 0.6, 0.9, 0.5,
472
+ 0, 0, 0,
473
+ 0, 0, 0,
474
+ 0.4, 0.4, 0.5,
475
+ 0.9, 0.9, 0.5,
476
+ 0.8, 0.6, 0.5
477
+ ]
478
+ examples = [[example_choices[i], f"{image_dir}/{image_file}", confidence_thresholds[i]] for i, image_file in enumerate(image_files)]
479
+ return examples
480
+
481
+
482
+
483
+
484
+ model_choices = ['红树林识别','裂缝识别','树叶病害识别1','树叶病害识别2','树叶病害检测3', '安全帽检测','反光衣检测', '吸烟检测','摔倒检测', '口罩佩戴检测','人数统计','道路车辆统计']
485
+
486
+
487
+ def create_blank_image(width=640, height=480, color=(255, 255, 255)):
488
+ blank_image = np.zeros((height, width, 3), np.uint8)
489
+ blank_image[:, :] = color
490
+ return blank_image
491
+
492
+ def process_image(model_choice, image_array=None, confidence_threshold=0.3):
493
+ output_text = '当前未有图片输入,请上传图片后再次点击运行。'
494
+
495
+ if image_array is None:
496
+ img = create_blank_image()
497
+ else:
498
+ if model_choice not in model_choices:
499
+ model_choice = "安全帽检测"
500
+ # 以下是模型选择和执行逻辑
501
+ if model_choice == "红树林识别":
502
+ img, red_tree_pixel_count = red_tree(image_array) # 语义分割模型
503
+ output_text = f"红树林的像素点有 {red_tree_pixel_count} 个。"
504
+ elif model_choice == "安全帽检测":
505
+ img, helmet_count, head_count = helmet_detection(image_array, confidence_threshold)
506
+ output_text = f"佩戴安全帽的人数为:{helmet_count},未佩戴安全帽的人数为:{head_count}。"
507
+ elif model_choice == "人数统计":
508
+ img, human_count = human_calculation(image_array, confidence_threshold)
509
+ output_text = f"该图片人员总人数为: {human_count}。"
510
+ elif model_choice == "反光衣检测":
511
+ img, vest_count, other_clothes_count= vest_detection(image_array, confidence_threshold)
512
+ output_text = f"该图片中总有 {vest_count} 人配备了反光衣,{other_clothes_count} 人没有配备反光衣。"
513
+ elif model_choice == "道路车辆统计":
514
+ img, car_count = car_calculation(image_array, confidence_threshold)
515
+ output_text = f"该道路上目前共有 {car_count} 台车辆。"
516
+ elif model_choice == "裂缝识别":
517
+ img, crack_result, crack_pixel_count = crack_classification(image_array, confidence_threshold)
518
+ if crack_result == "Positive":
519
+ output_text = f"该图片内存在裂缝,裂缝的像素点有 {crack_pixel_count} 个。"
520
+ else:
521
+ output_text = "该图片不存在裂缝。"
522
+ elif model_choice == "树叶病害检测3":
523
+ img, disease_count = leaf_disease_detection(image_array, confidence_threshold)
524
+ if disease_count > 0:
525
+ output_text = f"共检测到 {disease_count} 处病害。"
526
+ else:
527
+ output_text = "并未检测到病害。"
528
+ elif model_choice == "吸烟检测":
529
+ img, smoker_count = smoker_detector(image_array, confidence_threshold)
530
+ output_text = f"当前图片有 {smoker_count} 人在吸烟。"
531
+ elif model_choice == "树叶病害识别1":
532
+ img, leaf_pixel_count, disease_pixel_count = leaf_disease(image_array) # 语义分割模型
533
+ if disease_pixel_count == 0:
534
+ output_text = "该树叶并未出现病害。"
535
+ else:
536
+ output_text = f"病害的像素点有 {disease_pixel_count} 个。"
537
+ elif model_choice == "树叶病害识别2":
538
+ img, disease_pixel_count = single_label_disease(image_array) # 语义分割模型
539
+ output_text = f"病害的像素点有 {disease_pixel_count} 个。"
540
+ elif model_choice == "摔倒检测": # 您可以根据实际情况调整模型选择的名称
541
+ img, fall_count = detect_falls(image_array,confidence_threshold)
542
+ output_text = f"图像中摔倒的人数为 {fall_count} 人。"
543
+ elif model_choice == "口罩佩戴检测": # 您可以根据实际情况调整模型选择的名称
544
+ img, mask_count, nomask_count, unfit_count = detect_masks(image_array,confidence_threshold)
545
+ output_text = f"当前佩戴口罩的人数为 {mask_count},未正确佩戴口罩的人数为 {unfit_count},没有佩戴口罩的人数为 {nomask_count}。"
546
+
547
+ return img, output_text
548
+
549
+ def process_video(model_choice, video=None, confidence_threshold=0.3):
550
+
551
+ # 内部函数:创建空白视频
552
+ def create_blank_video(filename, duration=5, fps=30, width=640, height=480, color=(255, 255, 255)):
553
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用mp4v编解码器
554
+ out = cv2.VideoWriter(filename, fourcc, fps, (width, height))
555
+ blank_image = np.zeros((height, width, 3), np.uint8)
556
+ blank_image[:, :] = color
557
+ for _ in range(int(fps * duration)):
558
+ out.write(blank_image)
559
+ out.release()
560
+
561
+ # 检查视频是否存在
562
+ if video is None:
563
+ video_output_path = '/mnt/e/AI/mmdeploy/gradio/video/none.mp4'
564
+ create_blank_video(video_output_path)
565
+ output_text2 = '当前未有视频输入,请上传视频后再次点击运行。'
566
+ return video_output_path, output_text2
567
+ else:
568
+ video_output_path = '/mnt/e/AI/mmdeploy/gradio/video/output_video.mp4'
569
+ cap = cv2.VideoCapture(video)
570
+ if not cap.isOpened():
571
+ raise ValueError("无法打开视频文件")
572
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
573
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
574
+ # 获取输入视频的分辨率
575
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
576
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
577
+ # 使用帧采样的逻辑,但考虑到所有帧都需要处理,我们使用间隔为1的采样。
578
+ clip_len, frame_interval, num_clips = 1, 1, num_frames
579
+ avg_interval = (num_frames - clip_len * frame_interval + 1) / float(num_clips)
580
+ frame_inds = []
581
+ for i in range(num_clips):
582
+ clip_offset = int(i * avg_interval + avg_interval / 2.0)
583
+ for j in range(clip_len):
584
+ ind = (j * frame_interval + clip_offset) % num_frames
585
+ if num_frames <= clip_len * frame_interval - 1:
586
+ ind = j % num_frames
587
+ frame_inds.append(ind)
588
+
589
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
590
+ processed_frames = []
591
+ for frame_id in sorted(frame_inds):
592
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id) # 设置读取特定的帧
593
+ ret, frame = cap.read()
594
+ if not ret:
595
+ break
596
+ # 将帧率添加到视频的左上角
597
+ cv2.putText(frame, "FPS: {}".format(fps), (10, 30),
598
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
599
+
600
+ if model_choice == "红树林识别":
601
+ # 在此处调用红树林模型处理帧
602
+ processed_frame, red_tree_pixel_count = red_tree(frame)
603
+ # 在处理后的帧的右上角添加文字
604
+ cv2.putText(processed_frame, "Number of pixels in this frame: {}".format(red_tree_pixel_count),
605
+ (processed_frame.shape[1] - 300, 30), cv2.FONT_HERSHEY_SIMPLEX,
606
+ 0.7, (255, 255, 255), 2)
607
+
608
+
609
+ elif model_choice == "安全帽检测":
610
+ # 在此处调用安全帽检测模型处理帧
611
+ processed_frame, helmet_count, head_count = helmet_detection(frame, confidence_threshold)
612
+
613
+ cv2.putText(processed_frame, "Number of people wearing helmets: {}".format(helmet_count),
614
+ (processed_frame.shape[1] - 400, 30), cv2.FONT_HERSHEY_SIMPLEX,
615
+ 0.7, (255, 255, 255), 2)
616
+
617
+ # 在上一行文字下方添加表示未佩戴安全帽的人数的文字
618
+ cv2.putText(processed_frame, "Number of people without helmets: {}".format(head_count - helmet_count),
619
+ (processed_frame.shape[1] - 450, 60), cv2.FONT_HERSHEY_SIMPLEX,
620
+ 0.7, (255, 255, 255), 2)
621
+
622
+
623
+ elif model_choice == "人数统计":
624
+ # 在此处调用人数统计模型处理帧
625
+ processed_frame, human_count = human_calculation(frame, confidence_threshold)
626
+ cv2.putText(processed_frame, "Current number of people: {}".format(human_count),
627
+ (processed_frame.shape[1] - 300, 30), cv2.FONT_HERSHEY_SIMPLEX,
628
+ 0.7, (255, 255, 255), 2)
629
+
630
+ elif model_choice == "反光衣检测":
631
+ # 在此处调用反光衣检测模型处理帧
632
+ processed_frame, vest_count, other_clothes_count= vest_detection(image_array, confidence_threshold)
633
+ cv2.putText(processed_frame, "Number of reflective vests: {}".format(vest_count),
634
+ (processed_frame.shape[1] - 350, 30), cv2.FONT_HERSHEY_SIMPLEX,
635
+ 0.7, (255, 255, 255), 2)
636
+ cv2.putText(processed_frame, "Number without reflective vests: {}".format(other_clothes_count),
637
+ (processed_frame.shape[1] - 450, 60), cv2.FONT_HERSHEY_SIMPLEX,
638
+ 0.7, (255, 255, 255), 2)
639
+
640
+ elif model_choice == "道路车辆统计":
641
+ # 在此处调用道路车辆统计模型处理帧
642
+ processed_frame, car_count = car_calculation(frame, confidence_threshold)
643
+ cv2.putText(processed_frame, "Number of vehicles: {}".format(car_count),
644
+ (processed_frame.shape[1] - 250, 30), cv2.FONT_HERSHEY_SIMPLEX,
645
+ 0.7, (255, 255, 255), 2)
646
+
647
+ elif model_choice == "裂缝识别":
648
+ # 在此处调用裂缝识别模型处理帧
649
+ processed_frame, crack_pixel_count= crack_detection(frame)
650
+
651
+ elif model_choice == "树叶病害检测3":
652
+ # 在此处调用树叶病害检测模型处理帧
653
+ processed_frame, disease_count= leaf_disease_detection(frame, confidence_threshold)
654
+ # 在图像右上角显示叶片的病害数量
655
+ label_text = f"Leaf Disease Count: {disease_count}"
656
+ color = (0, 0, 255) # 红色
657
+ font_scale = 0.8
658
+ line_thickness = 2
659
+ text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)[0]
660
+ cv2.putText(processed_frame, label_text, (processed_frame.shape[1] - text_size[0] - 10, text_size[1] + 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)
661
+
662
+
663
+
664
+ elif model_choice == "吸烟检测":
665
+ # 在此处调用吸烟检测模型处理帧
666
+ processed_frame, smoker_count = smoker_detector(frame, confidence_threshold)
667
+
668
+ # 准备要显示的文本
669
+ text = f"吸烟者数量: {smoker_count}"
670
+
671
+ # 获取文本大小
672
+ text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
673
+
674
+ # 计算文本的位置,以便它出现在帧的右上角
675
+ text_position = (processed_frame.shape[1] - text_size[0] - 10, text_size[1] + 10)
676
+
677
+ # 将文本绘制到帧上
678
+ cv2.putText(processed_frame, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
679
+
680
+ elif model_choice == "树叶病害识别1":
681
+ # 在此处调用树叶病害识别模型处理帧
682
+ processed_frame, _, disease_pixel_count = leaf_disease(frame)
683
+ text = f"Current disease pixel count on the leaf: {disease_pixel_count}"
684
+ text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
685
+ cv2.putText(processed_frame, text,
686
+ (processed_frame.shape[1] - text_size[0] - 10, text_size[1] + 10),
687
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
688
+
689
+
690
+ elif model_choice == "树叶病害识别2":
691
+ # 在此处调用树叶病害识别模型处理帧
692
+ processed_frame, disease_pixel_count= single_label_disease(frame)
693
+ text = f"Current disease pixel count on the leaf: {disease_pixel_count}"
694
+ text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
695
+ cv2.putText(processed_frame, text,
696
+ (processed_frame.shape[1] - text_size[0] - 10, text_size[1] + 10),
697
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
698
+
699
+ elif model_choice == "口罩佩戴检测": # 您可以根据实际情况调整模型选择的名称
700
+ processed_frame, mask_count, nomask_count, unfit_count = detect_masks(frame,confidence_threshold)
701
+ cv2.putText(processed_frame, "Number wearing masks: {}".format(mask_count),
702
+ (processed_frame.shape[1] - 350, 30), cv2.FONT_HERSHEY_SIMPLEX,
703
+ 0.7, (255, 255, 255), 2)
704
+
705
+ cv2.putText(processed_frame, "Number not wearing masks: {}".format(nomask_count),
706
+ (processed_frame.shape[1] - 400, 60), cv2.FONT_HERSHEY_SIMPLEX,
707
+ 0.7, (255, 255, 255), 2)
708
+
709
+ cv2.putText(processed_frame, "Number wearing masks incorrectly: {}".format(unfit_count),
710
+ (processed_frame.shape[1] - 500, 90), cv2.FONT_HERSHEY_SIMPLEX,
711
+ 0.7, (255, 255, 255), 2)
712
+
713
+ elif model_choice == "摔倒检测":
714
+
715
+ processed_frame, fall_count= detect_falls(frame,confidence_threshold)
716
+ cv2.putText(processed_frame, "Number of people who fell: {}".format(fall_count),
717
+ (processed_frame.shape[1] - 350, 30), cv2.FONT_HERSHEY_SIMPLEX,
718
+ 0.7, (255, 255, 255), 2)
719
+
720
+ processed_frames.append(processed_frame)
721
+ out = cv2.VideoWriter(video_output_path, fourcc, fps, (width,height))
722
+ for frame in processed_frames:
723
+ out.write(frame)
724
+ out.release()
725
+ cap.release()
726
+ output_text2 = '请点击蓝色按钮下载视频。'
727
+ return video_output_path, output_text2
728
+
729
+ with gr.Blocks() as demo:
730
+ gr.Markdown("# <center>启云科技AI识别示例样板V1.12</center>")
731
+ gr.Markdown("请上传图像或视频进行预测")
732
+ with gr.Tab("AI图像处理"):
733
+ with gr.Row():
734
+ image_input2 = gr.Image(label="上传图像", type="numpy")
735
+ with gr.Column():
736
+ image_input1 = gr.Dropdown(choices=model_choices, label="模型选择")
737
+ image_input3 = gr.Slider(minimum=0, maximum=1, step=0.1, label="置信度阈值")
738
+ with gr.Row():
739
+ image_output1 = gr.Image(label="处理后的图像", type="numpy")
740
+ with gr.Column():
741
+ image_output2 = gr.Textbox(label="图像输出信息")
742
+ image_button = gr.Button('请点击按钮进行图像预测')
743
+ gr.Examples(get_image_examples(),inputs=[image_input1, image_input2, image_input3],outputs=[image_output1, image_output2], fn=process_image ,examples_per_page=6 ,cache_examples=True)
744
+ with gr.Tab("AI视频处理"):
745
+
746
+ with gr.Row():
747
+ video_input2 = gr.Video(label = '上传视频', format='mp4',interactive = True)
748
+ with gr.Column():
749
+ video_input1 = gr.Dropdown(choices=model_choices, label="模型选择")
750
+ video_input3 = gr.Slider(minimum=0, maximum=1,
751
+ step=0.1, label="置信度阈值")
752
+ with gr.Row():
753
+ video_output1 = gr.File(label='处理后的视频', type='file')
754
+ with gr.Column():
755
+ video_output2 = gr.Textbox(label = '视频输出信息')
756
+ video_button = gr.Button('请点击按钮进行视频预测')
757
+ with gr.Accordion("平台简介"):
758
+ gr.Markdown("红树林识别模型、裂缝识别模型、树叶病害识别模型、安全帽检测模型、反光衣检测模型、吸烟检测模型、口罩佩戴检测、摔倒检测、人数统计模型及道路车辆统计模型展示平台。")
759
+ image_button.click(process_image, inputs = [image_input1, image_input2, image_input3], outputs=[image_output1, image_output2])
760
+ video_button.click(process_video, inputs=[video_input1,video_input2, video_input3], outputs=[video_output1, video_output2])
761
+
762
+ demo.launch(share=True)
763
+
764
+