Spaces:
Runtime error
Runtime error
isLinXu
commited on
Commit
·
e2881a8
1
Parent(s):
2accf59
update app
Browse files- app.py +308 -0
- requirements.txt +20 -0
app.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
|
3 |
+
|
4 |
+
import PIL.Image
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
from detectron2.config import get_cfg
|
11 |
+
from detectron2.data.detection_utils import read_image
|
12 |
+
import atexit
|
13 |
+
import bisect
|
14 |
+
import multiprocessing as mp
|
15 |
+
from collections import deque
|
16 |
+
import cv2
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from detectron2.data import MetadataCatalog
|
20 |
+
from detectron2.engine.defaults import DefaultPredictor
|
21 |
+
from detectron2.utils.video_visualizer import VideoVisualizer
|
22 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
23 |
+
|
24 |
+
import warnings
|
25 |
+
warnings.filterwarnings("ignore")
|
26 |
+
|
27 |
+
class VisualizationDemo:
|
28 |
+
def __init__(self, cfg, device, instance_mode=ColorMode.IMAGE, parallel=False):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
cfg (CfgNode):
|
32 |
+
instance_mode (ColorMode):
|
33 |
+
parallel (bool): whether to run the model in different processes from visualization.
|
34 |
+
Useful since the visualization logic can be slow.
|
35 |
+
"""
|
36 |
+
self.metadata = MetadataCatalog.get(
|
37 |
+
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
|
38 |
+
)
|
39 |
+
self.cpu_device = torch.device("cpu")
|
40 |
+
self.instance_mode = instance_mode
|
41 |
+
|
42 |
+
self.parallel = parallel
|
43 |
+
if parallel:
|
44 |
+
num_gpu = torch.cuda.device_count()
|
45 |
+
print("num_gpu: ", num_gpu)
|
46 |
+
self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
|
47 |
+
else:
|
48 |
+
cfg.defrost()
|
49 |
+
# print("cfg: ", cfg)
|
50 |
+
cfg.MODEL.DEVICE = device
|
51 |
+
|
52 |
+
self.predictor = DefaultPredictor(cfg)
|
53 |
+
|
54 |
+
def run_on_image(self, image):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
58 |
+
This is the format used by OpenCV.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
predictions (dict): the output of the model.
|
62 |
+
vis_output (VisImage): the visualized image output.
|
63 |
+
"""
|
64 |
+
vis_output = None
|
65 |
+
predictions = self.predictor(image)
|
66 |
+
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
67 |
+
image = image[:, :, ::-1]
|
68 |
+
visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
|
69 |
+
if "panoptic_seg" in predictions:
|
70 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
71 |
+
vis_output = visualizer.draw_panoptic_seg_predictions(
|
72 |
+
panoptic_seg.to(self.cpu_device), segments_info
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
if "sem_seg" in predictions:
|
76 |
+
vis_output = visualizer.draw_sem_seg(
|
77 |
+
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
78 |
+
)
|
79 |
+
if "instances" in predictions:
|
80 |
+
instances = predictions["instances"].to(self.cpu_device)
|
81 |
+
vis_output = visualizer.draw_instance_predictions(predictions=instances)
|
82 |
+
|
83 |
+
return predictions, vis_output
|
84 |
+
|
85 |
+
def _frame_from_video(self, video):
|
86 |
+
while video.isOpened():
|
87 |
+
success, frame = video.read()
|
88 |
+
if success:
|
89 |
+
yield frame
|
90 |
+
else:
|
91 |
+
break
|
92 |
+
|
93 |
+
def run_on_video(self, video):
|
94 |
+
"""
|
95 |
+
Visualizes predictions on frames of the input video.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
|
99 |
+
either a webcam or a video file.
|
100 |
+
|
101 |
+
Yields:
|
102 |
+
ndarray: BGR visualizations of each video frame.
|
103 |
+
"""
|
104 |
+
video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
|
105 |
+
|
106 |
+
def process_predictions(frame, predictions):
|
107 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
108 |
+
if "panoptic_seg" in predictions:
|
109 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
110 |
+
vis_frame = video_visualizer.draw_panoptic_seg_predictions(
|
111 |
+
frame, panoptic_seg.to(self.cpu_device), segments_info
|
112 |
+
)
|
113 |
+
elif "instances" in predictions:
|
114 |
+
predictions = predictions["instances"].to(self.cpu_device)
|
115 |
+
vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
|
116 |
+
elif "sem_seg" in predictions:
|
117 |
+
vis_frame = video_visualizer.draw_sem_seg(
|
118 |
+
frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
119 |
+
)
|
120 |
+
|
121 |
+
# Converts Matplotlib RGB format to OpenCV BGR format
|
122 |
+
vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
|
123 |
+
return vis_frame
|
124 |
+
|
125 |
+
frame_gen = self._frame_from_video(video)
|
126 |
+
if self.parallel:
|
127 |
+
buffer_size = self.predictor.default_buffer_size
|
128 |
+
|
129 |
+
frame_data = deque()
|
130 |
+
|
131 |
+
for cnt, frame in enumerate(frame_gen):
|
132 |
+
frame_data.append(frame)
|
133 |
+
self.predictor.put(frame)
|
134 |
+
|
135 |
+
if cnt >= buffer_size:
|
136 |
+
frame = frame_data.popleft()
|
137 |
+
predictions = self.predictor.get()
|
138 |
+
yield process_predictions(frame, predictions)
|
139 |
+
|
140 |
+
while len(frame_data):
|
141 |
+
frame = frame_data.popleft()
|
142 |
+
predictions = self.predictor.get()
|
143 |
+
yield process_predictions(frame, predictions)
|
144 |
+
else:
|
145 |
+
for frame in frame_gen:
|
146 |
+
yield process_predictions(frame, self.predictor(frame))
|
147 |
+
|
148 |
+
|
149 |
+
class AsyncPredictor:
|
150 |
+
"""
|
151 |
+
A predictor that runs the model asynchronously, possibly on >1 GPUs.
|
152 |
+
Because rendering the visualization takes considerably amount of time,
|
153 |
+
this helps improve throughput a little bit when rendering videos.
|
154 |
+
"""
|
155 |
+
|
156 |
+
class _StopToken:
|
157 |
+
pass
|
158 |
+
|
159 |
+
class _PredictWorker(mp.Process):
|
160 |
+
def __init__(self, cfg, task_queue, result_queue):
|
161 |
+
self.cfg = cfg
|
162 |
+
self.task_queue = task_queue
|
163 |
+
self.result_queue = result_queue
|
164 |
+
super().__init__()
|
165 |
+
|
166 |
+
def run(self):
|
167 |
+
predictor = DefaultPredictor(self.cfg)
|
168 |
+
|
169 |
+
while True:
|
170 |
+
task = self.task_queue.get()
|
171 |
+
if isinstance(task, AsyncPredictor._StopToken):
|
172 |
+
break
|
173 |
+
idx, data = task
|
174 |
+
result = predictor(data)
|
175 |
+
self.result_queue.put((idx, result))
|
176 |
+
|
177 |
+
def __init__(self, cfg, num_gpus: int = 1):
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
cfg (CfgNode):
|
181 |
+
num_gpus (int): if 0, will run on CPU
|
182 |
+
"""
|
183 |
+
num_workers = max(num_gpus, 1)
|
184 |
+
self.task_queue = mp.Queue(maxsize=num_workers * 3)
|
185 |
+
self.result_queue = mp.Queue(maxsize=num_workers * 3)
|
186 |
+
self.procs = []
|
187 |
+
for gpuid in range(max(num_gpus, 1)):
|
188 |
+
cfg = cfg.clone()
|
189 |
+
cfg.defrost()
|
190 |
+
cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
|
191 |
+
self.procs.append(
|
192 |
+
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
|
193 |
+
)
|
194 |
+
|
195 |
+
self.put_idx = 0
|
196 |
+
self.get_idx = 0
|
197 |
+
self.result_rank = []
|
198 |
+
self.result_data = []
|
199 |
+
|
200 |
+
for p in self.procs:
|
201 |
+
p.start()
|
202 |
+
atexit.register(self.shutdown)
|
203 |
+
|
204 |
+
def put(self, image):
|
205 |
+
self.put_idx += 1
|
206 |
+
self.task_queue.put((self.put_idx, image))
|
207 |
+
|
208 |
+
def get(self):
|
209 |
+
self.get_idx += 1 # the index needed for this request
|
210 |
+
if len(self.result_rank) and self.result_rank[0] == self.get_idx:
|
211 |
+
res = self.result_data[0]
|
212 |
+
del self.result_data[0], self.result_rank[0]
|
213 |
+
return res
|
214 |
+
|
215 |
+
while True:
|
216 |
+
# make sure the results are returned in the correct order
|
217 |
+
idx, res = self.result_queue.get()
|
218 |
+
if idx == self.get_idx:
|
219 |
+
return res
|
220 |
+
insert = bisect.bisect(self.result_rank, idx)
|
221 |
+
self.result_rank.insert(insert, idx)
|
222 |
+
self.result_data.insert(insert, res)
|
223 |
+
|
224 |
+
def __len__(self):
|
225 |
+
return self.put_idx - self.get_idx
|
226 |
+
|
227 |
+
def __call__(self, image):
|
228 |
+
self.put(image)
|
229 |
+
return self.get()
|
230 |
+
|
231 |
+
def shutdown(self):
|
232 |
+
for _ in self.procs:
|
233 |
+
self.task_queue.put(AsyncPredictor._StopToken())
|
234 |
+
|
235 |
+
@property
|
236 |
+
def default_buffer_size(self):
|
237 |
+
return len(self.procs) * 5
|
238 |
+
|
239 |
+
|
240 |
+
detectron2_model_list = {
|
241 |
+
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x":{
|
242 |
+
"config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
|
243 |
+
"ckpts": "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"
|
244 |
+
},
|
245 |
+
}
|
246 |
+
|
247 |
+
|
248 |
+
# def dtectron2_instance_inference(image, config_file, ckpts, device):
|
249 |
+
# cfg = get_cfg()
|
250 |
+
# cfg.merge_from_file(config_file)
|
251 |
+
# cfg.MODEL.WEIGHTS = ckpts
|
252 |
+
# cfg.MODEL.DEVICE = "cpu"
|
253 |
+
# cfg.output = "output_img.jpg"
|
254 |
+
# visualization_demo = VisualizationDemo(cfg, device=device)
|
255 |
+
# if image:
|
256 |
+
# intput_path = "intput_img.jpg"
|
257 |
+
# image.save(intput_path)
|
258 |
+
# image = read_image(intput_path, format="BGR")
|
259 |
+
# predictions, vis_output = visualization_demo.run_on_image(image)
|
260 |
+
# output_image = PIL.Image.fromarray(vis_output.get_image())
|
261 |
+
# # print("predictions: ", predictions)
|
262 |
+
# return output_image
|
263 |
+
|
264 |
+
def dtectron2_instance_inference(image, input_model_name, device):
|
265 |
+
cfg = get_cfg()
|
266 |
+
config_file = detectron2_model_list[input_model_name]["config_file"]
|
267 |
+
ckpts = detectron2_model_list[input_model_name]["ckpts"]
|
268 |
+
cfg.merge_from_file(config_file)
|
269 |
+
cfg.MODEL.WEIGHTS = ckpts
|
270 |
+
cfg.MODEL.DEVICE = "cpu"
|
271 |
+
cfg.output = "output_img.jpg"
|
272 |
+
visualization_demo = VisualizationDemo(cfg, device=device)
|
273 |
+
if image:
|
274 |
+
intput_path = "intput_img.jpg"
|
275 |
+
image.save(intput_path)
|
276 |
+
image = read_image(intput_path, format="BGR")
|
277 |
+
predictions, vis_output = visualization_demo.run_on_image(image)
|
278 |
+
output_image = PIL.Image.fromarray(vis_output.get_image())
|
279 |
+
# print("predictions: ", predictions)
|
280 |
+
return output_image
|
281 |
+
|
282 |
+
def download_test_img():
|
283 |
+
# Images
|
284 |
+
torch.hub.download_url_to_file(
|
285 |
+
'https://user-images.githubusercontent.com/59380685/268517006-d8d4d3b3-964a-4f4d-8458-18c7eb75a4f2.jpg',
|
286 |
+
'000000502136.jpg')
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == '__main__':
|
290 |
+
input_image = gr.inputs.Image(type='pil', label='Input Image')
|
291 |
+
input_model_name = gr.inputs.Dropdown(list(detectron2_model_list.keys()), label="Model Name", default="COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x")
|
292 |
+
input_device = gr.inputs.Dropdown(["cpu", "cuda"], label="Devices", default="cpu")
|
293 |
+
output_image = gr.outputs.Image(type='pil', label='Output Image')
|
294 |
+
output_predictions = gr.outputs.Textbox(type='text', label='Output Predictions')
|
295 |
+
|
296 |
+
title = "Detectron2 web demo"
|
297 |
+
description = "<div align='center'><img src='https://raw.githubusercontent.com/facebookresearch/detectron2/8c4a333ceb8df05348759443d0206302485890e0/.github/Detectron2-Logo-Horz.svg' width='450''/><div>" \
|
298 |
+
"<p style='text-align: center'><a href='https://github.com/facebookresearch/detectron2'>Detectron2</a> Detectron2 是 Facebook AI Research 的下一代库,提供最先进的检测和分割算法。它是Detectron 和maskrcnn-benchmark的后继者 。它支持 Facebook 中的许多计算机视觉研究项目和生产应用。" \
|
299 |
+
"Detectron2 is a platform for object detection, segmentation and other visual recognition tasks..</p>"
|
300 |
+
article = "<p style='text-align: center'><a href='https://github.com/facebookresearch/detectron2'>Detectron2</a></p>" \
|
301 |
+
"<p style='text-align: center'><a href='https://github.com/facebookresearch/detectron2'>gradio build by gatilin</a></a></p>"
|
302 |
+
download_test_img()
|
303 |
+
|
304 |
+
examples = [["000000502136.jpg", "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x", "cpu"]]
|
305 |
+
gr.Interface(fn=dtectron2_instance_inference,
|
306 |
+
inputs=[input_image, input_model_name, input_device],
|
307 |
+
outputs=output_image,examples=examples,
|
308 |
+
title=title, description=description, article=article).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wget~=3.2
|
2 |
+
opencv-python~=4.6.0.66
|
3 |
+
numpy~=1.23.0
|
4 |
+
torch~=1.13.1
|
5 |
+
torchvision~=0.14.1
|
6 |
+
pillow~=9.4.0
|
7 |
+
gradio~=3.42.0
|
8 |
+
ultralytics~=8.0.169
|
9 |
+
pyyaml~=6.0
|
10 |
+
wandb~=0.13.11
|
11 |
+
tqdm~=4.65.0
|
12 |
+
matplotlib~=3.7.1
|
13 |
+
pandas~=2.0.0
|
14 |
+
seaborn~=0.12.2
|
15 |
+
requests~=2.31.0
|
16 |
+
psutil~=5.9.4
|
17 |
+
thop~=0.1.1-2209072238
|
18 |
+
timm~=0.9.2
|
19 |
+
super-gradients~=3.2.0
|
20 |
+
openmim
|