File size: 36,000 Bytes
f3f27e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
import cv2
from mmdeploy_runtime import Detector, Segmentor, Classifier
import numpy as np
import gradio as gr
import math
import os


# Load models globally to avoid redundancy

helmet_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/helmet', device_name='cuda', device_id=0)
red_tree_segmentor = Segmentor(model_path='/mnt/e/AI/mmdeploy/output/red_tree', device_name='cuda', device_id=0)
vest_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/vest_detection', device_name='cuda', device_id=0)
car_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/car_calculation', device_name='cuda', device_id=0)
crack_classifier = Classifier(model_path='/mnt/e/AI/mmdeploy/output/crack_classification', device_name='cuda', device_id=0)
disease_object_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/disease_object_detection', device_name='cuda', device_id=0)
crack_segmentor = Segmentor(model_path='/mnt/e/AI/mmdeploy/output/crack_detection2', device_name='cuda', device_id=0)
leaf_disease_segmentor = Segmentor(model_path='/mnt/e/AI/mmdeploy/output/disease_leaf', device_name='cuda', device_id=0)
single_label_disease_segmentor = Segmentor(model_path='/mnt/e/AI/mmdeploy/output/disease_detection', device_name='cuda', device_id=0)
fall_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/fall_detection_fastercnn', device_name='cuda', device_id=0)
mask_detector = Detector(model_path='/mnt/e/AI/mmdeploy/output/mask_detection', device_name='cuda', device_id=0)
smoker_detector_object = Detector(model_path='/mnt/e/AI/mmdeploy/output/smoker_nonsmoker', device_name='cuda', device_id=0)

def smoker_detector(frame, confidence_threshold=0.3):
    SMOKE_LABELS = ['smoker', 'nonsmoker']  # 新的标签列表
    bboxes, labels, masks = smoker_detector_object(frame)  # 修改检测器名字

    # 获取有效的bbox索引
    valid_indices = [(i, SMOKE_LABELS[label]) for i, label in enumerate(labels) if SMOKE_LABELS[label] == 'smoker' and bboxes[i][4] >= confidence_threshold]

    smoker_count = 0

    for i, label_name in valid_indices:
        bbox = bboxes[i]
        [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]

        if label_name == 'smoker':
            color = (255, 0, 0)  # 绿色用于'smoker'
            smoker_count += 1
        
        line_thickness = 2
        font_scale = 0.8
        cv2.rectangle(frame, (left, top), (right, bottom), color, thickness=line_thickness)
        label_text = f"{label_name} ({score:.2f})"
        cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)

        if masks and masks[i].size:
            mask = masks[i]
            blue, green, red = cv2.split(frame)
            if mask.shape == frame.shape[:2]:
                mask_img = blue
            else:
                x0 = int(max(math.floor(bbox[0]) - 1, 0))
                y0 = int(max(math.floor(bbox[1]) - 1, 0))
                mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
            cv2.bitwise_or(mask, mask_img, mask_img)
            frame = cv2.merge([blue, green, red])

    # 显示smoker的数量
    frame_height, frame_width = frame.shape[:2]
    summary_text = f"Smokers: {smoker_count}"
    cv2.putText(frame, summary_text, (frame_width - 200, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

    return frame, smoker_count



def crack_classification(frame, confidence_threshold=0.5):
    # 定义标签
    labels_dict = {0: 'Negative', 1: 'Positive'}
    
    # 使用裂缝分类器进行预测
    result = crack_classifier(frame)
    
    # 获取最大置信度的标签ID
    label_id, score = max(result, key=lambda x: x[1])

    if label_id == 1 and score > confidence_threshold:  # 如果检测到有裂缝,并且置信度超过阈值
        seg = crack_segmentor(frame)
        crack_pixel_count = np.sum(seg == 1)
        current_palette = [(255, 255, 255), (255, 0, 0)]  # 背景为白色,裂缝为红色
        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
        for label, color in enumerate(current_palette):
            color_seg[seg == label, :] = color
        frame = frame * 0.5 + color_seg * 0.5
        frame = frame.astype(np.uint8)
    elif label_id == 0 and score <= confidence_threshold:
        crack_pixel_count = None
    else:
        crack_pixel_count = None
        label_id = 0  # 这里我默认设置为0,即"Negative",但你可以根据实际情况进行调整

    # 在图像右上角显示预测结果和置信度
    label_text = labels_dict[label_id] + f" ({score:.2f})"
    color = (255, 0, 0) if label_id == 1 else (0, 255, 0)  # 裂缝为红色,否则为绿色
    font_scale = 0.8
    line_thickness = 2
    text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)[0]
    cv2.putText(frame, label_text, (frame.shape[1] - text_size[0] - 10, text_size[1] + 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)

    return frame, labels_dict[label_id], crack_pixel_count


def crack_detection(frame):
    # 使用裂缝检测器进行检测
    seg = crack_segmentor(frame)
    crack_pixel_count = np.sum(seg == 1)

    # 如果检测到裂缝,进行可视化处理
    if crack_pixel_count > 0:
        current_palette = [(255, 255, 255), (255, 0, 0)]  # 背景为白色,裂缝为红色
        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
        for label, color in enumerate(current_palette):
            color_seg[seg == label, :] = color
        frame = frame * 0.5 + color_seg * 0.5
        frame = frame.astype(np.uint8)

    # 在图像右上角显示检测到的裂缝像素数量
    label_text = f"Crack Pixels: {crack_pixel_count}"
    color = (255, 0, 0) if crack_pixel_count > 0 else (0, 255, 0)  # 如果有裂缝则为红色,否则为绿色
    font_scale = 0.8
    line_thickness = 2
    text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)[0]
    cv2.putText(frame, label_text, (frame.shape[1] - text_size[0] - 10, text_size[1] + 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)

    return frame, crack_pixel_count


def car_calculation(frame, confidence_threshold=0.7):
    CAR_LABEL = 'car'  # 这里只有一个车辆标签
    bboxes, labels, masks = car_detector(frame)
    valid_indices = [i for i, label in enumerate(labels) if bboxes[i][4] >= confidence_threshold]
    
    car_count = 0
    
    for i in valid_indices:
        bbox = bboxes[i]
        [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
        
        color = (0, 255, 0)  # 使用绿色标记车辆
        line_thickness = 2
        font_scale = 0.8

        cv2.rectangle(frame, (left, top), (right, bottom), color, thickness=line_thickness)
        label_text = CAR_LABEL + f" ({score:.2f})"
        cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)
        
        if masks and masks[i].size:
            mask = masks[i]
            blue, green, red = cv2.split(frame)
            if mask.shape == frame.shape[:2]:
                mask_img = blue
            else:
                x0 = int(max(math.floor(bbox[0]) - 1, 0))
                y0 = int(max(math.floor(bbox[1]) - 1, 0))
                mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
            cv2.bitwise_or(mask, mask_img, mask_img)
            frame = cv2.merge([blue, green, red])
        
        car_count += 1

    return frame, car_count



def vest_detection(frame, confidence_threshold=0.3):
    VEST_LABELS = ['other_clothes', 'vest']  # 新的标签列表
    bboxes, labels, masks = vest_detector(frame)
    
    # 获取有效的bbox索引
    valid_indices = [(i, VEST_LABELS[label]) for i, label in enumerate(labels) if VEST_LABELS[label] in ['vest', 'other_clothes'] and bboxes[i][4] >= confidence_threshold]

    vest_count = 0
    other_clothes_count = 0

    for i, label_name in valid_indices:
        bbox = bboxes[i]
        [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]

        if label_name == 'vest':
            color = (0, 255, 255)  # 黄色用于'vest'
            vest_count += 1
        else:
            color = (255, 0, 0)  # 蓝色用于'other_clothes'
            other_clothes_count += 1
        
        line_thickness = 2
        font_scale = 0.8
        cv2.rectangle(frame, (left, top), (right, bottom), color, thickness=line_thickness)
        label_text = f"{label_name} ({score:.2f})"
        cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)

        if masks and masks[i].size:
            mask = masks[i]
            blue, green, red = cv2.split(frame)
            if mask.shape == frame.shape[:2]:
                mask_img = blue
            else:
                x0 = int(max(math.floor(bbox[0]) - 1, 0))
                y0 = int(max(math.floor(bbox[1]) - 1, 0))
                mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
            cv2.bitwise_or(mask, mask_img, mask_img)
            frame = cv2.merge([blue, green, red])

    # 显示vest和other_clothes的数量和置信度
    frame_height, frame_width = frame.shape[:2]
    summary_text = f"Vests: {vest_count}, Other Clothes: {other_clothes_count}"
    cv2.putText(frame, summary_text, (frame_width - 300, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

    return frame, vest_count, other_clothes_count

def detect_falls(frame, confidence_threshold=0.5):
    # 假设输入图像是RGB格式,转换为BGR
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

    LABELS = ['fall', 'person']
    # 初始化摔倒计数器
    fall_count = 0

    # 使用模型进行检测
    bboxes, labels, masks = fall_detector(frame)
    
    for bbox, label_id in zip(bboxes, labels):
        [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
        if score < confidence_threshold:
            continue
        if LABELS[label_id] == 'fall':  # 仅显示摔倒的标注框
            cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), 2)
            label_text = f"{LABELS[label_id]}: {int(score*100)}%"
            cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
            # 递增摔倒计数器
            fall_count += 1

    # 转换图像回RGB格式
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 返回处理后的图像和摔倒的数量
    return frame, fall_count

def leaf_disease_detection(frame, confidence_threshold=0.3):
    # 假设输入图像是RGB格式,转换为BGR
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

    LABELS = ['disease']
    # 初始化病害计数器
    disease_count = 0
    bboxes, labels, masks = disease_object_detector(frame)
    indices = [i for i in range(len(bboxes))]
    for index, bbox, label_id in zip(indices, bboxes, labels):
        [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
        if score < confidence_threshold:
            continue
        cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), 1)
        label_text = f"{LABELS[label_id]}: {int(score*100)}%"
        cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)

        if masks[index].size:
            mask = masks[index]
            blue, green, red = cv2.split(frame)
            if mask.shape == frame.shape[:2]:
                mask_img = blue
            else:
                x0 = int(max(math.floor(bbox[0]) - 1, 0))
                y0 = int(max(math.floor(bbox[1]) - 1, 0))
                mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
            cv2.bitwise_or(mask, mask_img, mask_img)
            frame = cv2.merge([blue, green, red])
        # 递增病害计数器
        disease_count += 1

    # 转换图像回RGB格式
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 返回处理后的图像、病害计数和保存的图像路径
    return frame, disease_count

def detect_masks(frame, confidence_threshold=0.5):
    # 假设输入图像是RGB格式,转换为BGR
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

    LABELS = ['unfit', 'mask', 'nomask']
    # 初始化三个标签的计数器
    mask_count, nomask_count, unfit_count = 0, 0, 0

    # 使用模型进行检测
    bboxes, labels, masks = mask_detector(frame)
    
    for bbox, label_id in zip(bboxes, labels):
        [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
        if score < confidence_threshold:
            continue

        # 根据标签ID判断类别,并进行相应的计数
        if LABELS[label_id] == 'mask':
            mask_count += 1
            color = (0, 255, 0)
        elif LABELS[label_id] == 'nomask':
            nomask_count += 1
            color = (0, 0, 255)
        elif LABELS[label_id] == 'unfit':
            unfit_count += 1
            color = (255, 0, 0)
        
        cv2.rectangle(frame, (left, top), (right, bottom), color, 2)
        label_text = f"{LABELS[label_id]}: {int(score*100)}%"
        cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

    # 转换图像回RGB格式
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 返回处理后的图像和每个标签的数量
    return frame, mask_count, nomask_count, unfit_count

def helmet_detection(frame, confidence_threshold=0.3):

    HEL_LABELS = ['head', 'helmet']
    bboxes, labels, masks = helmet_detector(frame)
    valid_indices = [i for i, bbox in enumerate(bboxes) if bbox[4] >= confidence_threshold]
    
    helmet_count = 0
    head_count = 0
    
    for i in valid_indices:
        bbox = bboxes[i]
        label_id = labels[i]
        [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
        
        if HEL_LABELS[label_id] == 'helmet':
            color = (0, 255, 0)  # Green color for 'helmet'
            line_thickness = 1
            font_scale = 0.5
        elif HEL_LABELS[label_id] == 'head':
            color = (255, 0, 0)  # Red color for 'head'
            line_thickness = 1  # Increased line thickness for 'head' boxes
            font_scale = 0.5  # Decreased font size for 'head' labels

        cv2.rectangle(frame, (left, top), (right, bottom), color, thickness=line_thickness)
        label_text = HEL_LABELS[label_id] + f" ({score:.2f})"
        cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)

        if HEL_LABELS[label_id] == 'helmet':
            helmet_count += 1
        elif HEL_LABELS[label_id] == 'head':
            head_count += 1

    return frame, helmet_count, head_count



def human_calculation(frame, confidence_threshold=0.3):
    """
    Process the given image to count the number of humans.
    """
    HEL_LABELS = ['head', 'helmet']
    bboxes, labels, masks = helmet_detector(frame)
    
    human_count = 0  # Initialize human count
    
    for i in range(len(bboxes)):
        bbox = bboxes[i]
        label_id = labels[i]
        score = bbox[4]
        
        # Check if the label is 'head' or 'helmet' and the score is greater than confidence_threshold
        if HEL_LABELS[label_id] in ['head', 'helmet'] and score > confidence_threshold:
            human_count += 1
            [left, top, right, bottom] = bbox[0:4].astype(int)
            cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), thickness=1)  # Red color for boxes
            label_text = f"human ({score:.2f})"  # Include confidence score in label_text
            cv2.putText(frame, label_text, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)


    return frame, human_count


def red_tree(img):
    img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    def get_palette(num_classes=2):
        return [(255, 255, 255), (255, 0, 0)]
    seg = red_tree_segmentor(img_bgr)
    red_tree_pixel_count = np.sum(seg == 1)
    current_palette = get_palette()
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
    for label, color in enumerate(current_palette):
        color_seg[seg == label, :] = color
    color_seg_bgr = color_seg[..., ::-1]

    img_bgr = img_bgr * 0.5 + color_seg_bgr * 0.5
    img_bgr = img_bgr.astype(np.uint8)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    return img_rgb, red_tree_pixel_count


def leaf_disease(img):
    img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    
    def get_palette(num_classes=3):
        return [(255, 255, 255), (0, 255, 0), (255, 0, 0)]

    seg = leaf_disease_segmentor(img_bgr)
    
    leaf_pixel_count = np.sum(seg == 1)
    disease_pixel_count = np.sum(seg == 2)
    
    current_palette = get_palette()
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
    
    for label, color in enumerate(current_palette):
        color_seg[seg == label, :] = color
    
    color_seg_bgr = color_seg[..., ::-1]
    img_bgr = img_bgr * 0.5 + color_seg_bgr * 0.5
    img_bgr = img_bgr.astype(np.uint8)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    
    return img_rgb, leaf_pixel_count, disease_pixel_count

def single_label_disease(img):
    img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    
    def get_palette(num_classes=2):
        return [(255, 255, 255), (255, 0, 0)]

    seg = single_label_disease_segmentor(img_bgr)
    
    disease_pixel_count = np.sum(seg == 1)
    
    current_palette = get_palette()
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
    
    for label, color in enumerate(current_palette):
        color_seg[seg == label, :] = color
    
    color_seg_bgr = color_seg[..., ::-1]
    img_bgr = img_bgr * 0.5 + color_seg_bgr * 0.5
    img_bgr = img_bgr.astype(np.uint8)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    
    return img_rgb, disease_pixel_count



def get_image_examples():
    image_dir = "/mnt/e/AI/mmdeploy/gradio/photo"
    image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
    image_files.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))  # 按数字排序
    example_choices = [
        '红树林识别', '红树林识别', '红树林识别', 
        '安全帽检测', '安全帽检测', '安全帽检测',
        '人数统计', '人数统计', '人数统计',
        '反光衣检测','反光衣检测','反光衣检测',
        '道路车辆统计', '道路车辆统计', '道路车辆统计',
        '裂缝识别', '裂缝识别', '裂缝识别',
        '吸烟检测','吸烟检测','吸烟检测',
        '树叶病害识别1','树叶病害识别1','树叶病害识别1',
        '树叶病害识别2','树叶病害识别2','树叶病害识别2',
        '树叶病害检测3','树叶病害检测3','树叶病害检测3',
        '摔倒检测','摔倒检测','摔倒检测',
        '口罩佩戴检测','口罩佩戴检测','口罩佩戴检测',
    ]
    
    confidence_thresholds = [
    0, 0, 0, 
    0.7, 0.8, 0.6, 
    0.3, 0.8, 0.5, 
    0.8, 0.7, 0.8, 
    0.5, 0.2, 0.7, 
    0, 0, 0, 
    0.6, 0.9, 0.5, 
    0, 0, 0, 
    0, 0, 0, 
    0.4, 0.4, 0.5, 
    0.9, 0.9, 0.5, 
    0.8, 0.6, 0.5
]
    examples = [[example_choices[i], f"{image_dir}/{image_file}", confidence_thresholds[i]] for i, image_file in enumerate(image_files)]
    return examples

  


model_choices = ['红树林识别','裂缝识别','树叶病害识别1','树叶病害识别2','树叶病害检测3', '安全帽检测','反光衣检测', '吸烟检测','摔倒检测', '口罩佩戴检测','人数统计','道路车辆统计']


def create_blank_image(width=640, height=480, color=(255, 255, 255)):
    blank_image = np.zeros((height, width, 3), np.uint8)
    blank_image[:, :] = color
    return blank_image

def process_image(model_choice, image_array=None, confidence_threshold=0.3):
    output_text = '当前未有图片输入,请上传图片后再次点击运行。'
    
    if image_array is None:
        img = create_blank_image()
    else:
        if model_choice not in model_choices:
            model_choice = "安全帽检测"
        # 以下是模型选择和执行逻辑
        if model_choice == "红树林识别":
            img, red_tree_pixel_count = red_tree(image_array)  # 语义分割模型
            output_text = f"红树林的像素点有 {red_tree_pixel_count} 个。"
        elif model_choice == "安全帽检测":
            img, helmet_count, head_count = helmet_detection(image_array, confidence_threshold)
            output_text = f"佩戴安全帽的人数为:{helmet_count},未佩戴安全帽的人数为:{head_count}。"
        elif model_choice == "人数统计":
            img, human_count = human_calculation(image_array, confidence_threshold)
            output_text = f"该图片人员总人数为: {human_count}。"
        elif model_choice == "反光衣检测":
            img, vest_count, other_clothes_count= vest_detection(image_array, confidence_threshold)
            output_text = f"该图片中总有 {vest_count} 人配备了反光衣,{other_clothes_count} 人没有配备反光衣。"
        elif model_choice == "道路车辆统计":
            img, car_count = car_calculation(image_array, confidence_threshold)
            output_text = f"该道路上目前共有 {car_count} 台车辆。"
        elif model_choice == "裂缝识别":
            img, crack_result, crack_pixel_count = crack_classification(image_array, confidence_threshold)
            if crack_result == "Positive":
                output_text = f"该图片内存在裂缝,裂缝的像素点有 {crack_pixel_count} 个。"
            else:
                output_text = "该图片不存在裂缝。"
        elif model_choice == "树叶病害检测3":  
            img, disease_count = leaf_disease_detection(image_array, confidence_threshold)
            if disease_count > 0:
                output_text = f"共检测到 {disease_count} 处病害。"
            else:
                output_text = "并未检测到病害。"
        elif model_choice == "吸烟检测":
            img, smoker_count = smoker_detector(image_array, confidence_threshold)
            output_text = f"当前图片有 {smoker_count} 人在吸烟。"
        elif model_choice == "树叶病害识别1":
            img, leaf_pixel_count, disease_pixel_count = leaf_disease(image_array)  # 语义分割模型
            if disease_pixel_count == 0:
                output_text = "该树叶并未出现病害。"
            else:
                output_text = f"病害的像素点有 {disease_pixel_count} 个。"
        elif model_choice == "树叶病害识别2":  
            img, disease_pixel_count = single_label_disease(image_array)  # 语义分割模型
            output_text = f"病害的像素点有 {disease_pixel_count} 个。"
        elif model_choice == "摔倒检测":  # 您可以根据实际情况调整模型选择的名称
            img, fall_count = detect_falls(image_array,confidence_threshold)
            output_text = f"图像中摔倒的人数为 {fall_count} 人。"
        elif model_choice == "口罩佩戴检测":  # 您可以根据实际情况调整模型选择的名称
            img, mask_count, nomask_count, unfit_count = detect_masks(image_array,confidence_threshold)
            output_text = f"当前佩戴口罩的人数为 {mask_count},未正确佩戴口罩的人数为 {unfit_count},没有佩戴口罩的人数为 {nomask_count}。"
            
    return img, output_text

def process_video(model_choice, video=None, confidence_threshold=0.3):

    # 内部函数:创建空白视频
    def create_blank_video(filename, duration=5, fps=30, width=640, height=480, color=(255, 255, 255)):
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 使用mp4v编解码器
        out = cv2.VideoWriter(filename, fourcc, fps, (width, height))
        blank_image = np.zeros((height, width, 3), np.uint8)
        blank_image[:, :] = color
        for _ in range(int(fps * duration)):
            out.write(blank_image)
        out.release()

    # 检查视频是否存在
    if video is None:
        video_output_path = '/mnt/e/AI/mmdeploy/gradio/video/none.mp4'
        create_blank_video(video_output_path)
        output_text2 = '当前未有视频输入,请上传视频后再次点击运行。'
        return video_output_path, output_text2
    else:
        video_output_path = '/mnt/e/AI/mmdeploy/gradio/video/output_video.mp4'
        cap = cv2.VideoCapture(video)
        if not cap.isOpened():
            raise ValueError("无法打开视频文件")
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        # 获取输入视频的分辨率
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 使用帧采样的逻辑,但考虑到所有帧都需要处理,我们使用间隔为1的采样。
        clip_len, frame_interval, num_clips = 1, 1, num_frames
        avg_interval = (num_frames - clip_len * frame_interval + 1) / float(num_clips)
        frame_inds = []
        for i in range(num_clips):
            clip_offset = int(i * avg_interval + avg_interval / 2.0)
            for j in range(clip_len):
                ind = (j * frame_interval + clip_offset) % num_frames
                if num_frames <= clip_len * frame_interval - 1:
                    ind = j % num_frames
                frame_inds.append(ind)

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        processed_frames = []
        for frame_id in sorted(frame_inds): 
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)  # 设置读取特定的帧
            ret, frame = cap.read()
            if not ret:
                break
            # 将帧率添加到视频的左上角
            cv2.putText(frame, "FPS: {}".format(fps), (10, 30), 
                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
            
            if model_choice == "红树林识别":
                # 在此处调用红树林模型处理帧
                processed_frame, red_tree_pixel_count = red_tree(frame)
                    # 在处理后的帧的右上角添加文字
                cv2.putText(processed_frame, "Number of pixels in this frame: {}".format(red_tree_pixel_count), 
                            (processed_frame.shape[1] - 300, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                            0.7, (255, 255, 255), 2)


            elif model_choice == "安全帽检测":
                # 在此处调用安全帽检测模型处理帧
                processed_frame, helmet_count, head_count = helmet_detection(frame, confidence_threshold)
                
                cv2.putText(processed_frame, "Number of people wearing helmets: {}".format(helmet_count), 
                (processed_frame.shape[1] - 400, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                0.7, (255, 255, 255), 2)
    
                # 在上一行文字下方添加表示未佩戴安全帽的人数的文字
                cv2.putText(processed_frame, "Number of people without helmets: {}".format(head_count - helmet_count), 
                            (processed_frame.shape[1] - 450, 60), cv2.FONT_HERSHEY_SIMPLEX, 
                            0.7, (255, 255, 255), 2)


            elif model_choice == "人数统计":
                # 在此处调用人数统计模型处理帧
                processed_frame, human_count = human_calculation(frame, confidence_threshold)
                cv2.putText(processed_frame, "Current number of people: {}".format(human_count), 
                (processed_frame.shape[1] - 300, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                0.7, (255, 255, 255), 2)

            elif model_choice == "反光衣检测":
                # 在此处调用反光衣检测模型处理帧
                processed_frame, vest_count, other_clothes_count= vest_detection(image_array, confidence_threshold)
                cv2.putText(processed_frame, "Number of reflective vests: {}".format(vest_count), 
                (processed_frame.shape[1] - 350, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                0.7, (255, 255, 255), 2)
                cv2.putText(processed_frame, "Number without reflective vests: {}".format(other_clothes_count), 
                        (processed_frame.shape[1] - 450, 60), cv2.FONT_HERSHEY_SIMPLEX, 
                        0.7, (255, 255, 255), 2)

            elif model_choice == "道路车辆统计":
                # 在此处调用道路车辆统计模型处理帧
                processed_frame, car_count = car_calculation(frame, confidence_threshold)
                cv2.putText(processed_frame, "Number of vehicles: {}".format(car_count), 
                (processed_frame.shape[1] - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                0.7, (255, 255, 255), 2)               

            elif model_choice == "裂缝识别":
                # 在此处调用裂缝识别模型处理帧
                processed_frame, crack_pixel_count= crack_detection(frame)

            elif model_choice == "树叶病害检测3":
                # 在此处调用树叶病害检测模型处理帧
                processed_frame, disease_count= leaf_disease_detection(frame, confidence_threshold)
                # 在图像右上角显示叶片的病害数量
                label_text = f"Leaf Disease Count: {disease_count}"
                color = (0, 0, 255)  # 红色
                font_scale = 0.8
                line_thickness = 2
                text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, line_thickness)[0]
                cv2.putText(processed_frame, label_text, (processed_frame.shape[1] - text_size[0] - 10, text_size[1] + 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, line_thickness)



            elif model_choice == "吸烟检测":
                # 在此处调用吸烟检测模型处理帧
                processed_frame, smoker_count = smoker_detector(frame, confidence_threshold)

                # 准备要显示的文本
                text = f"吸烟者数量: {smoker_count}"

                # 获取文本大小
                text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]

                # 计算文本的位置,以便它出现在帧的右上角
                text_position = (processed_frame.shape[1] - text_size[0] - 10, text_size[1] + 10)

                # 将文本绘制到帧上
                cv2.putText(processed_frame, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
                
            elif model_choice == "树叶病害识别1":
                # 在此处调用树叶病害识别模型处理帧
                processed_frame, _, disease_pixel_count = leaf_disease(frame)
                text = f"Current disease pixel count on the leaf: {disease_pixel_count}"
                text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
                cv2.putText(processed_frame, text, 
                            (processed_frame.shape[1] - text_size[0] - 10, text_size[1] + 10), 
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)


            elif model_choice == "树叶病害识别2":  
                # 在此处调用树叶病害识别模型处理帧
                processed_frame, disease_pixel_count= single_label_disease(frame)
                text = f"Current disease pixel count on the leaf: {disease_pixel_count}"
                text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
                cv2.putText(processed_frame, text, 
                            (processed_frame.shape[1] - text_size[0] - 10, text_size[1] + 10), 
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
                
            elif model_choice == "口罩佩戴检测":  # 您可以根据实际情况调整模型选择的名称
                processed_frame, mask_count, nomask_count, unfit_count = detect_masks(frame,confidence_threshold)
                cv2.putText(processed_frame, "Number wearing masks: {}".format(mask_count), 
                            (processed_frame.shape[1] - 350, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                            0.7, (255, 255, 255), 2)

                cv2.putText(processed_frame, "Number not wearing masks: {}".format(nomask_count), 
                            (processed_frame.shape[1] - 400, 60), cv2.FONT_HERSHEY_SIMPLEX, 
                            0.7, (255, 255, 255), 2)

                cv2.putText(processed_frame, "Number wearing masks incorrectly: {}".format(unfit_count), 
                            (processed_frame.shape[1] - 500, 90), cv2.FONT_HERSHEY_SIMPLEX, 
                            0.7, (255, 255, 255), 2)
                
            elif model_choice == "摔倒检测":  
                
                processed_frame, fall_count= detect_falls(frame,confidence_threshold)
                cv2.putText(processed_frame, "Number of people who fell: {}".format(fall_count), 
            (processed_frame.shape[1] - 350, 30), cv2.FONT_HERSHEY_SIMPLEX, 
            0.7, (255, 255, 255), 2)

            processed_frames.append(processed_frame)
        out = cv2.VideoWriter(video_output_path, fourcc, fps, (width,height))
        for frame in processed_frames:
            out.write(frame)
        out.release()
        cap.release()
        output_text2 = '请点击蓝色按钮下载视频。'
    return video_output_path, output_text2

with gr.Blocks() as demo:
    gr.Markdown("# <center>启云科技AI识别示例样板V1.12</center>")
    gr.Markdown("请上传图像或视频进行预测")
    with gr.Tab("AI图像处理"):
        with gr.Row():
            image_input2 = gr.Image(label="上传图像", type="numpy")
            with gr.Column():
                image_input1 = gr.Dropdown(choices=model_choices, label="模型选择")
                image_input3 = gr.Slider(minimum=0, maximum=1, step=0.1, label="置信度阈值")
        with gr.Row():
            image_output1 = gr.Image(label="处理后的图像", type="numpy")
            with gr.Column():
                image_output2 = gr.Textbox(label="图像输出信息")
        image_button = gr.Button('请点击按钮进行图像预测')
        gr.Examples(get_image_examples(),inputs=[image_input1, image_input2, image_input3],outputs=[image_output1, image_output2], fn=process_image ,examples_per_page=6 ,cache_examples=True)
    with gr.Tab("AI视频处理"):

        with gr.Row():
            video_input2 = gr.Video(label = '上传视频', format='mp4',interactive = True)
            with gr.Column():
                video_input1 = gr.Dropdown(choices=model_choices, label="模型选择")
                video_input3 = gr.Slider(minimum=0, maximum=1, 
                                         step=0.1, label="置信度阈值")
        with gr.Row():
            video_output1 = gr.File(label='处理后的视频', type='file')
            with gr.Column():
                video_output2 = gr.Textbox(label = '视频输出信息')   
        video_button = gr.Button('请点击按钮进行视频预测')
    with gr.Accordion("平台简介"):
        gr.Markdown("红树林识别模型、裂缝识别模型、树叶病害识别模型、安全帽检测模型、反光衣检测模型、吸烟检测模型、口罩佩戴检测、摔倒检测、人数统计模型及道路车辆统计模型展示平台。")
    image_button.click(process_image, inputs = [image_input1, image_input2, image_input3], outputs=[image_output1, image_output2])
    video_button.click(process_video, inputs=[video_input1,video_input2, video_input3], outputs=[video_output1, video_output2])

demo.launch(share=True)