hysts HF staff commited on
Commit
09e24f9
1 Parent(s): d9aa325
Files changed (5) hide show
  1. .pre-commit-config.yaml +2 -12
  2. README.md +1 -1
  3. app.py +94 -128
  4. model.py +19 -23
  5. requirements.txt +1 -1
.pre-commit-config.yaml CHANGED
@@ -21,11 +21,11 @@ repos:
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
- rev: 5.10.1
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.812
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
@@ -34,13 +34,3 @@ repos:
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
37
- - repo: https://github.com/kynan/nbstripout
38
- rev: 0.5.0
39
- hooks:
40
- - id: nbstripout
41
- args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
- - repo: https://github.com/nbQA-dev/nbQA
43
- rev: 1.3.1
44
- hooks:
45
- - id: nbqa-isort
46
- - id: nbqa-yapf
 
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
 
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦀
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.0.15
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import pathlib
7
  import tarfile
8
 
@@ -15,21 +14,7 @@ DESCRIPTION = '''# ViTPose
15
  This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](https://github.com/ViTAE-Transformer/ViTPose).
16
 
17
  Related app: [https://huggingface.co/spaces/Gradio-Blocks/ViTPose](https://huggingface.co/spaces/Gradio-Blocks/ViTPose)
18
-
19
  '''
20
- FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.vitpose_video" />'
21
-
22
-
23
- def parse_args() -> argparse.Namespace:
24
- parser = argparse.ArgumentParser()
25
- parser.add_argument('--device', type=str, default='cpu')
26
- parser.add_argument('--theme', type=str)
27
- parser.add_argument('--share', action='store_true')
28
- parser.add_argument('--port', type=int)
29
- parser.add_argument('--disable-queue',
30
- dest='enable_queue',
31
- action='store_false')
32
- return parser.parse_args()
33
 
34
 
35
  def set_example_video(example: list) -> dict:
@@ -43,116 +28,97 @@ def extract_tar() -> None:
43
  f.extractall('mmdet_configs')
44
 
45
 
46
- def main():
47
- args = parse_args()
48
-
49
- extract_tar()
50
-
51
- model = AppModel(device=args.device)
52
-
53
- with gr.Blocks(theme=args.theme, css='style.css') as demo:
54
- gr.Markdown(DESCRIPTION)
55
-
56
- with gr.Row():
57
- with gr.Column():
58
- input_video = gr.Video(label='Input Video',
59
- format='mp4',
60
- elem_id='input_video')
61
- with gr.Group():
62
- detector_name = gr.Dropdown(
63
- list(model.det_model.MODEL_DICT.keys()),
64
- value=model.det_model.model_name,
65
- label='Detector')
66
- pose_model_name = gr.Dropdown(
67
- list(model.pose_model.MODEL_DICT.keys()),
68
- value=model.pose_model.model_name,
69
- label='Pose Model')
70
- det_score_threshold = gr.Slider(
71
- 0,
72
- 1,
73
- step=0.05,
74
- value=0.5,
75
- label='Box Score Threshold')
76
- max_num_frames = gr.Slider(
77
- 1,
78
- 300,
79
- step=1,
80
- value=60,
81
- label='Maximum Number of Frames')
82
- predict_button = gr.Button(value='Predict')
83
- pose_preds = gr.Variable()
84
-
85
- paths = sorted(pathlib.Path('videos').rglob('*.mp4'))
86
- example_videos = gr.Dataset(components=[input_video],
87
- samples=[[path.as_posix()]
88
- for path in paths])
89
-
90
- with gr.Column():
91
- with gr.Group():
92
- result = gr.Video(label='Result',
93
- format='mp4',
94
- elem_id='result')
95
- vis_kpt_score_threshold = gr.Slider(
96
- 0,
97
- 1,
98
- step=0.05,
99
- value=0.3,
100
- label='Visualization Score Threshold')
101
- vis_dot_radius = gr.Slider(1,
102
- 10,
103
- step=1,
104
- value=4,
105
- label='Dot Radius')
106
- vis_line_thickness = gr.Slider(1,
107
- 10,
108
- step=1,
109
- value=2,
110
- label='Line Thickness')
111
- redraw_button = gr.Button(value='Redraw')
112
-
113
- gr.Markdown(FOOTER)
114
-
115
- detector_name.change(fn=model.det_model.set_model,
116
- inputs=detector_name,
117
- outputs=None)
118
- pose_model_name.change(fn=model.pose_model.set_model,
119
- inputs=pose_model_name,
120
- outputs=None)
121
- predict_button.click(fn=model.run,
122
- inputs=[
123
- input_video,
124
- detector_name,
125
- pose_model_name,
126
- det_score_threshold,
127
- max_num_frames,
128
- vis_kpt_score_threshold,
129
- vis_dot_radius,
130
- vis_line_thickness,
131
- ],
132
- outputs=[
133
- result,
134
- pose_preds,
135
- ])
136
- redraw_button.click(fn=model.visualize_pose_results,
137
- inputs=[
138
- input_video,
139
- pose_preds,
140
- vis_kpt_score_threshold,
141
- vis_dot_radius,
142
- vis_line_thickness,
143
- ],
144
- outputs=result)
145
-
146
- example_videos.click(fn=set_example_video,
147
- inputs=example_videos,
148
- outputs=input_video)
149
-
150
- demo.launch(
151
- enable_queue=args.enable_queue,
152
- server_port=args.port,
153
- share=args.share,
154
- )
155
-
156
-
157
- if __name__ == '__main__':
158
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import pathlib
6
  import tarfile
7
 
 
14
  This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](https://github.com/ViTAE-Transformer/ViTPose).
15
 
16
  Related app: [https://huggingface.co/spaces/Gradio-Blocks/ViTPose](https://huggingface.co/spaces/Gradio-Blocks/ViTPose)
 
17
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def set_example_video(example: list) -> dict:
 
28
  f.extractall('mmdet_configs')
29
 
30
 
31
+ extract_tar()
32
+
33
+ model = AppModel()
34
+
35
+ with gr.Blocks(css='style.css') as demo:
36
+ gr.Markdown(DESCRIPTION)
37
+
38
+ with gr.Row():
39
+ with gr.Column():
40
+ input_video = gr.Video(label='Input Video',
41
+ format='mp4',
42
+ elem_id='input_video')
43
+ detector_name = gr.Dropdown(list(
44
+ model.det_model.MODEL_DICT.keys()),
45
+ value=model.det_model.model_name,
46
+ label='Detector')
47
+ pose_model_name = gr.Dropdown(list(
48
+ model.pose_model.MODEL_DICT.keys()),
49
+ value=model.pose_model.model_name,
50
+ label='Pose Model')
51
+ det_score_threshold = gr.Slider(0,
52
+ 1,
53
+ step=0.05,
54
+ value=0.5,
55
+ label='Box Score Threshold')
56
+ max_num_frames = gr.Slider(1,
57
+ 300,
58
+ step=1,
59
+ value=60,
60
+ label='Maximum Number of Frames')
61
+ predict_button = gr.Button(value='Predict')
62
+ pose_preds = gr.Variable()
63
+
64
+ paths = sorted(pathlib.Path('videos').rglob('*.mp4'))
65
+ example_videos = gr.Dataset(components=[input_video],
66
+ samples=[[path.as_posix()]
67
+ for path in paths])
68
+
69
+ with gr.Column():
70
+ result = gr.Video(label='Result', format='mp4', elem_id='result')
71
+ vis_kpt_score_threshold = gr.Slider(
72
+ 0,
73
+ 1,
74
+ step=0.05,
75
+ value=0.3,
76
+ label='Visualization Score Threshold')
77
+ vis_dot_radius = gr.Slider(1,
78
+ 10,
79
+ step=1,
80
+ value=4,
81
+ label='Dot Radius')
82
+ vis_line_thickness = gr.Slider(1,
83
+ 10,
84
+ step=1,
85
+ value=2,
86
+ label='Line Thickness')
87
+ redraw_button = gr.Button(value='Redraw')
88
+
89
+ detector_name.change(fn=model.det_model.set_model,
90
+ inputs=detector_name,
91
+ outputs=None)
92
+ pose_model_name.change(fn=model.pose_model.set_model,
93
+ inputs=pose_model_name,
94
+ outputs=None)
95
+ predict_button.click(fn=model.run,
96
+ inputs=[
97
+ input_video,
98
+ detector_name,
99
+ pose_model_name,
100
+ det_score_threshold,
101
+ max_num_frames,
102
+ vis_kpt_score_threshold,
103
+ vis_dot_radius,
104
+ vis_line_thickness,
105
+ ],
106
+ outputs=[
107
+ result,
108
+ pose_preds,
109
+ ])
110
+ redraw_button.click(fn=model.visualize_pose_results,
111
+ inputs=[
112
+ input_video,
113
+ pose_preds,
114
+ vis_kpt_score_threshold,
115
+ vis_dot_radius,
116
+ vis_line_thickness,
117
+ ],
118
+ outputs=result)
119
+
120
+ example_videos.click(fn=set_example_video,
121
+ inputs=example_videos,
122
+ outputs=input_video)
123
+
124
+ demo.queue().launch(show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import os
 
4
  import subprocess
5
  import sys
6
  import tempfile
@@ -11,9 +12,10 @@ if os.getenv('SYSTEM') == 'spaces':
11
  mim.uninstall('mmcv-full', confirm_yes=True)
12
  mim.install('mmcv-full==1.5.0', is_yes=True)
13
 
14
- subprocess.call('pip uninstall -y opencv-python'.split())
15
- subprocess.call('pip uninstall -y opencv-python-headless'.split())
16
- subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
 
17
 
18
  import cv2
19
  import huggingface_hub
@@ -27,7 +29,7 @@ from mmdet.apis import inference_detector, init_detector
27
  from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
28
  process_mmdet_results, vis_pose_result)
29
 
30
- HF_TOKEN = os.environ['HF_TOKEN']
31
 
32
 
33
  class DetModel:
@@ -58,8 +60,9 @@ class DetModel:
58
  },
59
  }
60
 
61
- def __init__(self, device: str | torch.device):
62
- self.device = torch.device(device)
 
63
  self._load_all_models_once()
64
  self.model_name = 'YOLOX-l'
65
  self.model = self._load_model(self.model_name)
@@ -131,8 +134,9 @@ class PoseModel:
131
  },
132
  }
133
 
134
- def __init__(self, device: str | torch.device):
135
- self.device = torch.device(device)
 
136
  self.model_name = 'ViTPose-B (multi-task train, COCO)'
137
  self.model = self._load_model(self.model_name)
138
 
@@ -199,9 +203,9 @@ class PoseModel:
199
 
200
 
201
  class AppModel:
202
- def __init__(self, device: str | torch.device):
203
- self.det_model = DetModel(device)
204
- self.pose_model = PoseModel(device)
205
 
206
  def run(
207
  self, video_path: str, det_model_name: str, pose_model_name: str,
@@ -222,8 +226,8 @@ class AppModel:
222
  preds_all = []
223
 
224
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
225
- temp_file = tempfile.NamedTemporaryFile(suffix='.mp4')
226
- writer = cv2.VideoWriter(temp_file.name, fourcc, fps, (width, height))
227
  for _ in range(max_num_frames):
228
  ok, frame = cap.read()
229
  if not ok:
@@ -238,10 +242,6 @@ class AppModel:
238
  cap.release()
239
  writer.release()
240
 
241
- out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
242
- subprocess.run(
243
- f'ffmpeg -y -loglevel quiet -stats -i {temp_file.name} -c:v libx264 {out_file.name}'
244
- .split())
245
  return out_file.name, preds_all
246
 
247
  def visualize_pose_results(self, video_path: str,
@@ -257,8 +257,8 @@ class AppModel:
257
  fps = cap.get(cv2.CAP_PROP_FPS)
258
 
259
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
260
- temp_file = tempfile.NamedTemporaryFile(suffix='.mp4')
261
- writer = cv2.VideoWriter(temp_file.name, fourcc, fps, (width, height))
262
  for pose_preds in pose_preds_all:
263
  ok, frame = cap.read()
264
  if not ok:
@@ -271,8 +271,4 @@ class AppModel:
271
  cap.release()
272
  writer.release()
273
 
274
- out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
275
- subprocess.run(
276
- f'ffmpeg -y -loglevel quiet -stats -i {temp_file.name} -c:v libx264 {out_file.name}'
277
- .split())
278
  return out_file.name
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import shlex
5
  import subprocess
6
  import sys
7
  import tempfile
 
12
  mim.uninstall('mmcv-full', confirm_yes=True)
13
  mim.install('mmcv-full==1.5.0', is_yes=True)
14
 
15
+ subprocess.call(shlex.split('pip uninstall -y opencv-python'))
16
+ subprocess.call(shlex.split('pip uninstall -y opencv-python-headless'))
17
+ subprocess.call(
18
+ shlex.split('pip install opencv-python-headless==4.5.5.64'))
19
 
20
  import cv2
21
  import huggingface_hub
 
29
  from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
30
  process_mmdet_results, vis_pose_result)
31
 
32
+ HF_TOKEN = os.getenv('HF_TOKEN')
33
 
34
 
35
  class DetModel:
 
60
  },
61
  }
62
 
63
+ def __init__(self):
64
+ self.device = torch.device(
65
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
66
  self._load_all_models_once()
67
  self.model_name = 'YOLOX-l'
68
  self.model = self._load_model(self.model_name)
 
134
  },
135
  }
136
 
137
+ def __init__(self):
138
+ self.device = torch.device(
139
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
140
  self.model_name = 'ViTPose-B (multi-task train, COCO)'
141
  self.model = self._load_model(self.model_name)
142
 
 
203
 
204
 
205
  class AppModel:
206
+ def __init__(self):
207
+ self.det_model = DetModel()
208
+ self.pose_model = PoseModel()
209
 
210
  def run(
211
  self, video_path: str, det_model_name: str, pose_model_name: str,
 
226
  preds_all = []
227
 
228
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
229
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
230
+ writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
231
  for _ in range(max_num_frames):
232
  ok, frame = cap.read()
233
  if not ok:
 
242
  cap.release()
243
  writer.release()
244
 
 
 
 
 
245
  return out_file.name, preds_all
246
 
247
  def visualize_pose_results(self, video_path: str,
 
257
  fps = cap.get(cv2.CAP_PROP_FPS)
258
 
259
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
260
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
261
+ writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
262
  for pose_preds in pose_preds_all:
263
  ok, frame = cap.read()
264
  if not ok:
 
271
  cap.release()
272
  writer.release()
273
 
 
 
 
 
274
  return out_file.name
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  mmcv-full==1.5.0
2
  mmdet==2.24.1
3
  mmpose==0.25.1
4
- numpy==1.22.4
5
  opencv-python-headless==4.5.5.64
6
  openmim==0.1.5
7
  timm==0.5.4
 
1
  mmcv-full==1.5.0
2
  mmdet==2.24.1
3
  mmpose==0.25.1
4
+ numpy==1.23.5
5
  opencv-python-headless==4.5.5.64
6
  openmim==0.1.5
7
  timm==0.5.4