qsitj commited on
Commit
f93a294
·
verified ·
1 Parent(s): 4f82037

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -45
app.py CHANGED
@@ -1,28 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import torch
3
  from transformers import pipeline
4
 
5
  from PIL import Image
6
 
7
- import matplotlib.pyplot as plt
8
  import matplotlib.patches as patches
9
 
10
- from random import choice
11
- import io
12
 
13
  detector50 = pipeline(model="facebook/detr-resnet-50")
14
 
15
  detector101 = pipeline(model="facebook/detr-resnet-101")
16
 
17
 
18
- import gradio as gr
19
 
20
  COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
21
  "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
22
  "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
23
 
24
  fdic = {
25
- "family" : "Impact",
26
  "style" : "italic",
27
  "size" : 15,
28
  "color" : "yellow",
@@ -30,12 +46,33 @@ fdic = {
30
  }
31
 
32
 
33
- def get_figure(in_pil_img, in_results):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  plt.figure(figsize=(16, 10))
35
  plt.imshow(in_pil_img)
36
- #pyplot.gcf()
37
  ax = plt.gca()
38
 
 
 
39
  for prediction in in_results:
40
  selected_color = choice(COLORS)
41
 
@@ -50,15 +87,8 @@ def get_figure(in_pil_img, in_results):
50
  return plt.gcf()
51
 
52
 
53
- def infer(model, in_pil_img):
54
-
55
- results = None
56
- if model == "detr-resnet-101":
57
- results = detector101(in_pil_img)
58
- else:
59
- results = detector50(in_pil_img)
60
-
61
- figure = get_figure(in_pil_img, results)
62
 
63
  buf = io.BytesIO()
64
  figure.savefig(buf, bbox_inches='tight')
@@ -68,39 +98,91 @@ def infer(model, in_pil_img):
68
  return output_pil_img
69
 
70
 
71
- with gr.Blocks(title="DETR Object Detection - ClassCat",
72
- css=".gradio-container {background:lightyellow;}"
73
- ) as demo:
74
- #sample_index = gr.State([])
75
-
76
- gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">DETR Object Detection</div>""")
77
-
78
- gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")
79
-
80
- model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
81
 
82
- gr.HTML("""<br/>""")
83
- gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
84
- gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  with gr.Row():
87
- input_image = gr.Image(label="Input image", type="pil")
88
- output_image = gr.Image(label="Output image with predicted instances", type="pil")
89
-
90
- gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)
91
-
92
- gr.HTML("""<br/>""")
93
- gr.HTML("""<h4 style="color:navy;">3. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")
94
-
95
- send_btn = gr.Button("Infer")
96
- send_btn.click(fn=infer, inputs=[model, input_image], outputs=[output_image])
97
-
98
- gr.HTML("""<br/>""")
99
- gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
100
- gr.HTML("""<ul>""")
101
- gr.HTML("""<li><a href="https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb" target="_blank">Hands-on tutorial for DETR</a>""")
102
- gr.HTML("""</ul>""")
103
 
 
104
 
105
  #demo.queue()
106
  demo.launch(debug=True)
 
1
+ import os
2
+
3
+ from gradio_webrtc import WebRTC
4
+ import requests
5
+ from PIL import Image
6
+
7
+ import matplotlib.pyplot as plt
8
+
9
+ from random import choice
10
+ import io
11
+
12
+ import gradio as gr
13
+
14
+ import cv2
15
+ import numpy as np
16
+
17
+ from io import BytesIO
18
+ import random
19
+ import tempfile
20
+ from pathlib import Path
21
 
22
  import torch
23
  from transformers import pipeline
24
 
25
  from PIL import Image
26
 
 
27
  import matplotlib.patches as patches
28
 
 
 
29
 
30
  detector50 = pipeline(model="facebook/detr-resnet-50")
31
 
32
  detector101 = pipeline(model="facebook/detr-resnet-101")
33
 
34
 
 
35
 
36
  COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
37
  "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
38
  "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
39
 
40
  fdic = {
41
+ # "family" : "Impact",
42
  "style" : "italic",
43
  "size" : 15,
44
  "color" : "yellow",
 
46
  }
47
 
48
 
49
+ def infer(model, in_pil_img):
50
+
51
+ results = None
52
+ if model == "detr-resnet-101":
53
+ results = detector101(in_pil_img)
54
+ else:
55
+ results = detector50(in_pil_img)
56
+
57
+ return results
58
+
59
+
60
+ #######################################
61
+
62
+
63
+ def query_data(model, in_pil_img: Image.Image):
64
+ return infer(model, in_pil_img)
65
+
66
+
67
+
68
+ def get_figure(in_pil_img):
69
  plt.figure(figsize=(16, 10))
70
  plt.imshow(in_pil_img)
71
+
72
  ax = plt.gca()
73
 
74
+ in_results = query_data(in_pil_img)
75
+
76
  for prediction in in_results:
77
  selected_color = choice(COLORS)
78
 
 
87
  return plt.gcf()
88
 
89
 
90
+ def infer(in_pil_img):
91
+ figure = get_figure(in_pil_img)
 
 
 
 
 
 
 
92
 
93
  buf = io.BytesIO()
94
  figure.savefig(buf, bbox_inches='tight')
 
98
  return output_pil_img
99
 
100
 
101
+ def process_single_frame(frame):
102
+ # 将 BGR 转换为 RGB
103
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
104
+
105
+ # 创建 PIL 图像对象
106
+ pil_image = Image.fromarray(rgb_frame)
 
 
 
 
107
 
108
+ # 获取带有标注信息的图像
109
+ figure = get_figure(pil_image)
 
110
 
111
+ buf = BytesIO()
112
+ figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
113
+ buf.seek(0)
114
+ annotated_image = Image.open(buf).convert('RGB')
115
+
116
+ return np.array(annotated_image)
117
+
118
+
119
+ def infer_video(input_video_path):
120
+ with tempfile.TemporaryDirectory() as tmp_dir:
121
+ # output_video_path = Path(tmp_dir) / "output.mp4"
122
+ cap = cv2.VideoCapture(input_video_path)
123
+
124
+ if not cap.isOpened():
125
+ raise ValueError("无法打开输入视频文件")
126
+
127
+ # width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
128
+ # height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
129
+ # fps = cap.get(cv2.CAP_PROP_FPS)
130
+ # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) # 使用原始视频的编码器
131
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数
132
+
133
+ # out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width, height))
134
+
135
+ frame_count = 0
136
+ try:
137
+ while frame_count < total_frames:
138
+ ret, frame = cap.read()
139
+ if not ret:
140
+ print(f"提前结束:在第 {frame_count} 帧时无法读取帧")
141
+ break
142
+
143
+ frame_count += 1
144
+
145
+ # 处理单帧并转换为 OpenCV 格式(BGR)
146
+ processed_frame = process_single_frame(frame)
147
+ bgr_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
148
+
149
+ yield bgr_frame
150
+
151
+ # 可选:显示进度
152
+ if frame_count % 30 == 0:
153
+ print(f"已处理 {frame_count}/{total_frames} 帧")
154
+
155
+ # if frame_count == 48:
156
+ # print("测试结束")
157
+ # return None
158
+
159
+ finally:
160
+ cap.release()
161
+
162
+ return None
163
+
164
+
165
+ # 更新 Gradio 接口以支持视频输入和输出
166
+ with gr.Blocks(title="长沙电网项目",
167
+ css=".gradio-container {background:lightyellow;}"
168
+ ) as demo:
169
+ gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>长沙电网项目</div>")
170
+
171
  with gr.Row():
172
+ input_video = gr.Video(label="输入视频")
173
+ output_video = WebRTC(label="WebRTC Stream",
174
+ rtc_configuration=None,
175
+ mode="receive",
176
+ modality="video")
177
+ detect = gr.Button("Detect", variant="primary")
178
+ output_video.stream(
179
+ fn=infer_video,
180
+ inputs=[input_video],
181
+ outputs=[output_video],
182
+ trigger=detect.click
183
+ )
 
 
 
 
184
 
185
+ demo.launch(debug=True)
186
 
187
  #demo.queue()
188
  demo.launch(debug=True)