hysts HF staff commited on
Commit
51e2977
1 Parent(s): a414687
Files changed (3) hide show
  1. app.py +16 -217
  2. model.py +189 -0
  3. style.css +11 -0
app.py CHANGED
@@ -3,35 +3,17 @@
3
  from __future__ import annotations
4
 
5
  import argparse
6
- import os
7
  import pathlib
8
- import subprocess
9
- import sys
10
  import tarfile
11
 
12
- if os.getenv('SYSTEM') == 'spaces':
13
- import mim
14
-
15
- mim.uninstall('mmcv-full', confirm_yes=True)
16
- mim.install('mmcv-full==1.5.0', is_yes=True)
17
-
18
- subprocess.call('pip uninstall -y opencv-python'.split())
19
- subprocess.call('pip uninstall -y opencv-python-headless'.split())
20
- subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
21
-
22
  import gradio as gr
23
- import huggingface_hub
24
- import numpy as np
25
- import torch
26
- import torch.nn as nn
27
 
28
- sys.path.insert(0, 'ViTPose/')
29
 
30
- from mmdet.apis import inference_detector, init_detector
31
- from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
32
- process_mmdet_results, vis_pose_result)
33
 
34
- TOKEN = os.environ['TOKEN']
 
35
 
36
 
37
  def parse_args() -> argparse.Namespace:
@@ -46,168 +28,6 @@ def parse_args() -> argparse.Namespace:
46
  return parser.parse_args()
47
 
48
 
49
- class DetModel:
50
- def __init__(self, device: str | torch.device):
51
- self.device = torch.device(device)
52
- self.models = self._load_models()
53
- self.model_name = 'YOLOX-l'
54
-
55
- def _load_models(self) -> dict[str, nn.Module]:
56
- model_dict = {
57
- 'YOLOX-tiny': {
58
- 'config':
59
- 'mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py',
60
- 'model':
61
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth',
62
- },
63
- 'YOLOX-s': {
64
- 'config':
65
- 'mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py',
66
- 'model':
67
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth',
68
- },
69
- 'YOLOX-l': {
70
- 'config':
71
- 'mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py',
72
- 'model':
73
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
74
- },
75
- 'YOLOX-x': {
76
- 'config':
77
- 'mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py',
78
- 'model':
79
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth',
80
- },
81
- }
82
- models = {
83
- key: init_detector(dic['config'], dic['model'], device=self.device)
84
- for key, dic in model_dict.items()
85
- }
86
- return models
87
-
88
- def set_model_name(self, name: str) -> None:
89
- self.model_name = name
90
-
91
- def detect_and_visualize(
92
- self, image: np.ndarray,
93
- score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
94
- out = self.detect(image)
95
- vis = self.visualize_detection_results(image, out, score_threshold)
96
- return out, vis
97
-
98
- def detect(self, image: np.ndarray) -> list[np.ndarray]:
99
- image = image[:, :, ::-1] # RGB -> BGR
100
- model = self.models[self.model_name]
101
- out = inference_detector(model, image)
102
- return out
103
-
104
- def visualize_detection_results(
105
- self,
106
- image: np.ndarray,
107
- detection_results: list[np.ndarray],
108
- score_threshold: float = 0.3) -> np.ndarray:
109
- person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)]
110
-
111
- image = image[:, :, ::-1] # RGB -> BGR
112
- model = self.models[self.model_name]
113
- vis = model.show_result(image,
114
- person_det,
115
- score_thr=score_threshold,
116
- bbox_color=None,
117
- text_color=(200, 200, 200),
118
- mask_color=None)
119
- return vis[:, :, ::-1] # BGR -> RGB
120
-
121
-
122
- class PoseModel:
123
- def __init__(self, device: str | torch.device):
124
- self.device = torch.device(device)
125
- self.models = self._load_models()
126
- self.model_name = 'ViTPose-B (multi-task train, COCO)'
127
-
128
- def _load_models(self) -> dict[str, nn.Module]:
129
- model_dict = {
130
- 'ViTPose-B (single-task train)': {
131
- 'config':
132
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
133
- 'model': 'models/vitpose-b.pth',
134
- },
135
- 'ViTPose-L (single-task train)': {
136
- 'config':
137
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
138
- 'model': 'models/vitpose-l.pth',
139
- },
140
- 'ViTPose-B (multi-task train, COCO)': {
141
- 'config':
142
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
143
- 'model': 'models/vitpose-b-multi-coco.pth',
144
- },
145
- 'ViTPose-L (multi-task train, COCO)': {
146
- 'config':
147
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
148
- 'model': 'models/vitpose-l-multi-coco.pth',
149
- },
150
- }
151
- models = dict()
152
- for key, dic in model_dict.items():
153
- ckpt_path = huggingface_hub.hf_hub_download('hysts/ViTPose',
154
- dic['model'],
155
- use_auth_token=TOKEN)
156
- model = init_pose_model(dic['config'],
157
- ckpt_path,
158
- device=self.device)
159
- models[key] = model
160
- return models
161
-
162
- def set_model_name(self, name: str) -> None:
163
- self.model_name = name
164
-
165
- def predict_pose_and_visualize(
166
- self,
167
- image: np.ndarray,
168
- det_results: list[np.ndarray],
169
- box_score_threshold: float,
170
- kpt_score_threshold: float,
171
- vis_dot_radius: int,
172
- vis_line_thickness: int,
173
- ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
174
- out = self.predict_pose(image, det_results, box_score_threshold)
175
- vis = self.visualize_pose_results(image, out, kpt_score_threshold,
176
- vis_dot_radius, vis_line_thickness)
177
- return out, vis
178
-
179
- def predict_pose(
180
- self,
181
- image: np.ndarray,
182
- det_results: list[np.ndarray],
183
- box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
184
- image = image[:, :, ::-1] # RGB -> BGR
185
- model = self.models[self.model_name]
186
- person_results = process_mmdet_results(det_results, 1)
187
- out, _ = inference_top_down_pose_model(model,
188
- image,
189
- person_results=person_results,
190
- bbox_thr=box_score_threshold,
191
- format='xyxy')
192
- return out
193
-
194
- def visualize_pose_results(self,
195
- image: np.ndarray,
196
- pose_results: list[np.ndarray],
197
- kpt_score_threshold: float = 0.3,
198
- vis_dot_radius: int = 4,
199
- vis_line_thickness: int = 1) -> np.ndarray:
200
- image = image[:, :, ::-1] # RGB -> BGR
201
- model = self.models[self.model_name]
202
- vis = vis_pose_result(model,
203
- image,
204
- pose_results,
205
- kpt_score_thr=kpt_score_threshold,
206
- radius=vis_dot_radius,
207
- thickness=vis_line_thickness)
208
- return vis[:, :, ::-1] # BGR -> RGB
209
-
210
-
211
  def set_example_image(example: list) -> dict:
212
  return gr.Image.update(value=example[0])
213
 
@@ -227,17 +47,8 @@ def main():
227
  det_model = DetModel(device=args.device)
228
  pose_model = PoseModel(device=args.device)
229
 
230
- css = '''
231
- h1#title {
232
- text-align: center;
233
- }
234
- '''
235
-
236
- with gr.Blocks(theme=args.theme, css=css) as demo:
237
- gr.Markdown('''<h1 id="title">ViTPose</h1>
238
-
239
- This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](https://github.com/ViTAE-Transformer/ViTPose).'''
240
- )
241
 
242
  with gr.Box():
243
  gr.Markdown('## Step 1')
@@ -318,14 +129,10 @@ This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](ht
318
  with gr.Row():
319
  redraw_pose_button = gr.Button(value='Redraw')
320
 
321
- gr.Markdown(
322
- '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.vitpose" alt="visitor badge"/></center>'
323
- )
324
 
325
  detector_name.change(fn=det_model.set_model_name,
326
- inputs=[
327
- detector_name,
328
- ],
329
  outputs=None)
330
  detect_button.click(fn=det_model.detect_and_visualize,
331
  inputs=[
@@ -342,14 +149,10 @@ This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](ht
342
  det_preds,
343
  vis_det_score_threshold,
344
  ],
345
- outputs=[
346
- detection_visualization,
347
- ])
348
 
349
  pose_model_name.change(fn=pose_model.set_model_name,
350
- inputs=[
351
- pose_model_name,
352
- ],
353
  outputs=None)
354
  predict_button.click(fn=pose_model.predict_pose_and_visualize,
355
  inputs=[
@@ -372,17 +175,13 @@ This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](ht
372
  vis_dot_radius,
373
  vis_line_thickness,
374
  ],
375
- outputs=[
376
- pose_visualization,
377
- ])
378
 
379
- example_images.click(fn=set_example_image,
380
- inputs=[
381
- example_images,
382
- ],
383
- outputs=[
384
- input_image,
385
- ])
386
 
387
  demo.launch(
388
  enable_queue=args.enable_queue,
 
3
  from __future__ import annotations
4
 
5
  import argparse
 
6
  import pathlib
 
 
7
  import tarfile
8
 
 
 
 
 
 
 
 
 
 
 
9
  import gradio as gr
 
 
 
 
10
 
11
+ from model import DetModel, PoseModel
12
 
13
+ DESCRIPTION = '''# ViTPose
 
 
14
 
15
+ This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](https://github.com/ViTAE-Transformer/ViTPose).'''
16
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.vitpose" />'
17
 
18
 
19
  def parse_args() -> argparse.Namespace:
 
28
  return parser.parse_args()
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def set_example_image(example: list) -> dict:
32
  return gr.Image.update(value=example[0])
33
 
 
47
  det_model = DetModel(device=args.device)
48
  pose_model = PoseModel(device=args.device)
49
 
50
+ with gr.Blocks(theme=args.theme, css='style.css') as demo:
51
+ gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
52
 
53
  with gr.Box():
54
  gr.Markdown('## Step 1')
 
129
  with gr.Row():
130
  redraw_pose_button = gr.Button(value='Redraw')
131
 
132
+ gr.Markdown(FOOTER)
 
 
133
 
134
  detector_name.change(fn=det_model.set_model_name,
135
+ inputs=detector_name,
 
 
136
  outputs=None)
137
  detect_button.click(fn=det_model.detect_and_visualize,
138
  inputs=[
 
149
  det_preds,
150
  vis_det_score_threshold,
151
  ],
152
+ outputs=detection_visualization)
 
 
153
 
154
  pose_model_name.change(fn=pose_model.set_model_name,
155
+ inputs=pose_model_name,
 
 
156
  outputs=None)
157
  predict_button.click(fn=pose_model.predict_pose_and_visualize,
158
  inputs=[
 
175
  vis_dot_radius,
176
  vis_line_thickness,
177
  ],
178
+ outputs=pose_visualization)
 
 
179
 
180
+ example_images.click(
181
+ fn=set_example_image,
182
+ inputs=example_images,
183
+ outputs=input_image,
184
+ )
 
 
185
 
186
  demo.launch(
187
  enable_queue=args.enable_queue,
model.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import subprocess
5
+ import sys
6
+
7
+ if os.getenv('SYSTEM') == 'spaces':
8
+ import mim
9
+
10
+ mim.uninstall('mmcv-full', confirm_yes=True)
11
+ mim.install('mmcv-full==1.5.0', is_yes=True)
12
+
13
+ subprocess.call('pip uninstall -y opencv-python'.split())
14
+ subprocess.call('pip uninstall -y opencv-python-headless'.split())
15
+ subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
16
+
17
+ import huggingface_hub
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ sys.path.insert(0, 'ViTPose/')
23
+
24
+ from mmdet.apis import inference_detector, init_detector
25
+ from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
26
+ process_mmdet_results, vis_pose_result)
27
+
28
+ HF_TOKEN = os.environ['HF_TOKEN']
29
+
30
+
31
+ class DetModel:
32
+ def __init__(self, device: str | torch.device):
33
+ self.device = torch.device(device)
34
+ self.models = self._load_models()
35
+ self.model_name = 'YOLOX-l'
36
+
37
+ def _load_models(self) -> dict[str, nn.Module]:
38
+ model_dict = {
39
+ 'YOLOX-tiny': {
40
+ 'config':
41
+ 'mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py',
42
+ 'model':
43
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth',
44
+ },
45
+ 'YOLOX-s': {
46
+ 'config':
47
+ 'mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py',
48
+ 'model':
49
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth',
50
+ },
51
+ 'YOLOX-l': {
52
+ 'config':
53
+ 'mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py',
54
+ 'model':
55
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
56
+ },
57
+ 'YOLOX-x': {
58
+ 'config':
59
+ 'mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py',
60
+ 'model':
61
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth',
62
+ },
63
+ }
64
+ models = {
65
+ key: init_detector(dic['config'], dic['model'], device=self.device)
66
+ for key, dic in model_dict.items()
67
+ }
68
+ return models
69
+
70
+ def set_model_name(self, name: str) -> None:
71
+ self.model_name = name
72
+
73
+ def detect_and_visualize(
74
+ self, image: np.ndarray,
75
+ score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
76
+ out = self.detect(image)
77
+ vis = self.visualize_detection_results(image, out, score_threshold)
78
+ return out, vis
79
+
80
+ def detect(self, image: np.ndarray) -> list[np.ndarray]:
81
+ image = image[:, :, ::-1] # RGB -> BGR
82
+ model = self.models[self.model_name]
83
+ out = inference_detector(model, image)
84
+ return out
85
+
86
+ def visualize_detection_results(
87
+ self,
88
+ image: np.ndarray,
89
+ detection_results: list[np.ndarray],
90
+ score_threshold: float = 0.3) -> np.ndarray:
91
+ person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)]
92
+
93
+ image = image[:, :, ::-1] # RGB -> BGR
94
+ model = self.models[self.model_name]
95
+ vis = model.show_result(image,
96
+ person_det,
97
+ score_thr=score_threshold,
98
+ bbox_color=None,
99
+ text_color=(200, 200, 200),
100
+ mask_color=None)
101
+ return vis[:, :, ::-1] # BGR -> RGB
102
+
103
+
104
+ class PoseModel:
105
+ def __init__(self, device: str | torch.device):
106
+ self.device = torch.device(device)
107
+ self.models = self._load_models()
108
+ self.model_name = 'ViTPose-B (multi-task train, COCO)'
109
+
110
+ def _load_models(self) -> dict[str, nn.Module]:
111
+ model_dict = {
112
+ 'ViTPose-B (single-task train)': {
113
+ 'config':
114
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
115
+ 'model': 'models/vitpose-b.pth',
116
+ },
117
+ 'ViTPose-L (single-task train)': {
118
+ 'config':
119
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
120
+ 'model': 'models/vitpose-l.pth',
121
+ },
122
+ 'ViTPose-B (multi-task train, COCO)': {
123
+ 'config':
124
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
125
+ 'model': 'models/vitpose-b-multi-coco.pth',
126
+ },
127
+ 'ViTPose-L (multi-task train, COCO)': {
128
+ 'config':
129
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
130
+ 'model': 'models/vitpose-l-multi-coco.pth',
131
+ },
132
+ }
133
+ models = dict()
134
+ for key, dic in model_dict.items():
135
+ ckpt_path = huggingface_hub.hf_hub_download(
136
+ 'hysts/ViTPose', dic['model'], use_auth_token=HF_TOKEN)
137
+ model = init_pose_model(dic['config'],
138
+ ckpt_path,
139
+ device=self.device)
140
+ models[key] = model
141
+ return models
142
+
143
+ def set_model_name(self, name: str) -> None:
144
+ self.model_name = name
145
+
146
+ def predict_pose_and_visualize(
147
+ self,
148
+ image: np.ndarray,
149
+ det_results: list[np.ndarray],
150
+ box_score_threshold: float,
151
+ kpt_score_threshold: float,
152
+ vis_dot_radius: int,
153
+ vis_line_thickness: int,
154
+ ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
155
+ out = self.predict_pose(image, det_results, box_score_threshold)
156
+ vis = self.visualize_pose_results(image, out, kpt_score_threshold,
157
+ vis_dot_radius, vis_line_thickness)
158
+ return out, vis
159
+
160
+ def predict_pose(
161
+ self,
162
+ image: np.ndarray,
163
+ det_results: list[np.ndarray],
164
+ box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
165
+ image = image[:, :, ::-1] # RGB -> BGR
166
+ model = self.models[self.model_name]
167
+ person_results = process_mmdet_results(det_results, 1)
168
+ out, _ = inference_top_down_pose_model(model,
169
+ image,
170
+ person_results=person_results,
171
+ bbox_thr=box_score_threshold,
172
+ format='xyxy')
173
+ return out
174
+
175
+ def visualize_pose_results(self,
176
+ image: np.ndarray,
177
+ pose_results: list[np.ndarray],
178
+ kpt_score_threshold: float = 0.3,
179
+ vis_dot_radius: int = 4,
180
+ vis_line_thickness: int = 1) -> np.ndarray:
181
+ image = image[:, :, ::-1] # RGB -> BGR
182
+ model = self.models[self.model_name]
183
+ vis = vis_pose_result(model,
184
+ image,
185
+ pose_results,
186
+ kpt_score_thr=kpt_score_threshold,
187
+ radius=vis_dot_radius,
188
+ thickness=vis_line_thickness)
189
+ return vis[:, :, ::-1] # BGR -> RGB
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ div#result {
5
+ max-width: 600px;
6
+ max-height: 600px;
7
+ }
8
+ img#visitor-badge {
9
+ display: block;
10
+ margin: auto;
11
+ }