hysts HF staff commited on
Commit
9b1e028
1 Parent(s): 5de7473
Files changed (6) hide show
  1. .pre-commit-config.yaml +59 -36
  2. .style.yapf +0 -5
  3. README.md +1 -1
  4. app.py +88 -108
  5. model.py +84 -102
  6. style.css +1 -4
.pre-commit-config.yaml CHANGED
@@ -1,37 +1,60 @@
1
- exclude: ^(ViTPose/|mmdet_configs/configs/)
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
- hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📊
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
 
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
app.py CHANGED
@@ -9,14 +9,14 @@ import gradio as gr
9
 
10
  from model import AppDetModel, AppPoseModel
11
 
12
- DESCRIPTION = '# [ViTPose](https://github.com/ViTAE-Transformer/ViTPose)'
13
 
14
 
15
  def extract_tar() -> None:
16
- if pathlib.Path('mmdet_configs/configs').exists():
17
  return
18
- with tarfile.open('mmdet_configs/configs.tar') as f:
19
- f.extractall('mmdet_configs')
20
 
21
 
22
  extract_tar()
@@ -24,135 +24,115 @@ extract_tar()
24
  det_model = AppDetModel()
25
  pose_model = AppPoseModel()
26
 
27
- with gr.Blocks(css='style.css') as demo:
28
  gr.Markdown(DESCRIPTION)
29
 
30
- with gr.Box():
31
- gr.Markdown('## Step 1')
32
  with gr.Row():
33
  with gr.Column():
34
  with gr.Row():
35
- input_image = gr.Image(label='Input Image', type='numpy')
36
  with gr.Row():
37
  detector_name = gr.Dropdown(
38
- label='Detector',
39
- choices=list(det_model.MODEL_DICT.keys()),
40
- value=det_model.model_name)
41
  with gr.Row():
42
- detect_button = gr.Button('Detect')
43
- det_preds = gr.Variable()
44
  with gr.Column():
45
  with gr.Row():
46
- detection_visualization = gr.Image(
47
- label='Detection Result',
48
- type='numpy',
49
- elem_id='det-result')
50
  with gr.Row():
51
  vis_det_score_threshold = gr.Slider(
52
- label='Visualization Score Threshold',
53
- minimum=0,
54
- maximum=1,
55
- step=0.05,
56
- value=0.5)
57
  with gr.Row():
58
- redraw_det_button = gr.Button(value='Redraw')
59
 
60
  with gr.Row():
61
- paths = sorted(pathlib.Path('images').rglob('*.jpg'))
62
- example_images = gr.Examples(examples=[[path.as_posix()]
63
- for path in paths],
64
- inputs=input_image)
65
 
66
- with gr.Box():
67
- gr.Markdown('## Step 2')
68
  with gr.Row():
69
  with gr.Column():
70
  with gr.Row():
71
  pose_model_name = gr.Dropdown(
72
- label='Pose Model',
73
- choices=list(pose_model.MODEL_DICT.keys()),
74
- value=pose_model.model_name)
75
- det_score_threshold = gr.Slider(label='Box Score Threshold',
76
- minimum=0,
77
- maximum=1,
78
- step=0.05,
79
- value=0.5)
80
  with gr.Row():
81
- predict_button = gr.Button('Predict')
82
- pose_preds = gr.Variable()
83
  with gr.Column():
84
  with gr.Row():
85
- pose_visualization = gr.Image(label='Result',
86
- type='numpy',
87
- elem_id='pose-result')
88
  with gr.Row():
89
  vis_kpt_score_threshold = gr.Slider(
90
- label='Visualization Score Threshold',
91
- minimum=0,
92
- maximum=1,
93
- step=0.05,
94
- value=0.3)
95
  with gr.Row():
96
- vis_dot_radius = gr.Slider(label='Dot Radius',
97
- minimum=1,
98
- maximum=10,
99
- step=1,
100
- value=4)
101
  with gr.Row():
102
- vis_line_thickness = gr.Slider(label='Line Thickness',
103
- minimum=1,
104
- maximum=10,
105
- step=1,
106
- value=2)
107
  with gr.Row():
108
- redraw_pose_button = gr.Button('Redraw')
109
-
110
- detector_name.change(fn=det_model.set_model,
111
- inputs=detector_name,
112
- outputs=None)
113
- detect_button.click(fn=det_model.run,
114
- inputs=[
115
- detector_name,
116
- input_image,
117
- vis_det_score_threshold,
118
- ],
119
- outputs=[
120
- det_preds,
121
- detection_visualization,
122
- ])
123
- redraw_det_button.click(fn=det_model.visualize_detection_results,
124
- inputs=[
125
- input_image,
126
- det_preds,
127
- vis_det_score_threshold,
128
- ],
129
- outputs=detection_visualization)
130
-
131
- pose_model_name.change(fn=pose_model.set_model,
132
- inputs=pose_model_name,
133
- outputs=None)
134
- predict_button.click(fn=pose_model.run,
135
- inputs=[
136
- pose_model_name,
137
- input_image,
138
- det_preds,
139
- det_score_threshold,
140
- vis_kpt_score_threshold,
141
- vis_dot_radius,
142
- vis_line_thickness,
143
- ],
144
- outputs=[
145
- pose_preds,
146
- pose_visualization,
147
- ])
148
- redraw_pose_button.click(fn=pose_model.visualize_pose_results,
149
- inputs=[
150
- input_image,
151
- pose_preds,
152
- vis_kpt_score_threshold,
153
- vis_dot_radius,
154
- vis_line_thickness,
155
- ],
156
- outputs=pose_visualization)
157
-
158
- demo.queue(max_size=10).launch()
 
 
 
 
 
 
9
 
10
  from model import AppDetModel, AppPoseModel
11
 
12
+ DESCRIPTION = "# [ViTPose](https://github.com/ViTAE-Transformer/ViTPose)"
13
 
14
 
15
  def extract_tar() -> None:
16
+ if pathlib.Path("mmdet_configs/configs").exists():
17
  return
18
+ with tarfile.open("mmdet_configs/configs.tar") as f:
19
+ f.extractall("mmdet_configs")
20
 
21
 
22
  extract_tar()
 
24
  det_model = AppDetModel()
25
  pose_model = AppPoseModel()
26
 
27
+ with gr.Blocks(css="style.css") as demo:
28
  gr.Markdown(DESCRIPTION)
29
 
30
+ with gr.Group():
31
+ gr.Markdown("## Step 1")
32
  with gr.Row():
33
  with gr.Column():
34
  with gr.Row():
35
+ input_image = gr.Image(label="Input Image", type="numpy")
36
  with gr.Row():
37
  detector_name = gr.Dropdown(
38
+ label="Detector", choices=list(det_model.MODEL_DICT.keys()), value=det_model.model_name
39
+ )
 
40
  with gr.Row():
41
+ detect_button = gr.Button("Detect")
42
+ det_preds = gr.State()
43
  with gr.Column():
44
  with gr.Row():
45
+ detection_visualization = gr.Image(label="Detection Result", type="numpy", elem_id="det-result")
 
 
 
46
  with gr.Row():
47
  vis_det_score_threshold = gr.Slider(
48
+ label="Visualization Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5
49
+ )
 
 
 
50
  with gr.Row():
51
+ redraw_det_button = gr.Button(value="Redraw")
52
 
53
  with gr.Row():
54
+ paths = sorted(pathlib.Path("images").rglob("*.jpg"))
55
+ example_images = gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_image)
 
 
56
 
57
+ with gr.Group():
58
+ gr.Markdown("## Step 2")
59
  with gr.Row():
60
  with gr.Column():
61
  with gr.Row():
62
  pose_model_name = gr.Dropdown(
63
+ label="Pose Model", choices=list(pose_model.MODEL_DICT.keys()), value=pose_model.model_name
64
+ )
65
+ det_score_threshold = gr.Slider(
66
+ label="Box Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5
67
+ )
 
 
 
68
  with gr.Row():
69
+ predict_button = gr.Button("Predict")
70
+ pose_preds = gr.State()
71
  with gr.Column():
72
  with gr.Row():
73
+ pose_visualization = gr.Image(label="Result", type="numpy", elem_id="pose-result")
 
 
74
  with gr.Row():
75
  vis_kpt_score_threshold = gr.Slider(
76
+ label="Visualization Score Threshold", minimum=0, maximum=1, step=0.05, value=0.3
77
+ )
 
 
 
78
  with gr.Row():
79
+ vis_dot_radius = gr.Slider(label="Dot Radius", minimum=1, maximum=10, step=1, value=4)
 
 
 
 
80
  with gr.Row():
81
+ vis_line_thickness = gr.Slider(label="Line Thickness", minimum=1, maximum=10, step=1, value=2)
 
 
 
 
82
  with gr.Row():
83
+ redraw_pose_button = gr.Button("Redraw")
84
+
85
+ detector_name.change(fn=det_model.set_model, inputs=detector_name)
86
+ detect_button.click(
87
+ fn=det_model.run,
88
+ inputs=[
89
+ detector_name,
90
+ input_image,
91
+ vis_det_score_threshold,
92
+ ],
93
+ outputs=[
94
+ det_preds,
95
+ detection_visualization,
96
+ ],
97
+ )
98
+ redraw_det_button.click(
99
+ fn=det_model.visualize_detection_results,
100
+ inputs=[
101
+ input_image,
102
+ det_preds,
103
+ vis_det_score_threshold,
104
+ ],
105
+ outputs=detection_visualization,
106
+ )
107
+
108
+ pose_model_name.change(fn=pose_model.set_model, inputs=pose_model_name)
109
+ predict_button.click(
110
+ fn=pose_model.run,
111
+ inputs=[
112
+ pose_model_name,
113
+ input_image,
114
+ det_preds,
115
+ det_score_threshold,
116
+ vis_kpt_score_threshold,
117
+ vis_dot_radius,
118
+ vis_line_thickness,
119
+ ],
120
+ outputs=[
121
+ pose_preds,
122
+ pose_visualization,
123
+ ],
124
+ )
125
+ redraw_pose_button.click(
126
+ fn=pose_model.visualize_pose_results,
127
+ inputs=[
128
+ input_image,
129
+ pose_preds,
130
+ vis_kpt_score_threshold,
131
+ vis_dot_radius,
132
+ vis_line_thickness,
133
+ ],
134
+ outputs=pose_visualization,
135
+ )
136
+
137
+ if __name__ == "__main__":
138
+ demo.queue(max_size=10).launch()
model.py CHANGED
@@ -6,15 +6,15 @@ import shlex
6
  import subprocess
7
  import sys
8
 
9
- if os.getenv('SYSTEM') == 'spaces':
10
  import mim
11
 
12
- mim.uninstall('mmcv-full', confirm_yes=True)
13
- mim.install('mmcv-full==1.5.0', is_yes=True)
14
 
15
- subprocess.run(shlex.split('pip uninstall -y opencv-python'))
16
- subprocess.run(shlex.split('pip uninstall -y opencv-python-headless'))
17
- subprocess.run(shlex.split('pip install opencv-python-headless==4.8.0.74'))
18
 
19
  import huggingface_hub
20
  import numpy as np
@@ -22,47 +22,42 @@ import torch
22
  import torch.nn as nn
23
 
24
  app_dir = pathlib.Path(__file__).parent
25
- submodule_dir = app_dir / 'ViTPose'
26
  sys.path.insert(0, submodule_dir.as_posix())
27
 
28
  from mmdet.apis import inference_detector, init_detector
29
- from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
30
- process_mmdet_results, vis_pose_result)
 
 
 
 
31
 
32
 
33
  class DetModel:
34
  MODEL_DICT = {
35
- 'YOLOX-tiny': {
36
- 'config':
37
- 'mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py',
38
- 'model':
39
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth',
40
  },
41
- 'YOLOX-s': {
42
- 'config':
43
- 'mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py',
44
- 'model':
45
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth',
46
  },
47
- 'YOLOX-l': {
48
- 'config':
49
- 'mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py',
50
- 'model':
51
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
52
  },
53
- 'YOLOX-x': {
54
- 'config':
55
- 'mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py',
56
- 'model':
57
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth',
58
  },
59
  }
60
 
61
  def __init__(self):
62
- self.device = torch.device(
63
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
64
  self._load_all_models_once()
65
- self.model_name = 'YOLOX-l'
66
  self.model = self._load_model(self.model_name)
67
 
68
  def _load_all_models_once(self) -> None:
@@ -71,7 +66,7 @@ class DetModel:
71
 
72
  def _load_model(self, name: str) -> nn.Module:
73
  d = self.MODEL_DICT[name]
74
- return init_detector(d['config'], d['model'], device=self.device)
75
 
76
  def set_model(self, name: str) -> None:
77
  if name == self.model_name:
@@ -79,9 +74,7 @@ class DetModel:
79
  self.model_name = name
80
  self.model = self._load_model(name)
81
 
82
- def detect_and_visualize(
83
- self, image: np.ndarray,
84
- score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
85
  out = self.detect(image)
86
  vis = self.visualize_detection_results(image, out, score_threshold)
87
  return out, vis
@@ -92,57 +85,46 @@ class DetModel:
92
  return out
93
 
94
  def visualize_detection_results(
95
- self,
96
- image: np.ndarray,
97
- detection_results: list[np.ndarray],
98
- score_threshold: float = 0.3) -> np.ndarray:
99
  person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)] * 79
100
 
101
  image = image[:, :, ::-1] # RGB -> BGR
102
- vis = self.model.show_result(image,
103
- person_det,
104
- score_thr=score_threshold,
105
- bbox_color=None,
106
- text_color=(200, 200, 200),
107
- mask_color=None)
108
  return vis[:, :, ::-1] # BGR -> RGB
109
 
110
 
111
  class AppDetModel(DetModel):
112
- def run(self, model_name: str, image: np.ndarray,
113
- score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
114
  self.set_model(model_name)
115
  return self.detect_and_visualize(image, score_threshold)
116
 
117
 
118
  class PoseModel:
119
  MODEL_DICT = {
120
- 'ViTPose-B (single-task train)': {
121
- 'config':
122
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
123
- 'model': 'models/vitpose-b.pth',
124
  },
125
- 'ViTPose-L (single-task train)': {
126
- 'config':
127
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
128
- 'model': 'models/vitpose-l.pth',
129
  },
130
- 'ViTPose-B (multi-task train, COCO)': {
131
- 'config':
132
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
133
- 'model': 'models/vitpose-b-multi-coco.pth',
134
  },
135
- 'ViTPose-L (multi-task train, COCO)': {
136
- 'config':
137
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
138
- 'model': 'models/vitpose-l-multi-coco.pth',
139
  },
140
  }
141
 
142
  def __init__(self):
143
- self.device = torch.device(
144
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
145
- self.model_name = 'ViTPose-B (multi-task train, COCO)'
146
  self.model = self._load_model(self.model_name)
147
 
148
  def _load_all_models_once(self) -> None:
@@ -151,9 +133,8 @@ class PoseModel:
151
 
152
  def _load_model(self, name: str) -> nn.Module:
153
  d = self.MODEL_DICT[name]
154
- ckpt_path = huggingface_hub.hf_hub_download('public-data/ViTPose',
155
- d['model'])
156
- model = init_pose_model(d['config'], ckpt_path, device=self.device)
157
  return model
158
 
159
  def set_model(self, name: str) -> None:
@@ -172,50 +153,51 @@ class PoseModel:
172
  vis_line_thickness: int,
173
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
174
  out = self.predict_pose(image, det_results, box_score_threshold)
175
- vis = self.visualize_pose_results(image, out, kpt_score_threshold,
176
- vis_dot_radius, vis_line_thickness)
177
  return out, vis
178
 
179
  def predict_pose(
180
- self,
181
- image: np.ndarray,
182
- det_results: list[np.ndarray],
183
- box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
184
  image = image[:, :, ::-1] # RGB -> BGR
185
  person_results = process_mmdet_results(det_results, 1)
186
- out, _ = inference_top_down_pose_model(self.model,
187
- image,
188
- person_results=person_results,
189
- bbox_thr=box_score_threshold,
190
- format='xyxy')
191
  return out
192
 
193
- def visualize_pose_results(self,
194
- image: np.ndarray,
195
- pose_results: list[np.ndarray],
196
- kpt_score_threshold: float = 0.3,
197
- vis_dot_radius: int = 4,
198
- vis_line_thickness: int = 1) -> np.ndarray:
 
 
199
  image = image[:, :, ::-1] # RGB -> BGR
200
- vis = vis_pose_result(self.model,
201
- image,
202
- pose_results,
203
- kpt_score_thr=kpt_score_threshold,
204
- radius=vis_dot_radius,
205
- thickness=vis_line_thickness)
 
 
206
  return vis[:, :, ::-1] # BGR -> RGB
207
 
208
 
209
  class AppPoseModel(PoseModel):
210
  def run(
211
- self, model_name: str, image: np.ndarray,
212
- det_results: list[np.ndarray], box_score_threshold: float,
213
- kpt_score_threshold: float, vis_dot_radius: int,
214
- vis_line_thickness: int
 
 
 
 
215
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
216
  self.set_model(model_name)
217
- return self.predict_pose_and_visualize(image, det_results,
218
- box_score_threshold,
219
- kpt_score_threshold,
220
- vis_dot_radius,
221
- vis_line_thickness)
 
6
  import subprocess
7
  import sys
8
 
9
+ if os.getenv("SYSTEM") == "spaces":
10
  import mim
11
 
12
+ mim.uninstall("mmcv-full", confirm_yes=True)
13
+ mim.install("mmcv-full==1.5.0", is_yes=True)
14
 
15
+ subprocess.run(shlex.split("pip uninstall -y opencv-python"))
16
+ subprocess.run(shlex.split("pip uninstall -y opencv-python-headless"))
17
+ subprocess.run(shlex.split("pip install opencv-python-headless==4.8.0.74"))
18
 
19
  import huggingface_hub
20
  import numpy as np
 
22
  import torch.nn as nn
23
 
24
  app_dir = pathlib.Path(__file__).parent
25
+ submodule_dir = app_dir / "ViTPose"
26
  sys.path.insert(0, submodule_dir.as_posix())
27
 
28
  from mmdet.apis import inference_detector, init_detector
29
+ from mmpose.apis import (
30
+ inference_top_down_pose_model,
31
+ init_pose_model,
32
+ process_mmdet_results,
33
+ vis_pose_result,
34
+ )
35
 
36
 
37
  class DetModel:
38
  MODEL_DICT = {
39
+ "YOLOX-tiny": {
40
+ "config": "mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py",
41
+ "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth",
 
 
42
  },
43
+ "YOLOX-s": {
44
+ "config": "mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py",
45
+ "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth",
 
 
46
  },
47
+ "YOLOX-l": {
48
+ "config": "mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py",
49
+ "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth",
 
 
50
  },
51
+ "YOLOX-x": {
52
+ "config": "mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py",
53
+ "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth",
 
 
54
  },
55
  }
56
 
57
  def __init__(self):
58
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
59
  self._load_all_models_once()
60
+ self.model_name = "YOLOX-l"
61
  self.model = self._load_model(self.model_name)
62
 
63
  def _load_all_models_once(self) -> None:
 
66
 
67
  def _load_model(self, name: str) -> nn.Module:
68
  d = self.MODEL_DICT[name]
69
+ return init_detector(d["config"], d["model"], device=self.device)
70
 
71
  def set_model(self, name: str) -> None:
72
  if name == self.model_name:
 
74
  self.model_name = name
75
  self.model = self._load_model(name)
76
 
77
+ def detect_and_visualize(self, image: np.ndarray, score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
 
 
78
  out = self.detect(image)
79
  vis = self.visualize_detection_results(image, out, score_threshold)
80
  return out, vis
 
85
  return out
86
 
87
  def visualize_detection_results(
88
+ self, image: np.ndarray, detection_results: list[np.ndarray], score_threshold: float = 0.3
89
+ ) -> np.ndarray:
 
 
90
  person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)] * 79
91
 
92
  image = image[:, :, ::-1] # RGB -> BGR
93
+ vis = self.model.show_result(
94
+ image, person_det, score_thr=score_threshold, bbox_color=None, text_color=(200, 200, 200), mask_color=None
95
+ )
 
 
 
96
  return vis[:, :, ::-1] # BGR -> RGB
97
 
98
 
99
  class AppDetModel(DetModel):
100
+ def run(self, model_name: str, image: np.ndarray, score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
 
101
  self.set_model(model_name)
102
  return self.detect_and_visualize(image, score_threshold)
103
 
104
 
105
  class PoseModel:
106
  MODEL_DICT = {
107
+ "ViTPose-B (single-task train)": {
108
+ "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py",
109
+ "model": "models/vitpose-b.pth",
 
110
  },
111
+ "ViTPose-L (single-task train)": {
112
+ "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py",
113
+ "model": "models/vitpose-l.pth",
 
114
  },
115
+ "ViTPose-B (multi-task train, COCO)": {
116
+ "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py",
117
+ "model": "models/vitpose-b-multi-coco.pth",
 
118
  },
119
+ "ViTPose-L (multi-task train, COCO)": {
120
+ "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py",
121
+ "model": "models/vitpose-l-multi-coco.pth",
 
122
  },
123
  }
124
 
125
  def __init__(self):
126
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
127
+ self.model_name = "ViTPose-B (multi-task train, COCO)"
 
128
  self.model = self._load_model(self.model_name)
129
 
130
  def _load_all_models_once(self) -> None:
 
133
 
134
  def _load_model(self, name: str) -> nn.Module:
135
  d = self.MODEL_DICT[name]
136
+ ckpt_path = huggingface_hub.hf_hub_download("public-data/ViTPose", d["model"])
137
+ model = init_pose_model(d["config"], ckpt_path, device=self.device)
 
138
  return model
139
 
140
  def set_model(self, name: str) -> None:
 
153
  vis_line_thickness: int,
154
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
155
  out = self.predict_pose(image, det_results, box_score_threshold)
156
+ vis = self.visualize_pose_results(image, out, kpt_score_threshold, vis_dot_radius, vis_line_thickness)
 
157
  return out, vis
158
 
159
  def predict_pose(
160
+ self, image: np.ndarray, det_results: list[np.ndarray], box_score_threshold: float = 0.5
161
+ ) -> list[dict[str, np.ndarray]]:
 
 
162
  image = image[:, :, ::-1] # RGB -> BGR
163
  person_results = process_mmdet_results(det_results, 1)
164
+ out, _ = inference_top_down_pose_model(
165
+ self.model, image, person_results=person_results, bbox_thr=box_score_threshold, format="xyxy"
166
+ )
 
 
167
  return out
168
 
169
+ def visualize_pose_results(
170
+ self,
171
+ image: np.ndarray,
172
+ pose_results: list[np.ndarray],
173
+ kpt_score_threshold: float = 0.3,
174
+ vis_dot_radius: int = 4,
175
+ vis_line_thickness: int = 1,
176
+ ) -> np.ndarray:
177
  image = image[:, :, ::-1] # RGB -> BGR
178
+ vis = vis_pose_result(
179
+ self.model,
180
+ image,
181
+ pose_results,
182
+ kpt_score_thr=kpt_score_threshold,
183
+ radius=vis_dot_radius,
184
+ thickness=vis_line_thickness,
185
+ )
186
  return vis[:, :, ::-1] # BGR -> RGB
187
 
188
 
189
  class AppPoseModel(PoseModel):
190
  def run(
191
+ self,
192
+ model_name: str,
193
+ image: np.ndarray,
194
+ det_results: list[np.ndarray],
195
+ box_score_threshold: float,
196
+ kpt_score_threshold: float,
197
+ vis_dot_radius: int,
198
+ vis_line_thickness: int,
199
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
200
  self.set_model(model_name)
201
+ return self.predict_pose_and_visualize(
202
+ image, det_results, box_score_threshold, kpt_score_threshold, vis_dot_radius, vis_line_thickness
203
+ )
 
 
style.css CHANGED
@@ -1,5 +1,6 @@
1
  h1 {
2
  text-align: center;
 
3
  }
4
  div#det-result {
5
  max-width: 600px;
@@ -9,7 +10,3 @@ div#pose-result {
9
  max-width: 600px;
10
  max-height: 600px;
11
  }
12
- img#visitor-badge {
13
- display: block;
14
- margin: auto;
15
- }
 
1
  h1 {
2
  text-align: center;
3
+ display: block;
4
  }
5
  div#det-result {
6
  max-width: 600px;
 
10
  max-width: 600px;
11
  max-height: 600px;
12
  }