hysts HF staff commited on
Commit
24679fa
1 Parent(s): c1b78a1
Files changed (7) hide show
  1. .gitignore +1 -0
  2. .gitmodules +9 -0
  3. app.py +153 -0
  4. face_detection +1 -0
  5. face_parsing +1 -0
  6. requirements.txt +4 -0
  7. roi_tanh_warping +1 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
.gitmodules ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "face_detection"]
2
+ path = face_detection
3
+ url = https://github.com/ibug-group/face_detection
4
+ [submodule "face_parsing"]
5
+ path = face_parsing
6
+ url = https://github.com/hhj1897/face_parsing
7
+ [submodule "roi_tanh_warping"]
8
+ path = roi_tanh_warping
9
+ url = https://github.com/ibug-group/roi_tanh_warping
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import sys
10
+ import tarfile
11
+
12
+ import gradio as gr
13
+ import huggingface_hub
14
+ import numpy as np
15
+ import torch
16
+
17
+ sys.path.insert(0, 'face_detection')
18
+ sys.path.insert(0, 'face_parsing')
19
+ sys.path.insert(0, 'roi_tanh_warping')
20
+
21
+ from ibug.face_detection import RetinaFacePredictor
22
+ from ibug.face_parsing.parser import WEIGHT, FaceParser
23
+ from ibug.face_parsing.utils import label_colormap
24
+
25
+ REPO_URL = 'https://github.com/hhj1897/face_parsing'
26
+ TITLE = 'hhj1897/face_parsing'
27
+ DESCRIPTION = f'This is a demo for {REPO_URL}.'
28
+ ARTICLE = None
29
+
30
+ TOKEN = os.environ['TOKEN']
31
+
32
+
33
+ def parse_args() -> argparse.Namespace:
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument('--device', type=str, default='cpu')
36
+ parser.add_argument('--theme', type=str)
37
+ parser.add_argument('--live', action='store_true')
38
+ parser.add_argument('--share', action='store_true')
39
+ parser.add_argument('--port', type=int)
40
+ parser.add_argument('--disable-queue',
41
+ dest='enable_queue',
42
+ action='store_false')
43
+ parser.add_argument('--allow-flagging', type=str, default='never')
44
+ parser.add_argument('--allow-screenshot', action='store_true')
45
+ return parser.parse_args()
46
+
47
+
48
+ def load_sample_images() -> list[pathlib.Path]:
49
+ image_dir = pathlib.Path('images')
50
+ if not image_dir.exists():
51
+ image_dir.mkdir()
52
+ dataset_repo = 'hysts/input-images'
53
+ filenames = ['000.tar', '001.tar']
54
+ for name in filenames:
55
+ path = huggingface_hub.hf_hub_download(dataset_repo,
56
+ name,
57
+ repo_type='dataset',
58
+ use_auth_token=TOKEN)
59
+ with tarfile.open(path) as f:
60
+ f.extractall(image_dir.as_posix())
61
+ return sorted(image_dir.rglob('*.jpg'))
62
+
63
+
64
+ def load_detector(device: torch.device) -> RetinaFacePredictor:
65
+ model = RetinaFacePredictor(
66
+ threshold=0.8,
67
+ device=device,
68
+ model=RetinaFacePredictor.get_model('mobilenet0.25'))
69
+ return model
70
+
71
+
72
+ def load_model(model_name: str, device: torch.device) -> FaceParser:
73
+ encoder, decoder, num_classes = model_name.split('-')
74
+ num_classes = int(num_classes)
75
+ model = FaceParser(device=device,
76
+ encoder=encoder,
77
+ decoder=decoder,
78
+ num_classes=num_classes)
79
+ model.num_classes = num_classes
80
+ return model
81
+
82
+
83
+ def predict(image: np.ndarray, model_name: str, max_num_faces: int,
84
+ detector: RetinaFacePredictor,
85
+ models: dict[str, FaceParser]) -> np.ndarray:
86
+ model = models[model_name]
87
+ colormap = label_colormap(model.num_classes)
88
+
89
+ # RGB -> BGR
90
+ image = image[:, :, ::-1]
91
+
92
+ faces = detector(image, rgb=False)
93
+ if len(faces) == 0:
94
+ raise RuntimeError('No face was found.')
95
+ faces = sorted(list(faces), key=lambda x: -x[4])[:max_num_faces][::-1]
96
+ masks = model.predict_img(image, faces, rgb=False)
97
+
98
+ mask_image = np.zeros_like(image)
99
+ for mask in masks:
100
+ temp = colormap[mask]
101
+ mask_image[temp > 0] = temp[temp > 0]
102
+
103
+ res = image.astype(float) * 0.5 + mask_image[:, :, ::-1] * 0.5
104
+ res = np.clip(np.round(res), 0, 255).astype(np.uint8)
105
+ return res[:, :, ::-1]
106
+
107
+
108
+ def main():
109
+ gr.close_all()
110
+
111
+ args = parse_args()
112
+ device = torch.device(args.device)
113
+
114
+ detector = load_detector(device)
115
+
116
+ model_names = list(WEIGHT.keys())
117
+ models = {name: load_model(name, device=device) for name in model_names}
118
+
119
+ func = functools.partial(predict, detector=detector, models=models)
120
+ func = functools.update_wrapper(func, predict)
121
+
122
+ image_paths = load_sample_images()
123
+ examples = [[path.as_posix(), model_names[1], 10] for path in image_paths]
124
+
125
+ gr.Interface(
126
+ func,
127
+ [
128
+ gr.inputs.Image(type='numpy', label='Input'),
129
+ gr.inputs.Radio(model_names,
130
+ type='value',
131
+ default=model_names[1],
132
+ label='Model'),
133
+ gr.inputs.Slider(
134
+ 1, 20, step=1, default=10, label='Max Number of Faces'),
135
+ ],
136
+ gr.outputs.Image(type='numpy', label='Output'),
137
+ examples=examples,
138
+ title=TITLE,
139
+ description=DESCRIPTION,
140
+ article=ARTICLE,
141
+ theme=args.theme,
142
+ allow_screenshot=args.allow_screenshot,
143
+ allow_flagging=args.allow_flagging,
144
+ live=args.live,
145
+ ).launch(
146
+ enable_queue=args.enable_queue,
147
+ server_port=args.port,
148
+ share=args.share,
149
+ )
150
+
151
+
152
+ if __name__ == '__main__':
153
+ main()
face_detection ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit bc1e392b11d731fa20b1397c8ff3faed5e7fc76e
face_parsing ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8ce84123d0433e6ed389b33e5d3dc2a6a1609d70
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.22.3
2
+ opencv-python-headless==4.5.5.64
3
+ torch==1.11.0
4
+ torchvision==0.12.0
roi_tanh_warping ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit f9cb77ed9d4ce4e40f026b2425d62efe517691c9