hysts HF staff commited on
Commit
d25bfc0
1 Parent(s): 51e2977
Files changed (3) hide show
  1. app.py +16 -11
  2. model.py +114 -85
  3. style.css +5 -1
app.py CHANGED
@@ -8,7 +8,7 @@ import tarfile
8
 
9
  import gradio as gr
10
 
11
- from model import DetModel, PoseModel
12
 
13
  DESCRIPTION = '''# ViTPose
14
 
@@ -44,8 +44,8 @@ def main():
44
 
45
  extract_tar()
46
 
47
- det_model = DetModel(device=args.device)
48
- pose_model = PoseModel(device=args.device)
49
 
50
  with gr.Blocks(theme=args.theme, css='style.css') as demo:
51
  gr.Markdown(DESCRIPTION)
@@ -59,7 +59,7 @@ def main():
59
  type='numpy')
60
  with gr.Row():
61
  detector_name = gr.Dropdown(list(
62
- det_model.models.keys()),
63
  value=det_model.model_name,
64
  label='Detector')
65
  with gr.Row():
@@ -68,7 +68,9 @@ def main():
68
  with gr.Column():
69
  with gr.Row():
70
  detection_visualization = gr.Image(
71
- label='Detection Result', type='numpy')
 
 
72
  with gr.Row():
73
  vis_det_score_threshold = gr.Slider(
74
  0,
@@ -91,7 +93,7 @@ def main():
91
  with gr.Column():
92
  with gr.Row():
93
  pose_model_name = gr.Dropdown(
94
- list(pose_model.models.keys()),
95
  value=pose_model.model_name,
96
  label='Pose Model')
97
  det_score_threshold = gr.Slider(
@@ -106,7 +108,8 @@ def main():
106
  with gr.Column():
107
  with gr.Row():
108
  pose_visualization = gr.Image(label='Result',
109
- type='numpy')
 
110
  with gr.Row():
111
  vis_kpt_score_threshold = gr.Slider(
112
  0,
@@ -131,11 +134,12 @@ def main():
131
 
132
  gr.Markdown(FOOTER)
133
 
134
- detector_name.change(fn=det_model.set_model_name,
135
  inputs=detector_name,
136
  outputs=None)
137
- detect_button.click(fn=det_model.detect_and_visualize,
138
  inputs=[
 
139
  input_image,
140
  vis_det_score_threshold,
141
  ],
@@ -151,11 +155,12 @@ def main():
151
  ],
152
  outputs=detection_visualization)
153
 
154
- pose_model_name.change(fn=pose_model.set_model_name,
155
  inputs=pose_model_name,
156
  outputs=None)
157
- predict_button.click(fn=pose_model.predict_pose_and_visualize,
158
  inputs=[
 
159
  input_image,
160
  det_preds,
161
  det_score_threshold,
 
8
 
9
  import gradio as gr
10
 
11
+ from model import AppDetModel, AppPoseModel
12
 
13
  DESCRIPTION = '''# ViTPose
14
 
 
44
 
45
  extract_tar()
46
 
47
+ det_model = AppDetModel(device=args.device)
48
+ pose_model = AppPoseModel(device=args.device)
49
 
50
  with gr.Blocks(theme=args.theme, css='style.css') as demo:
51
  gr.Markdown(DESCRIPTION)
 
59
  type='numpy')
60
  with gr.Row():
61
  detector_name = gr.Dropdown(list(
62
+ det_model.MODEL_DICT.keys()),
63
  value=det_model.model_name,
64
  label='Detector')
65
  with gr.Row():
 
68
  with gr.Column():
69
  with gr.Row():
70
  detection_visualization = gr.Image(
71
+ label='Detection Result',
72
+ type='numpy',
73
+ elem_id='det-result')
74
  with gr.Row():
75
  vis_det_score_threshold = gr.Slider(
76
  0,
 
93
  with gr.Column():
94
  with gr.Row():
95
  pose_model_name = gr.Dropdown(
96
+ list(pose_model.MODEL_DICT.keys()),
97
  value=pose_model.model_name,
98
  label='Pose Model')
99
  det_score_threshold = gr.Slider(
 
108
  with gr.Column():
109
  with gr.Row():
110
  pose_visualization = gr.Image(label='Result',
111
+ type='numpy',
112
+ elem_id='pose-result')
113
  with gr.Row():
114
  vis_kpt_score_threshold = gr.Slider(
115
  0,
 
134
 
135
  gr.Markdown(FOOTER)
136
 
137
+ detector_name.change(fn=det_model.set_model,
138
  inputs=detector_name,
139
  outputs=None)
140
+ detect_button.click(fn=det_model.run,
141
  inputs=[
142
+ detector_name,
143
  input_image,
144
  vis_det_score_threshold,
145
  ],
 
155
  ],
156
  outputs=detection_visualization)
157
 
158
+ pose_model_name.change(fn=pose_model.set_model,
159
  inputs=pose_model_name,
160
  outputs=None)
161
+ predict_button.click(fn=pose_model.run,
162
  inputs=[
163
+ pose_model_name,
164
  input_image,
165
  det_preds,
166
  det_score_threshold,
model.py CHANGED
@@ -29,46 +29,52 @@ HF_TOKEN = os.environ['HF_TOKEN']
29
 
30
 
31
  class DetModel:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def __init__(self, device: str | torch.device):
33
  self.device = torch.device(device)
34
- self.models = self._load_models()
35
  self.model_name = 'YOLOX-l'
 
 
 
 
 
 
 
 
 
36
 
37
- def _load_models(self) -> dict[str, nn.Module]:
38
- model_dict = {
39
- 'YOLOX-tiny': {
40
- 'config':
41
- 'mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py',
42
- 'model':
43
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth',
44
- },
45
- 'YOLOX-s': {
46
- 'config':
47
- 'mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py',
48
- 'model':
49
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth',
50
- },
51
- 'YOLOX-l': {
52
- 'config':
53
- 'mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py',
54
- 'model':
55
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
56
- },
57
- 'YOLOX-x': {
58
- 'config':
59
- 'mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py',
60
- 'model':
61
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth',
62
- },
63
- }
64
- models = {
65
- key: init_detector(dic['config'], dic['model'], device=self.device)
66
- for key, dic in model_dict.items()
67
- }
68
- return models
69
-
70
- def set_model_name(self, name: str) -> None:
71
  self.model_name = name
 
72
 
73
  def detect_and_visualize(
74
  self, image: np.ndarray,
@@ -79,8 +85,7 @@ class DetModel:
79
 
80
  def detect(self, image: np.ndarray) -> list[np.ndarray]:
81
  image = image[:, :, ::-1] # RGB -> BGR
82
- model = self.models[self.model_name]
83
- out = inference_detector(model, image)
84
  return out
85
 
86
  def visualize_detection_results(
@@ -88,60 +93,71 @@ class DetModel:
88
  image: np.ndarray,
89
  detection_results: list[np.ndarray],
90
  score_threshold: float = 0.3) -> np.ndarray:
91
- person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)]
92
 
93
  image = image[:, :, ::-1] # RGB -> BGR
94
- model = self.models[self.model_name]
95
- vis = model.show_result(image,
96
- person_det,
97
- score_thr=score_threshold,
98
- bbox_color=None,
99
- text_color=(200, 200, 200),
100
- mask_color=None)
101
  return vis[:, :, ::-1] # BGR -> RGB
102
 
103
 
 
 
 
 
 
 
 
104
  class PoseModel:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def __init__(self, device: str | torch.device):
106
  self.device = torch.device(device)
107
- self.models = self._load_models()
108
  self.model_name = 'ViTPose-B (multi-task train, COCO)'
109
-
110
- def _load_models(self) -> dict[str, nn.Module]:
111
- model_dict = {
112
- 'ViTPose-B (single-task train)': {
113
- 'config':
114
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
115
- 'model': 'models/vitpose-b.pth',
116
- },
117
- 'ViTPose-L (single-task train)': {
118
- 'config':
119
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
120
- 'model': 'models/vitpose-l.pth',
121
- },
122
- 'ViTPose-B (multi-task train, COCO)': {
123
- 'config':
124
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
125
- 'model': 'models/vitpose-b-multi-coco.pth',
126
- },
127
- 'ViTPose-L (multi-task train, COCO)': {
128
- 'config':
129
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
130
- 'model': 'models/vitpose-l-multi-coco.pth',
131
- },
132
- }
133
- models = dict()
134
- for key, dic in model_dict.items():
135
- ckpt_path = huggingface_hub.hf_hub_download(
136
- 'hysts/ViTPose', dic['model'], use_auth_token=HF_TOKEN)
137
- model = init_pose_model(dic['config'],
138
- ckpt_path,
139
- device=self.device)
140
- models[key] = model
141
- return models
142
-
143
- def set_model_name(self, name: str) -> None:
144
  self.model_name = name
 
145
 
146
  def predict_pose_and_visualize(
147
  self,
@@ -163,9 +179,8 @@ class PoseModel:
163
  det_results: list[np.ndarray],
164
  box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
165
  image = image[:, :, ::-1] # RGB -> BGR
166
- model = self.models[self.model_name]
167
  person_results = process_mmdet_results(det_results, 1)
168
- out, _ = inference_top_down_pose_model(model,
169
  image,
170
  person_results=person_results,
171
  bbox_thr=box_score_threshold,
@@ -179,11 +194,25 @@ class PoseModel:
179
  vis_dot_radius: int = 4,
180
  vis_line_thickness: int = 1) -> np.ndarray:
181
  image = image[:, :, ::-1] # RGB -> BGR
182
- model = self.models[self.model_name]
183
- vis = vis_pose_result(model,
184
  image,
185
  pose_results,
186
  kpt_score_thr=kpt_score_threshold,
187
  radius=vis_dot_radius,
188
  thickness=vis_line_thickness)
189
  return vis[:, :, ::-1] # BGR -> RGB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  class DetModel:
32
+ MODEL_DICT = {
33
+ 'YOLOX-tiny': {
34
+ 'config':
35
+ 'mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py',
36
+ 'model':
37
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth',
38
+ },
39
+ 'YOLOX-s': {
40
+ 'config':
41
+ 'mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py',
42
+ 'model':
43
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth',
44
+ },
45
+ 'YOLOX-l': {
46
+ 'config':
47
+ 'mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py',
48
+ 'model':
49
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
50
+ },
51
+ 'YOLOX-x': {
52
+ 'config':
53
+ 'mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py',
54
+ 'model':
55
+ 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth',
56
+ },
57
+ }
58
+
59
  def __init__(self, device: str | torch.device):
60
  self.device = torch.device(device)
61
+ self._load_all_models_once()
62
  self.model_name = 'YOLOX-l'
63
+ self.model = self._load_model(self.model_name)
64
+
65
+ def _load_all_models_once(self) -> None:
66
+ for name in self.MODEL_DICT:
67
+ self._load_model(name)
68
+
69
+ def _load_model(self, name: str) -> nn.Module:
70
+ dic = self.MODEL_DICT[name]
71
+ return init_detector(dic['config'], dic['model'], device=self.device)
72
 
73
+ def set_model(self, name: str) -> None:
74
+ if name == self.model_name:
75
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  self.model_name = name
77
+ self.model = self._load_model(name)
78
 
79
  def detect_and_visualize(
80
  self, image: np.ndarray,
 
85
 
86
  def detect(self, image: np.ndarray) -> list[np.ndarray]:
87
  image = image[:, :, ::-1] # RGB -> BGR
88
+ out = inference_detector(self.model, image)
 
89
  return out
90
 
91
  def visualize_detection_results(
 
93
  image: np.ndarray,
94
  detection_results: list[np.ndarray],
95
  score_threshold: float = 0.3) -> np.ndarray:
96
+ person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)] * 79
97
 
98
  image = image[:, :, ::-1] # RGB -> BGR
99
+ vis = self.model.show_result(image,
100
+ person_det,
101
+ score_thr=score_threshold,
102
+ bbox_color=None,
103
+ text_color=(200, 200, 200),
104
+ mask_color=None)
 
105
  return vis[:, :, ::-1] # BGR -> RGB
106
 
107
 
108
+ class AppDetModel(DetModel):
109
+ def run(self, model_name: str, image: np.ndarray,
110
+ score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
111
+ self.set_model(model_name)
112
+ return self.detect_and_visualize(image, score_threshold)
113
+
114
+
115
  class PoseModel:
116
+ MODEL_DICT = {
117
+ 'ViTPose-B (single-task train)': {
118
+ 'config':
119
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
120
+ 'model': 'models/vitpose-b.pth',
121
+ },
122
+ 'ViTPose-L (single-task train)': {
123
+ 'config':
124
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
125
+ 'model': 'models/vitpose-l.pth',
126
+ },
127
+ 'ViTPose-B (multi-task train, COCO)': {
128
+ 'config':
129
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
130
+ 'model': 'models/vitpose-b-multi-coco.pth',
131
+ },
132
+ 'ViTPose-L (multi-task train, COCO)': {
133
+ 'config':
134
+ 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
135
+ 'model': 'models/vitpose-l-multi-coco.pth',
136
+ },
137
+ }
138
+
139
  def __init__(self, device: str | torch.device):
140
  self.device = torch.device(device)
 
141
  self.model_name = 'ViTPose-B (multi-task train, COCO)'
142
+ self.model = self._load_model(self.model_name)
143
+
144
+ def _load_all_models_once(self) -> None:
145
+ for name in self.MODEL_DICT:
146
+ self._load_model(name)
147
+
148
+ def _load_model(self, name: str) -> nn.Module:
149
+ dic = self.MODEL_DICT[name]
150
+ ckpt_path = huggingface_hub.hf_hub_download('hysts/ViTPose',
151
+ dic['model'],
152
+ use_auth_token=HF_TOKEN)
153
+ model = init_pose_model(dic['config'], ckpt_path, device=self.device)
154
+ return model
155
+
156
+ def set_model(self, name: str) -> None:
157
+ if name == self.model_name:
158
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  self.model_name = name
160
+ self.model = self._load_model(name)
161
 
162
  def predict_pose_and_visualize(
163
  self,
 
179
  det_results: list[np.ndarray],
180
  box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
181
  image = image[:, :, ::-1] # RGB -> BGR
 
182
  person_results = process_mmdet_results(det_results, 1)
183
+ out, _ = inference_top_down_pose_model(self.model,
184
  image,
185
  person_results=person_results,
186
  bbox_thr=box_score_threshold,
 
194
  vis_dot_radius: int = 4,
195
  vis_line_thickness: int = 1) -> np.ndarray:
196
  image = image[:, :, ::-1] # RGB -> BGR
197
+ vis = vis_pose_result(self.model,
 
198
  image,
199
  pose_results,
200
  kpt_score_thr=kpt_score_threshold,
201
  radius=vis_dot_radius,
202
  thickness=vis_line_thickness)
203
  return vis[:, :, ::-1] # BGR -> RGB
204
+
205
+
206
+ class AppPoseModel(PoseModel):
207
+ def run(
208
+ self, model_name: str, image: np.ndarray,
209
+ det_results: list[np.ndarray], box_score_threshold: float,
210
+ kpt_score_threshold: float, vis_dot_radius: int,
211
+ vis_line_thickness: int
212
+ ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
213
+ self.set_model(model_name)
214
+ return self.predict_pose_and_visualize(image, det_results,
215
+ box_score_threshold,
216
+ kpt_score_threshold,
217
+ vis_dot_radius,
218
+ vis_line_thickness)
style.css CHANGED
@@ -1,7 +1,11 @@
1
  h1 {
2
  text-align: center;
3
  }
4
- div#result {
 
 
 
 
5
  max-width: 600px;
6
  max-height: 600px;
7
  }
 
1
  h1 {
2
  text-align: center;
3
  }
4
+ div#det-result {
5
+ max-width: 600px;
6
+ max-height: 600px;
7
+ }
8
+ div#pose-result {
9
  max-width: 600px;
10
  max-height: 600px;
11
  }