hysts HF staff commited on
Commit
9bdd97c
1 Parent(s): 41476e2
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. app.py +161 -0
  3. bizarre-pose-estimator +1 -0
  4. requirements.txt +2 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "bizarre-pose-estimator"]
2
+ path = bizarre-pose-estimator
3
+ url = https://github.com/ShuhongChen/bizarre-pose-estimator
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import subprocess
10
+ import sys
11
+
12
+ # workaround for https://github.com/gradio-app/gradio/issues/483
13
+ command = 'pip install -U gradio==2.7.0'
14
+ subprocess.call(command.split())
15
+
16
+ import gradio as gr
17
+ import huggingface_hub
18
+ import PIL.Image
19
+ import torch
20
+ import torchvision
21
+
22
+ sys.path.insert(0, 'bizarre-pose-estimator')
23
+
24
+ from _util.twodee_v0 import I as ImageWrapper
25
+
26
+ TOKEN = os.environ['TOKEN']
27
+
28
+ MODEL_REPO = 'hysts/bizarre-pose-estimator-models'
29
+ MODEL_PATH = 'tagger.pth'
30
+ LABEL_PATH = 'tags.txt'
31
+
32
+
33
+ def parse_args() -> argparse.Namespace:
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument('--device', type=str, default='cpu')
36
+ parser.add_argument('--score-slider-step', type=float, default=0.05)
37
+ parser.add_argument('--score-threshold', type=float, default=0.5)
38
+ parser.add_argument('--theme', type=str, default='dark-grass')
39
+ parser.add_argument('--live', action='store_true')
40
+ parser.add_argument('--share', action='store_true')
41
+ parser.add_argument('--port', type=int)
42
+ parser.add_argument('--disable-queue',
43
+ dest='enable_queue',
44
+ action='store_false')
45
+ parser.add_argument('--allow-flagging', type=str, default='never')
46
+ parser.add_argument('--allow-screenshot', action='store_true')
47
+ return parser.parse_args()
48
+
49
+
50
+ def download_sample_images() -> list[pathlib.Path]:
51
+ image_dir = pathlib.Path('samples')
52
+ image_dir.mkdir(exist_ok=True)
53
+
54
+ dataset_repo = 'hysts/sample-images-TADNE'
55
+ n_images = 36
56
+ paths = []
57
+ for index in range(n_images):
58
+ path = huggingface_hub.hf_hub_download(dataset_repo,
59
+ f'{index:02d}.jpg',
60
+ repo_type='dataset',
61
+ cache_dir=image_dir.as_posix(),
62
+ use_auth_token=TOKEN)
63
+ paths.append(pathlib.Path(path))
64
+ return paths
65
+
66
+
67
+ @torch.inference_mode()
68
+ def predict(image: PIL.Image.Image, score_threshold: float,
69
+ device: torch.device, model: torch.nn.Module,
70
+ labels: list[str]) -> dict[str, float]:
71
+ data = ImageWrapper(image).resize_square(256).alpha_bg(
72
+ c='w').convert('RGB').tensor()
73
+ data = data.to(device).unsqueeze(0)
74
+
75
+ preds = model(data)[0]
76
+ preds = torch.sigmoid(preds)
77
+ preds = preds.cpu().numpy().astype(float)
78
+
79
+ res = dict()
80
+ for prob, label in zip(preds, labels):
81
+ if prob < score_threshold:
82
+ continue
83
+ res[label] = prob
84
+ return res
85
+
86
+
87
+ def load_model(device: torch.device) -> torch.nn.Module:
88
+ model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
89
+ MODEL_PATH,
90
+ use_auth_token=TOKEN)
91
+ state_dict = torch.load(model_path)
92
+
93
+ model = torchvision.models.resnet50(num_classes=1062)
94
+ model.load_state_dict(state_dict)
95
+ model.to(device)
96
+ model.eval()
97
+
98
+ return model
99
+
100
+
101
+ def load_labels() -> list[str]:
102
+ label_path = huggingface_hub.hf_hub_download(MODEL_REPO,
103
+ LABEL_PATH,
104
+ use_auth_token=TOKEN)
105
+ with open(label_path) as f:
106
+ labels = [line.strip() for line in f.readlines()]
107
+ return labels
108
+
109
+
110
+ def main():
111
+ gr.close_all()
112
+
113
+ args = parse_args()
114
+ device = torch.device(args.device)
115
+
116
+ image_paths = download_sample_images()
117
+ examples = [[path.as_posix(), args.score_threshold]
118
+ for path in image_paths]
119
+
120
+ model = load_model(device)
121
+ labels = load_labels()
122
+
123
+ func = functools.partial(predict,
124
+ device=device,
125
+ model=model,
126
+ labels=labels)
127
+ func = functools.update_wrapper(func, predict)
128
+
129
+ repo_url = 'https://github.com/ShuhongChen/bizarre-pose-estimator'
130
+ title = 'ShuhongChen/bizarre-pose-estimator (tagger)'
131
+ description = f'A demo for {repo_url}'
132
+ article = None
133
+
134
+ gr.Interface(
135
+ func,
136
+ [
137
+ gr.inputs.Image(type='pil', label='Input'),
138
+ gr.inputs.Slider(0,
139
+ 1,
140
+ step=args.score_slider_step,
141
+ default=args.score_threshold,
142
+ label='Score Threshold'),
143
+ ],
144
+ gr.outputs.Label(label='Output'),
145
+ theme=args.theme,
146
+ title=title,
147
+ description=description,
148
+ article=article,
149
+ examples=examples,
150
+ allow_screenshot=args.allow_screenshot,
151
+ allow_flagging=args.allow_flagging,
152
+ live=args.live,
153
+ ).launch(
154
+ enable_queue=args.enable_queue,
155
+ server_port=args.port,
156
+ share=args.share,
157
+ )
158
+
159
+
160
+ if __name__ == '__main__':
161
+ main()
bizarre-pose-estimator ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 7382ec234fa40cd8a6ec4a28b4639209199bc035
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=1.10.1
2
+ torchvision>=0.11.2