xfys commited on
Commit
32a9efd
1 Parent(s): 8c87f17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -40
app.py CHANGED
@@ -9,7 +9,14 @@ from yolov5 import detect
9
  from PIL import Image
10
 
11
  # 目标检测
12
- def Detect(image):
 
 
 
 
 
 
 
13
  # 创建临时文件夹
14
  temp_path = tempfile.TemporaryDirectory(dir="./")
15
  temp_dir = temp_path.name
@@ -21,7 +28,7 @@ def Detect(image):
21
  # 结果图片的存储目录
22
  temp_result_path = os.path.join(temp_dir, "tempresult")
23
  # 对临时图片进行检测
24
- detect.run(source=temp_image_path, data="test_image/FLIR.yaml", weights="weights/best.pt", project=f'./{temp_dir}',name = 'tempresult', hide_conf=False, conf_thres=0.35)
25
  # 结果图片的路径
26
  temp_result_path = os.path.join(temp_result_path, os.listdir(temp_result_path)[0])
27
  # 读取结果图片
@@ -32,25 +39,34 @@ def Detect(image):
32
 
33
  # 候选图片
34
  example_image= [
35
- "./test_image/video-2SReBn5LtAkL5HMj2-frame-005072-MA7NCLQGoqq9aHaiL.jpg",
36
- "./test_image/video-2rsjnZFyGQGeynfbv-frame-003708-6fPQbB7jtibwaYAE7.jpg",
37
- "./test_image/video-2SReBn5LtAkL5HMj2-frame-000317-HTgPBFgZyPdwQnNvE.jpg",
38
- "./test_image/video-jNQtRj6NGycZDEXpe-frame-002515-J3YntG8ntvZheKK3P.jpg",
39
- "./test_image/video-kDDWXrnLSoSdHCZ7S-frame-003063-eaKjPvPskDPjenZ8S.jpg",
40
- "./test_image/video-r68Yr9RPWEp5fW2ZF-frame-000333-X6K5iopqbmjKEsSqN.jpg"
 
 
41
  ]
42
 
43
  # 目标追踪
44
- def Track(video, tracking_method):
45
  # 存储临时视频的文件夹
46
  temp_dir = "./temp"
47
  # 先清空temp文件夹
48
  shutil.rmtree("./temp")
49
  os.mkdir("./temp")
 
 
 
 
 
 
 
50
  # 获取视频的名字
51
  video_name = os.path.basename(video)
52
  # 对视频进行检测
53
- track.run(source=video, yolo_weights=Path("weights/best2.pt"),reid_weights=Path("weights/osnet_x0_25_msmt17.pt") , project=Path(f'./{temp_dir}'),name = 'tempresult', tracking_method=tracking_method)
54
  # 结果视频的路径
55
  temp_result_path = os.path.join(f'./{temp_dir}', "tempresult", video_name)
56
  # 返回结果视频的路径
@@ -58,24 +74,35 @@ def Track(video, tracking_method):
58
 
59
  # 候选视频
60
  example_video= [
61
- ["./video/5.mp4", "bytetrack"],
62
- ["./video/bicyclecity.mp4", "strongsort"],
63
- ["./video/9.mp4", "bytetrack"],
64
- ["./video/8.mp4", "strongsort"],
65
- ["./video/4.mp4", "bytetrack"],
66
- ["./video/car.mp4", "strongsort"],
 
 
67
  ]
68
 
69
  iface_Image = gr.Interface(fn=Detect,
70
- inputs=gr.Image(label="上传一张红外图像,仅支持jpg格式"),
 
 
 
 
71
  outputs=gr.Image(label="检测结果"),
72
- examples=example_image)
 
73
 
74
  iface_video = gr.Interface(fn=Track,
75
- inputs=[gr.Video(label="上传段红外视频,仅支持mp4格式"),
76
- gr.Radio(["bytetrack", "strongsort"],
77
- label="track methond",
78
- info="选择追踪器",
 
 
 
 
79
  value="bytetrack")],
80
  outputs=gr.Video(label="追踪结果"),
81
  examples=example_video
@@ -84,22 +111,4 @@ iface_video = gr.Interface(fn=Track,
84
  demo = gr.TabbedInterface([iface_video, iface_Image], tab_names=["目标追踪", "目标检测"], title="红外目标检测追踪")
85
 
86
  demo.launch()
87
- #iface_Image.launch()
88
-
89
-
90
-
91
-
92
-
93
-
94
-
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
 
 
9
  from PIL import Image
10
 
11
  # 目标检测
12
+ def Detect(image, image_type):
13
+ if image_type == "红外图像":
14
+ pt = "best.pt"
15
+ cnf = "FLIR.yaml"
16
+ else:
17
+ pt = "yolov5s.pt"
18
+ cnf = "coco128.yaml"
19
+
20
  # 创建临时文件夹
21
  temp_path = tempfile.TemporaryDirectory(dir="./")
22
  temp_dir = temp_path.name
 
28
  # 结果图片的存储目录
29
  temp_result_path = os.path.join(temp_dir, "tempresult")
30
  # 对临时图片进行检测
31
+ detect.run(source=temp_image_path, data=f"test_image/{cnf}", weights=f"weights/{pt}", project=f'./{temp_dir}',name = 'tempresult', hide_conf=False, conf_thres=0.35)
32
  # 结果图片的路径
33
  temp_result_path = os.path.join(temp_result_path, os.listdir(temp_result_path)[0])
34
  # 读取结果图片
 
39
 
40
  # 候选图片
41
  example_image= [
42
+ ["./test_image/1.jpg", "红外图像"],
43
+ ["./test_image/2.jpg", "红外图像"],
44
+ ["./test_image/3.jpg", "红外图像"],
45
+ ["./test_image/8.jpg", "红外图像"],
46
+ ["./test_image/5.jpg", "红外图像"],
47
+ # ["./test_image/6.jpg]", "红外图像"],
48
+ ["./test_image/4.jpg", "可见光图像"],
49
+ ["./test_image/7.jpg", "可见光图像"]
50
  ]
51
 
52
  # 目标追踪
53
+ def Track(video, video_type, tracking_method):
54
  # 存储临时视频的文件夹
55
  temp_dir = "./temp"
56
  # 先清空temp文件夹
57
  shutil.rmtree("./temp")
58
  os.mkdir("./temp")
59
+
60
+ # 获取视频的形式
61
+ if video_type == "红外视频":
62
+ pt = "best2.pt"
63
+ else:
64
+ pt = "yolov5s.pt"
65
+
66
  # 获取视频的名字
67
  video_name = os.path.basename(video)
68
  # 对视频进行检测
69
+ track.run(source=video, yolo_weights=Path(f"weights/{pt}"),reid_weights=Path("weights/osnet_x0_25_msmt17.pt") , project=Path(f'./{temp_dir}'), name = 'tempresult', tracking_method=tracking_method)
70
  # 结果视频的路径
71
  temp_result_path = os.path.join(f'./{temp_dir}', "tempresult", video_name)
72
  # 返回结果视频的路径
 
74
 
75
  # 候选视频
76
  example_video= [
77
+ ["./video/5.mp4", "红外视频", "bytetrack"],
78
+ ["./video/bicyclecity.mp4","红外视频", "strongsort"],
79
+ ["./video/9.mp4", "红外视频", "bytetrack"],
80
+ ["./video/8.mp4", "红外视频", "strongsort"],
81
+ ["./video/4.mp4", "红外视频", "bytetrack"],
82
+ ["./video/car.mp4", "红外视频", "strongsort"],
83
+ ["./video/caixukun.mp4", "可见光视频", "bytetrack"],
84
+ ["./video/palace.mp4", "可见光视频", "strongsort"],
85
  ]
86
 
87
  iface_Image = gr.Interface(fn=Detect,
88
+ inputs=[gr.Image(label="上传一张图像(jpg格式)"),
89
+ gr.Radio(["红外图像", "可见光图像"],
90
+ label="image type",
91
+ info="选择图片的形式",
92
+ value="红外图像")],
93
  outputs=gr.Image(label="检测结果"),
94
+ examples=example_image
95
+ )
96
 
97
  iface_video = gr.Interface(fn=Track,
98
+ inputs=[gr.Video(label="上传一段视频(mp4格式)"),
99
+ gr.Radio(["红外视频", "可见光视频"],
100
+ label="video type",
101
+ info="选择视频的形式",
102
+ value="bytetrack"),
103
+ gr.Radio(["bytetrack", "strongsort"],
104
+ label="track methond",
105
+ info="建议使用bytetrack, strongsort在cpu上运行很慢",
106
  value="bytetrack")],
107
  outputs=gr.Video(label="追踪结果"),
108
  examples=example_video
 
111
  demo = gr.TabbedInterface([iface_video, iface_Image], tab_names=["目标追踪", "目标检测"], title="红外目标检测追踪")
112
 
113
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114