isLinXu commited on
Commit
e2881a8
·
1 Parent(s): 2accf59

update app

Browse files
Files changed (2) hide show
  1. app.py +308 -0
  2. requirements.txt +20 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
3
+
4
+ import PIL.Image
5
+ import gradio as gr
6
+ import torch
7
+ import numpy as np
8
+ import cv2
9
+
10
+ from detectron2.config import get_cfg
11
+ from detectron2.data.detection_utils import read_image
12
+ import atexit
13
+ import bisect
14
+ import multiprocessing as mp
15
+ from collections import deque
16
+ import cv2
17
+ import torch
18
+
19
+ from detectron2.data import MetadataCatalog
20
+ from detectron2.engine.defaults import DefaultPredictor
21
+ from detectron2.utils.video_visualizer import VideoVisualizer
22
+ from detectron2.utils.visualizer import ColorMode, Visualizer
23
+
24
+ import warnings
25
+ warnings.filterwarnings("ignore")
26
+
27
+ class VisualizationDemo:
28
+ def __init__(self, cfg, device, instance_mode=ColorMode.IMAGE, parallel=False):
29
+ """
30
+ Args:
31
+ cfg (CfgNode):
32
+ instance_mode (ColorMode):
33
+ parallel (bool): whether to run the model in different processes from visualization.
34
+ Useful since the visualization logic can be slow.
35
+ """
36
+ self.metadata = MetadataCatalog.get(
37
+ cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
38
+ )
39
+ self.cpu_device = torch.device("cpu")
40
+ self.instance_mode = instance_mode
41
+
42
+ self.parallel = parallel
43
+ if parallel:
44
+ num_gpu = torch.cuda.device_count()
45
+ print("num_gpu: ", num_gpu)
46
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
47
+ else:
48
+ cfg.defrost()
49
+ # print("cfg: ", cfg)
50
+ cfg.MODEL.DEVICE = device
51
+
52
+ self.predictor = DefaultPredictor(cfg)
53
+
54
+ def run_on_image(self, image):
55
+ """
56
+ Args:
57
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
58
+ This is the format used by OpenCV.
59
+
60
+ Returns:
61
+ predictions (dict): the output of the model.
62
+ vis_output (VisImage): the visualized image output.
63
+ """
64
+ vis_output = None
65
+ predictions = self.predictor(image)
66
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
67
+ image = image[:, :, ::-1]
68
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
69
+ if "panoptic_seg" in predictions:
70
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
71
+ vis_output = visualizer.draw_panoptic_seg_predictions(
72
+ panoptic_seg.to(self.cpu_device), segments_info
73
+ )
74
+ else:
75
+ if "sem_seg" in predictions:
76
+ vis_output = visualizer.draw_sem_seg(
77
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
78
+ )
79
+ if "instances" in predictions:
80
+ instances = predictions["instances"].to(self.cpu_device)
81
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
82
+
83
+ return predictions, vis_output
84
+
85
+ def _frame_from_video(self, video):
86
+ while video.isOpened():
87
+ success, frame = video.read()
88
+ if success:
89
+ yield frame
90
+ else:
91
+ break
92
+
93
+ def run_on_video(self, video):
94
+ """
95
+ Visualizes predictions on frames of the input video.
96
+
97
+ Args:
98
+ video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
99
+ either a webcam or a video file.
100
+
101
+ Yields:
102
+ ndarray: BGR visualizations of each video frame.
103
+ """
104
+ video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
105
+
106
+ def process_predictions(frame, predictions):
107
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
+ if "panoptic_seg" in predictions:
109
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
110
+ vis_frame = video_visualizer.draw_panoptic_seg_predictions(
111
+ frame, panoptic_seg.to(self.cpu_device), segments_info
112
+ )
113
+ elif "instances" in predictions:
114
+ predictions = predictions["instances"].to(self.cpu_device)
115
+ vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
116
+ elif "sem_seg" in predictions:
117
+ vis_frame = video_visualizer.draw_sem_seg(
118
+ frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
119
+ )
120
+
121
+ # Converts Matplotlib RGB format to OpenCV BGR format
122
+ vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
123
+ return vis_frame
124
+
125
+ frame_gen = self._frame_from_video(video)
126
+ if self.parallel:
127
+ buffer_size = self.predictor.default_buffer_size
128
+
129
+ frame_data = deque()
130
+
131
+ for cnt, frame in enumerate(frame_gen):
132
+ frame_data.append(frame)
133
+ self.predictor.put(frame)
134
+
135
+ if cnt >= buffer_size:
136
+ frame = frame_data.popleft()
137
+ predictions = self.predictor.get()
138
+ yield process_predictions(frame, predictions)
139
+
140
+ while len(frame_data):
141
+ frame = frame_data.popleft()
142
+ predictions = self.predictor.get()
143
+ yield process_predictions(frame, predictions)
144
+ else:
145
+ for frame in frame_gen:
146
+ yield process_predictions(frame, self.predictor(frame))
147
+
148
+
149
+ class AsyncPredictor:
150
+ """
151
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
152
+ Because rendering the visualization takes considerably amount of time,
153
+ this helps improve throughput a little bit when rendering videos.
154
+ """
155
+
156
+ class _StopToken:
157
+ pass
158
+
159
+ class _PredictWorker(mp.Process):
160
+ def __init__(self, cfg, task_queue, result_queue):
161
+ self.cfg = cfg
162
+ self.task_queue = task_queue
163
+ self.result_queue = result_queue
164
+ super().__init__()
165
+
166
+ def run(self):
167
+ predictor = DefaultPredictor(self.cfg)
168
+
169
+ while True:
170
+ task = self.task_queue.get()
171
+ if isinstance(task, AsyncPredictor._StopToken):
172
+ break
173
+ idx, data = task
174
+ result = predictor(data)
175
+ self.result_queue.put((idx, result))
176
+
177
+ def __init__(self, cfg, num_gpus: int = 1):
178
+ """
179
+ Args:
180
+ cfg (CfgNode):
181
+ num_gpus (int): if 0, will run on CPU
182
+ """
183
+ num_workers = max(num_gpus, 1)
184
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
185
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
186
+ self.procs = []
187
+ for gpuid in range(max(num_gpus, 1)):
188
+ cfg = cfg.clone()
189
+ cfg.defrost()
190
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
191
+ self.procs.append(
192
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
193
+ )
194
+
195
+ self.put_idx = 0
196
+ self.get_idx = 0
197
+ self.result_rank = []
198
+ self.result_data = []
199
+
200
+ for p in self.procs:
201
+ p.start()
202
+ atexit.register(self.shutdown)
203
+
204
+ def put(self, image):
205
+ self.put_idx += 1
206
+ self.task_queue.put((self.put_idx, image))
207
+
208
+ def get(self):
209
+ self.get_idx += 1 # the index needed for this request
210
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
211
+ res = self.result_data[0]
212
+ del self.result_data[0], self.result_rank[0]
213
+ return res
214
+
215
+ while True:
216
+ # make sure the results are returned in the correct order
217
+ idx, res = self.result_queue.get()
218
+ if idx == self.get_idx:
219
+ return res
220
+ insert = bisect.bisect(self.result_rank, idx)
221
+ self.result_rank.insert(insert, idx)
222
+ self.result_data.insert(insert, res)
223
+
224
+ def __len__(self):
225
+ return self.put_idx - self.get_idx
226
+
227
+ def __call__(self, image):
228
+ self.put(image)
229
+ return self.get()
230
+
231
+ def shutdown(self):
232
+ for _ in self.procs:
233
+ self.task_queue.put(AsyncPredictor._StopToken())
234
+
235
+ @property
236
+ def default_buffer_size(self):
237
+ return len(self.procs) * 5
238
+
239
+
240
+ detectron2_model_list = {
241
+ "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x":{
242
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
243
+ "ckpts": "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"
244
+ },
245
+ }
246
+
247
+
248
+ # def dtectron2_instance_inference(image, config_file, ckpts, device):
249
+ # cfg = get_cfg()
250
+ # cfg.merge_from_file(config_file)
251
+ # cfg.MODEL.WEIGHTS = ckpts
252
+ # cfg.MODEL.DEVICE = "cpu"
253
+ # cfg.output = "output_img.jpg"
254
+ # visualization_demo = VisualizationDemo(cfg, device=device)
255
+ # if image:
256
+ # intput_path = "intput_img.jpg"
257
+ # image.save(intput_path)
258
+ # image = read_image(intput_path, format="BGR")
259
+ # predictions, vis_output = visualization_demo.run_on_image(image)
260
+ # output_image = PIL.Image.fromarray(vis_output.get_image())
261
+ # # print("predictions: ", predictions)
262
+ # return output_image
263
+
264
+ def dtectron2_instance_inference(image, input_model_name, device):
265
+ cfg = get_cfg()
266
+ config_file = detectron2_model_list[input_model_name]["config_file"]
267
+ ckpts = detectron2_model_list[input_model_name]["ckpts"]
268
+ cfg.merge_from_file(config_file)
269
+ cfg.MODEL.WEIGHTS = ckpts
270
+ cfg.MODEL.DEVICE = "cpu"
271
+ cfg.output = "output_img.jpg"
272
+ visualization_demo = VisualizationDemo(cfg, device=device)
273
+ if image:
274
+ intput_path = "intput_img.jpg"
275
+ image.save(intput_path)
276
+ image = read_image(intput_path, format="BGR")
277
+ predictions, vis_output = visualization_demo.run_on_image(image)
278
+ output_image = PIL.Image.fromarray(vis_output.get_image())
279
+ # print("predictions: ", predictions)
280
+ return output_image
281
+
282
+ def download_test_img():
283
+ # Images
284
+ torch.hub.download_url_to_file(
285
+ 'https://user-images.githubusercontent.com/59380685/268517006-d8d4d3b3-964a-4f4d-8458-18c7eb75a4f2.jpg',
286
+ '000000502136.jpg')
287
+
288
+
289
+ if __name__ == '__main__':
290
+ input_image = gr.inputs.Image(type='pil', label='Input Image')
291
+ input_model_name = gr.inputs.Dropdown(list(detectron2_model_list.keys()), label="Model Name", default="COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x")
292
+ input_device = gr.inputs.Dropdown(["cpu", "cuda"], label="Devices", default="cpu")
293
+ output_image = gr.outputs.Image(type='pil', label='Output Image')
294
+ output_predictions = gr.outputs.Textbox(type='text', label='Output Predictions')
295
+
296
+ title = "Detectron2 web demo"
297
+ description = "<div align='center'><img src='https://raw.githubusercontent.com/facebookresearch/detectron2/8c4a333ceb8df05348759443d0206302485890e0/.github/Detectron2-Logo-Horz.svg' width='450''/><div>" \
298
+ "<p style='text-align: center'><a href='https://github.com/facebookresearch/detectron2'>Detectron2</a> Detectron2 是 Facebook AI Research 的下一代库,提供最先进的检测和分割算法。它是Detectron 和maskrcnn-benchmark的后继者 。它支持 Facebook 中的许多计算机视觉研究项目和生产应用。" \
299
+ "Detectron2 is a platform for object detection, segmentation and other visual recognition tasks..</p>"
300
+ article = "<p style='text-align: center'><a href='https://github.com/facebookresearch/detectron2'>Detectron2</a></p>" \
301
+ "<p style='text-align: center'><a href='https://github.com/facebookresearch/detectron2'>gradio build by gatilin</a></a></p>"
302
+ download_test_img()
303
+
304
+ examples = [["000000502136.jpg", "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x", "cpu"]]
305
+ gr.Interface(fn=dtectron2_instance_inference,
306
+ inputs=[input_image, input_model_name, input_device],
307
+ outputs=output_image,examples=examples,
308
+ title=title, description=description, article=article).launch()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wget~=3.2
2
+ opencv-python~=4.6.0.66
3
+ numpy~=1.23.0
4
+ torch~=1.13.1
5
+ torchvision~=0.14.1
6
+ pillow~=9.4.0
7
+ gradio~=3.42.0
8
+ ultralytics~=8.0.169
9
+ pyyaml~=6.0
10
+ wandb~=0.13.11
11
+ tqdm~=4.65.0
12
+ matplotlib~=3.7.1
13
+ pandas~=2.0.0
14
+ seaborn~=0.12.2
15
+ requests~=2.31.0
16
+ psutil~=5.9.4
17
+ thop~=0.1.1-2209072238
18
+ timm~=0.9.2
19
+ super-gradients~=3.2.0
20
+ openmim