hysts HF staff commited on
Commit
5a9bbeb
β€’
1 Parent(s): 1da2a40
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +126 -162
  3. model.py +12 -9
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ“Š
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.1.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.21.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -10,22 +10,7 @@ import gradio as gr
10
 
11
  from model import AppDetModel, AppPoseModel
12
 
13
- DESCRIPTION = '''# ViTPose
14
-
15
- This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](https://github.com/ViTAE-Transformer/ViTPose).'''
16
- FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.vitpose" />'
17
-
18
-
19
- def parse_args() -> argparse.Namespace:
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument('--device', type=str, default='cpu')
22
- parser.add_argument('--theme', type=str)
23
- parser.add_argument('--share', action='store_true')
24
- parser.add_argument('--port', type=int)
25
- parser.add_argument('--disable-queue',
26
- dest='enable_queue',
27
- action='store_false')
28
- return parser.parse_args()
29
 
30
 
31
  def set_example_image(example: list) -> dict:
@@ -39,161 +24,140 @@ def extract_tar() -> None:
39
  f.extractall('mmdet_configs')
40
 
41
 
42
- def main():
43
- args = parse_args()
44
-
45
- extract_tar()
46
-
47
- det_model = AppDetModel(device=args.device)
48
- pose_model = AppPoseModel(device=args.device)
49
-
50
- with gr.Blocks(theme=args.theme, css='style.css') as demo:
51
- gr.Markdown(DESCRIPTION)
52
-
53
- with gr.Box():
54
- gr.Markdown('## Step 1')
55
- with gr.Row():
56
- with gr.Column():
57
- with gr.Row():
58
- input_image = gr.Image(label='Input Image',
59
- type='numpy')
60
- with gr.Row():
61
- detector_name = gr.Dropdown(list(
62
- det_model.MODEL_DICT.keys()),
63
- value=det_model.model_name,
64
- label='Detector')
65
- with gr.Row():
66
- detect_button = gr.Button(value='Detect')
67
- det_preds = gr.Variable()
68
- with gr.Column():
69
- with gr.Row():
70
- detection_visualization = gr.Image(
71
- label='Detection Result',
72
- type='numpy',
73
- elem_id='det-result')
74
- with gr.Row():
75
- vis_det_score_threshold = gr.Slider(
76
- 0,
77
- 1,
78
- step=0.05,
79
- value=0.5,
80
- label='Visualization Score Threshold')
81
- with gr.Row():
82
- redraw_det_button = gr.Button(value='Redraw')
83
-
84
- with gr.Row():
85
- paths = sorted(pathlib.Path('images').rglob('*.jpg'))
86
- example_images = gr.Dataset(components=[input_image],
87
- samples=[[path.as_posix()]
88
- for path in paths])
89
-
90
- with gr.Box():
91
- gr.Markdown('## Step 2')
92
- with gr.Row():
93
- with gr.Column():
94
- with gr.Row():
95
- pose_model_name = gr.Dropdown(
96
- list(pose_model.MODEL_DICT.keys()),
97
- value=pose_model.model_name,
98
- label='Pose Model')
99
- det_score_threshold = gr.Slider(
100
- 0,
101
- 1,
102
  step=0.05,
103
- value=0.5,
104
- label='Box Score Threshold')
105
- with gr.Row():
106
- predict_button = gr.Button(value='Predict')
107
- pose_preds = gr.Variable()
108
- with gr.Column():
109
- with gr.Row():
110
- pose_visualization = gr.Image(label='Result',
111
- type='numpy',
112
- elem_id='pose-result')
113
- with gr.Row():
114
- vis_kpt_score_threshold = gr.Slider(
115
- 0,
116
- 1,
117
- step=0.05,
118
- value=0.3,
119
- label='Visualization Score Threshold')
120
- with gr.Row():
121
- vis_dot_radius = gr.Slider(1,
122
- 10,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  step=1,
124
- value=4,
125
- label='Dot Radius')
126
- with gr.Row():
127
- vis_line_thickness = gr.Slider(1,
128
- 10,
129
- step=1,
130
- value=2,
131
- label='Line Thickness')
132
- with gr.Row():
133
- redraw_pose_button = gr.Button(value='Redraw')
134
-
135
- gr.Markdown(FOOTER)
136
-
137
- detector_name.change(fn=det_model.set_model,
138
- inputs=detector_name,
139
- outputs=None)
140
- detect_button.click(fn=det_model.run,
 
141
  inputs=[
142
- detector_name,
143
  input_image,
 
144
  vis_det_score_threshold,
145
  ],
146
- outputs=[
147
- det_preds,
148
- detection_visualization,
149
- ])
150
- redraw_det_button.click(fn=det_model.visualize_detection_results,
151
- inputs=[
152
- input_image,
153
- det_preds,
154
- vis_det_score_threshold,
155
- ],
156
- outputs=detection_visualization)
157
-
158
- pose_model_name.change(fn=pose_model.set_model,
159
- inputs=pose_model_name,
160
- outputs=None)
161
- predict_button.click(fn=pose_model.run,
 
 
 
 
162
  inputs=[
163
- pose_model_name,
164
  input_image,
165
- det_preds,
166
- det_score_threshold,
167
  vis_kpt_score_threshold,
168
  vis_dot_radius,
169
  vis_line_thickness,
170
  ],
171
- outputs=[
172
- pose_preds,
173
- pose_visualization,
174
- ])
175
- redraw_pose_button.click(fn=pose_model.visualize_pose_results,
176
- inputs=[
177
- input_image,
178
- pose_preds,
179
- vis_kpt_score_threshold,
180
- vis_dot_radius,
181
- vis_line_thickness,
182
- ],
183
- outputs=pose_visualization)
184
-
185
- example_images.click(
186
- fn=set_example_image,
187
- inputs=example_images,
188
- outputs=input_image,
189
- )
190
-
191
- demo.launch(
192
- enable_queue=args.enable_queue,
193
- server_port=args.port,
194
- share=args.share,
195
- )
196
-
197
-
198
- if __name__ == '__main__':
199
- main()
 
10
 
11
  from model import AppDetModel, AppPoseModel
12
 
13
+ DESCRIPTION = '# [ViTPose](https://github.com/ViTAE-Transformer/ViTPose)'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def set_example_image(example: list) -> dict:
 
24
  f.extractall('mmdet_configs')
25
 
26
 
27
+ extract_tar()
28
+
29
+ det_model = AppDetModel()
30
+ pose_model = AppPoseModel()
31
+
32
+ with gr.Blocks(css='style.css') as demo:
33
+ gr.Markdown(DESCRIPTION)
34
+
35
+ with gr.Box():
36
+ gr.Markdown('## Step 1')
37
+ with gr.Row():
38
+ with gr.Column():
39
+ with gr.Row():
40
+ input_image = gr.Image(label='Input Image', type='numpy')
41
+ with gr.Row():
42
+ detector_name = gr.Dropdown(
43
+ label='Detector',
44
+ choices=list(det_model.MODEL_DICT.keys()),
45
+ value=det_model.model_name)
46
+ with gr.Row():
47
+ detect_button = gr.Button('Detect')
48
+ det_preds = gr.Variable()
49
+ with gr.Column():
50
+ with gr.Row():
51
+ detection_visualization = gr.Image(
52
+ label='Detection Result',
53
+ type='numpy',
54
+ elem_id='det-result')
55
+ with gr.Row():
56
+ vis_det_score_threshold = gr.Slider(
57
+ label='Visualization Score Threshold',
58
+ minimum=0,
59
+ maximum=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  step=0.05,
61
+ value=0.5)
62
+ with gr.Row():
63
+ redraw_det_button = gr.Button(value='Redraw')
64
+
65
+ with gr.Row():
66
+ paths = sorted(pathlib.Path('images').rglob('*.jpg'))
67
+ example_images = gr.Examples(examples=[[path.as_posix()]
68
+ for path in paths],
69
+ inputs=input_image)
70
+
71
+ with gr.Box():
72
+ gr.Markdown('## Step 2')
73
+ with gr.Row():
74
+ with gr.Column():
75
+ with gr.Row():
76
+ pose_model_name = gr.Dropdown(
77
+ label='Pose Model',
78
+ choices=list(pose_model.MODEL_DICT.keys()),
79
+ value=pose_model.model_name)
80
+ det_score_threshold = gr.Slider(label='Box Score Threshold',
81
+ minimum=0,
82
+ maximum=1,
83
+ step=0.05,
84
+ value=0.5)
85
+ with gr.Row():
86
+ predict_button = gr.Button('Predict')
87
+ pose_preds = gr.Variable()
88
+ with gr.Column():
89
+ with gr.Row():
90
+ pose_visualization = gr.Image(label='Result',
91
+ type='numpy',
92
+ elem_id='pose-result')
93
+ with gr.Row():
94
+ vis_kpt_score_threshold = gr.Slider(
95
+ label='Visualization Score Threshold',
96
+ minimum=0,
97
+ maximum=1,
98
+ step=0.05,
99
+ value=0.3)
100
+ with gr.Row():
101
+ vis_dot_radius = gr.Slider(label='Dot Radius',
102
+ minimum=1,
103
+ maximum=10,
104
+ step=1,
105
+ value=4)
106
+ with gr.Row():
107
+ vis_line_thickness = gr.Slider(label='Line Thickness',
108
+ minimum=1,
109
+ maximum=10,
110
  step=1,
111
+ value=2)
112
+ with gr.Row():
113
+ redraw_pose_button = gr.Button('Redraw')
114
+
115
+ detector_name.change(fn=det_model.set_model,
116
+ inputs=detector_name,
117
+ outputs=None)
118
+ detect_button.click(fn=det_model.run,
119
+ inputs=[
120
+ detector_name,
121
+ input_image,
122
+ vis_det_score_threshold,
123
+ ],
124
+ outputs=[
125
+ det_preds,
126
+ detection_visualization,
127
+ ])
128
+ redraw_det_button.click(fn=det_model.visualize_detection_results,
129
  inputs=[
 
130
  input_image,
131
+ det_preds,
132
  vis_det_score_threshold,
133
  ],
134
+ outputs=detection_visualization)
135
+
136
+ pose_model_name.change(fn=pose_model.set_model,
137
+ inputs=pose_model_name,
138
+ outputs=None)
139
+ predict_button.click(fn=pose_model.run,
140
+ inputs=[
141
+ pose_model_name,
142
+ input_image,
143
+ det_preds,
144
+ det_score_threshold,
145
+ vis_kpt_score_threshold,
146
+ vis_dot_radius,
147
+ vis_line_thickness,
148
+ ],
149
+ outputs=[
150
+ pose_preds,
151
+ pose_visualization,
152
+ ])
153
+ redraw_pose_button.click(fn=pose_model.visualize_pose_results,
154
  inputs=[
 
155
  input_image,
156
+ pose_preds,
 
157
  vis_kpt_score_threshold,
158
  vis_dot_radius,
159
  vis_line_thickness,
160
  ],
161
+ outputs=pose_visualization)
162
+
163
+ demo.queue(api_open=False).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import os
4
  import pathlib
 
5
  import subprocess
6
  import sys
7
 
@@ -11,9 +12,9 @@ if os.getenv('SYSTEM') == 'spaces':
11
  mim.uninstall('mmcv-full', confirm_yes=True)
12
  mim.install('mmcv-full==1.5.0', is_yes=True)
13
 
14
- subprocess.run('pip uninstall -y opencv-python'.split())
15
- subprocess.run('pip uninstall -y opencv-python-headless'.split())
16
- subprocess.run('pip install opencv-python-headless==4.5.5.64'.split())
17
 
18
  import huggingface_hub
19
  import numpy as np
@@ -21,14 +22,14 @@ import torch
21
  import torch.nn as nn
22
 
23
  app_dir = pathlib.Path(__file__).parent
24
- submodule_dir = app_dir / 'ViTPose/'
25
  sys.path.insert(0, submodule_dir.as_posix())
26
 
27
  from mmdet.apis import inference_detector, init_detector
28
  from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
29
  process_mmdet_results, vis_pose_result)
30
 
31
- HF_TOKEN = os.environ['HF_TOKEN']
32
 
33
 
34
  class DetModel:
@@ -59,8 +60,9 @@ class DetModel:
59
  },
60
  }
61
 
62
- def __init__(self, device: str | torch.device):
63
- self.device = torch.device(device)
 
64
  self._load_all_models_once()
65
  self.model_name = 'YOLOX-l'
66
  self.model = self._load_model(self.model_name)
@@ -139,8 +141,9 @@ class PoseModel:
139
  },
140
  }
141
 
142
- def __init__(self, device: str | torch.device):
143
- self.device = torch.device(device)
 
144
  self.model_name = 'ViTPose-B (multi-task train, COCO)'
145
  self.model = self._load_model(self.model_name)
146
 
 
2
 
3
  import os
4
  import pathlib
5
+ import shlex
6
  import subprocess
7
  import sys
8
 
 
12
  mim.uninstall('mmcv-full', confirm_yes=True)
13
  mim.install('mmcv-full==1.5.0', is_yes=True)
14
 
15
+ subprocess.run(shlex.split('pip uninstall -y opencv-python'))
16
+ subprocess.run(shlex.split('pip uninstall -y opencv-python-headless'))
17
+ subprocess.run(shlex.split('pip install opencv-python-headless==4.5.5.64'))
18
 
19
  import huggingface_hub
20
  import numpy as np
 
22
  import torch.nn as nn
23
 
24
  app_dir = pathlib.Path(__file__).parent
25
+ submodule_dir = app_dir / 'ViTPose'
26
  sys.path.insert(0, submodule_dir.as_posix())
27
 
28
  from mmdet.apis import inference_detector, init_detector
29
  from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
30
  process_mmdet_results, vis_pose_result)
31
 
32
+ HF_TOKEN = os.getenv('HF_TOKEN')
33
 
34
 
35
  class DetModel:
 
60
  },
61
  }
62
 
63
+ def __init__(self):
64
+ self.device = torch.device(
65
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
66
  self._load_all_models_once()
67
  self.model_name = 'YOLOX-l'
68
  self.model = self._load_model(self.model_name)
 
141
  },
142
  }
143
 
144
+ def __init__(self):
145
+ self.device = torch.device(
146
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
147
  self.model_name = 'ViTPose-B (multi-task train, COCO)'
148
  self.model = self._load_model(self.model_name)
149