senior-sigan commited on
Commit
424d90e
1 Parent(s): 54c0a4b

init project

Browse files
Files changed (7) hide show
  1. .gitignore +170 -0
  2. .python-version +1 -0
  3. app.py +103 -0
  4. encoder.py +32 -0
  5. examples/elon-musk.jpg +0 -0
  6. face_detector.py +135 -0
  7. generator.py +28 -0
.gitignore ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### joe made this: http://goel.io/joe
2
+ #### macos ####
3
+ # General
4
+ *.DS_Store
5
+ .AppleDouble
6
+ .LSOverride
7
+
8
+ # Icon must end with two \r
9
+ Icon
10
+
11
+
12
+ # Thumbnails
13
+ ._*
14
+
15
+ # Files that might appear in the root of a volume
16
+ .DocumentRevisions-V100
17
+ .fseventsd
18
+ .Spotlight-V100
19
+ .TemporaryItems
20
+ .Trashes
21
+ .VolumeIcon.icns
22
+ .com.apple.timemachine.donotpresent
23
+
24
+ # Directories potentially created on remote AFP share
25
+ .AppleDB
26
+ .AppleDesktop
27
+ Network Trash Folder
28
+ Temporary Items
29
+ .apdisk
30
+ #### linux ####
31
+ *~
32
+
33
+ # temporary files which can be created if a process still has a handle open of a deleted file
34
+ .fuse_hidden*
35
+
36
+ # KDE directory preferences
37
+ .directory
38
+
39
+ # Linux trash folder which might appear on any partition or disk
40
+ .Trash-*
41
+
42
+ # .nfs files are created when an open file is removed but is still being accessed
43
+ .nfs*
44
+ #### windows ####
45
+ # Windows thumbnail cache files
46
+ Thumbs.db
47
+ ehthumbs.db
48
+ ehthumbs_vista.db
49
+
50
+ # Dump file
51
+ *.stackdump
52
+
53
+ # Folder config file
54
+ Desktop.ini
55
+
56
+ # Recycle Bin used on file shares
57
+ $RECYCLE.BIN/
58
+
59
+ # Windows Installer files
60
+ *.cab
61
+ *.msi
62
+ *.msm
63
+ *.msp
64
+
65
+ # Windows shortcuts
66
+ *.lnk
67
+ #### python ####
68
+ # Byte-compiled / optimized / DLL files
69
+ __pycache__/
70
+ *.py[cod]
71
+ *$py.class
72
+
73
+ # C extensions
74
+ *.so
75
+
76
+ # Distribution / packaging
77
+ .Python
78
+ build/
79
+ develop-eggs/
80
+ dist/
81
+ downloads/
82
+ eggs/
83
+ .eggs/
84
+ lib/
85
+ lib64/
86
+ parts/
87
+ sdist/
88
+ var/
89
+ wheels/
90
+ *.egg-info/
91
+ .installed.cfg
92
+ *.egg
93
+
94
+ # PyInstaller
95
+ # Usually these files are written by a python script from a template
96
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
97
+ *.manifest
98
+ *.spec
99
+
100
+ # Installer logs
101
+ pip-log.txt
102
+ pip-delete-this-directory.txt
103
+
104
+ # Unit test / coverage reports
105
+ htmlcov/
106
+ .tox/
107
+ .coverage
108
+ .coverage.*
109
+ .cache
110
+ nosetests.xml
111
+ coverage.xml
112
+ *.cover
113
+ .hypothesis/
114
+
115
+ # Translations
116
+ *.mo
117
+ *.pot
118
+
119
+ # Django stuff:
120
+ *.log
121
+ local_settings.py
122
+
123
+ # Flask stuff:
124
+ instance/
125
+ .webassets-cache
126
+
127
+ # Scrapy stuff:
128
+ .scrapy
129
+
130
+ # Sphinx documentation
131
+ docs/_build/
132
+
133
+ # PyBuilder
134
+ target/
135
+
136
+ # Jupyter Notebook
137
+ .ipynb_checkpoints
138
+
139
+ # celery beat schedule file
140
+ celerybeat-schedule
141
+
142
+ # SageMath parsed files
143
+ *.sage.py
144
+
145
+ # Environments
146
+ .env
147
+ .venv
148
+ env/
149
+ venv/
150
+ ENV/
151
+
152
+ # Spyder project settings
153
+ .spyderproject
154
+ .spyproject
155
+
156
+ # Rope project settings
157
+ .ropeproject
158
+
159
+ # mkdocs documentation
160
+ /site
161
+
162
+ # mypy
163
+ .mypy_cache/
164
+
165
+ #### jetbrains ####
166
+
167
+ .idea/
168
+ .docker/
169
+ .pytest_cache/
170
+ *.db
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ miniforge3
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import pathlib
4
+ import gradio as gr
5
+ import PIL.Image
6
+ from encoder import Encoder
7
+
8
+ from face_detector import FaceAligner
9
+ from generator import Generator
10
+ from huggingface_hub import hf_hub_download
11
+
12
+
13
+ def parse_args() -> argparse.Namespace:
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--share', action='store_true')
16
+ parser.add_argument('--port', type=int)
17
+ parser.add_argument('--disable-queue',
18
+ dest='enable_queue',
19
+ action='store_false')
20
+ return parser.parse_args()
21
+
22
+
23
+ def load_examples():
24
+ image_dir = pathlib.Path('examples')
25
+ images = sorted(image_dir.glob('*.jpg'))
26
+ return [path.as_posix() for path in images]
27
+
28
+
29
+ def predict(
30
+ image: PIL.Image.Image,
31
+ face_aligner: FaceAligner,
32
+ encoder: Encoder,
33
+ generator: Generator,
34
+ ):
35
+ images = face_aligner.align(image)
36
+
37
+ gen_imgs = []
38
+ for img in images:
39
+ x = encoder.predict(img)
40
+ gen_img = generator.predict(x)
41
+ gen_imgs.append(gen_img)
42
+
43
+ return gen_imgs
44
+
45
+
46
+ def load_models():
47
+ encoder_path = hf_hub_download(
48
+ 'senior-sigan/nijigenka',
49
+ 'encoder.onnx',
50
+ )
51
+ generator_path = hf_hub_download(
52
+ 'senior-sigan/nijigenka',
53
+ 'face2art.onnx',
54
+ )
55
+ shape_predictor_path = hf_hub_download(
56
+ 'senior-sigan/nijigenka',
57
+ 'shape_predictor_68_face_landmarks.bin',
58
+ )
59
+
60
+ face_aligner = FaceAligner(
61
+ image_size=512,
62
+ shape_predictor_path=shape_predictor_path,
63
+ )
64
+ encoder = Encoder(model_path=encoder_path)
65
+ generator = Generator(model_path=generator_path)
66
+
67
+ return face_aligner, encoder, generator
68
+
69
+
70
+ def main():
71
+ args = parse_args()
72
+ gr.close_all()
73
+
74
+ face_aligner, encoder, generator = load_models()
75
+
76
+ func = functools.partial(
77
+ predict,
78
+ face_aligner=face_aligner,
79
+ encoder=encoder,
80
+ generator=generator,
81
+ )
82
+ func = functools.update_wrapper(func, predict)
83
+
84
+ iface = gr.Interface(
85
+ fn=func,
86
+ inputs=[
87
+ gr.inputs.Image(type='pil', label='Input')
88
+ ],
89
+ outputs=gr.outputs.Carousel(['image']),
90
+ examples=load_examples(),
91
+ title='Nijigenka: Portrait to Art',
92
+ allow_flagging='never',
93
+ theme='huggingface',
94
+ )
95
+ iface.launch(
96
+ enable_queue=args.enable_queue,
97
+ server_port=args.port,
98
+ share=args.share,
99
+ )
100
+
101
+
102
+ if __name__ == '__main__':
103
+ main()
encoder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image
2
+ import onnxruntime
3
+ import numpy as np
4
+
5
+
6
+ def to_tensor(
7
+ image: PIL.Image.Image,
8
+ image_size=(256, 256),
9
+ ) -> np.ndarray:
10
+ image = image.resize(image_size)
11
+ x = np.asarray(image) / 255.0
12
+ x = np.transpose(x, (2, 0, 1))
13
+ x = (x - 0.5) / 0.5
14
+ x = np.expand_dims(x, axis=0).astype(np.float32)
15
+ return x
16
+
17
+
18
+ class Encoder(object):
19
+ def __init__(
20
+ self,
21
+ model_path: str = 'encoder.onnx',
22
+ ) -> None:
23
+ self.input_name = 'input_0'
24
+ self.output_name = 'output_0'
25
+ self.session = onnxruntime.InferenceSession(model_path, None)
26
+
27
+ def predict(self, image: PIL.Image.Image) -> np.ndarray:
28
+ x = to_tensor(image)
29
+ output = self.session.run([self.output_name], {
30
+ self.input_name: x,
31
+ })[0][0]
32
+ return output
examples/elon-musk.jpg ADDED
face_detector.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator, List, Tuple
2
+
3
+ import dlib
4
+ import numpy as np
5
+ import PIL.Image
6
+ import scipy.ndimage
7
+
8
+
9
+ class FaceAligner(object):
10
+ def __init__(
11
+ self,
12
+ shape_predictor_path: str = 'shape_predictor_68_face_landmarks.dat',
13
+ image_size: int = 512,
14
+ ) -> None:
15
+ self.image_size = image_size
16
+ self.detector = dlib.get_frontal_face_detector()
17
+ self.shape_predictor = dlib.shape_predictor(shape_predictor_path)
18
+
19
+ def align(self, image: PIL.Image.Image) -> List[PIL.Image.Image]:
20
+ landmarks = self.get_landmarks(image)
21
+
22
+ return [image_align(
23
+ image,
24
+ face_landmarks,
25
+ output_size=self.image_size,
26
+ transform_size=self.image_size * 2,
27
+ ) for face_landmarks in landmarks]
28
+
29
+ def get_landmarks(
30
+ self,
31
+ image: PIL.Image.Image,
32
+ ) -> Iterator[List[Tuple[int, int]]]:
33
+ img = np.asarray(image.convert('L'))
34
+ dets = self.detector(img, 1)
35
+
36
+ for detection in dets:
37
+ try:
38
+ parts = self.shape_predictor(img, detection).parts()
39
+ face_landmarks = [(point.x, point.y) for point in parts]
40
+ yield face_landmarks
41
+ except:
42
+ print("Exception in get_landmarks()!")
43
+
44
+
45
+ def image_align(
46
+ img: PIL.Image.Image,
47
+ face_landmarks: List[Tuple[int, int]],
48
+ output_size: int = 1024,
49
+ transform_size: int = 4096,
50
+ enable_padding: bool = True,
51
+ x_scale: float = 1,
52
+ y_scale: float = 1,
53
+ em_scale: float = 0.1,
54
+ ) -> PIL.Image.Image:
55
+ # Align function from FFHQ dataset pre-processing step
56
+ # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
57
+
58
+ lm = np.array(face_landmarks)
59
+ lm_chin = lm[0: 17] # left-right
60
+ lm_eyebrow_left = lm[17: 22] # left-right
61
+ lm_eyebrow_right = lm[22: 27] # left-right
62
+ lm_nose = lm[27: 31] # top-down
63
+ lm_nostrils = lm[31: 36] # top-down
64
+ lm_eye_left = lm[36: 42] # left-clockwise
65
+ lm_eye_right = lm[42: 48] # left-clockwise
66
+ lm_mouth_outer = lm[48: 60] # left-clockwise
67
+ lm_mouth_inner = lm[60: 68] # left-clockwise
68
+
69
+ # Calculate auxiliary vectors.
70
+ eye_left = np.mean(lm_eye_left, axis=0)
71
+ eye_right = np.mean(lm_eye_right, axis=0)
72
+ eye_avg = (eye_left + eye_right) * 0.5
73
+ eye_to_eye = eye_right - eye_left
74
+ mouth_left = lm_mouth_outer[0]
75
+ mouth_right = lm_mouth_outer[6]
76
+ mouth_avg = (mouth_left + mouth_right) * 0.5
77
+ eye_to_mouth = mouth_avg - eye_avg
78
+
79
+ # Choose oriented crop rectangle.
80
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
81
+ x /= np.hypot(*x)
82
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
83
+ x *= x_scale
84
+ y = np.flipud(x) * [-y_scale, y_scale]
85
+ c = eye_avg + eye_to_mouth * em_scale
86
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
87
+ qsize = np.hypot(*x) * 2
88
+
89
+ # Shrink.
90
+ shrink = int(np.floor(qsize / output_size * 0.5))
91
+ if shrink > 1:
92
+ rsize = (int(np.rint(float(img.size[0]) / shrink)),
93
+ int(np.rint(float(img.size[1]) / shrink)))
94
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
95
+ quad /= shrink
96
+ qsize /= shrink
97
+
98
+ # Crop.
99
+ border = max(int(np.rint(qsize * 0.1)), 3)
100
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(
101
+ np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
102
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
103
+ min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
104
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
105
+ img = img.crop(crop)
106
+ quad -= crop[0:2]
107
+
108
+ # Pad.
109
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(
110
+ np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
111
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] -
112
+ img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
113
+ if enable_padding and max(pad) > border - 4:
114
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
115
+ img = np.pad(np.float32(img),
116
+ ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
117
+ h, w, _ = img.shape
118
+ y, x, _ = np.ogrid[:h, :w, :1]
119
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(
120
+ w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
121
+ blur = qsize * 0.02
122
+ img += (scipy.ndimage.gaussian_filter(img,
123
+ [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
124
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
125
+ img = np.uint8(np.clip(np.rint(img), 0, 255))
126
+ img = PIL.Image.fromarray(img, 'RGB')
127
+ quad += pad[:2]
128
+
129
+ # Transform.
130
+ img = img.transform((transform_size, transform_size),
131
+ PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
132
+ if output_size < transform_size:
133
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
134
+
135
+ return img
generator.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xml.etree.ElementTree import PI
2
+ import PIL.Image
3
+ import onnxruntime
4
+ import numpy as np
5
+
6
+
7
+ def to_image(tensor: np.ndarray) -> PIL.Image.Image:
8
+ tensor = tensor * 0.5 + 0.5
9
+ tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8)
10
+ tensor = np.transpose(tensor, (1, 2, 0))
11
+ return PIL.Image.fromarray(tensor)
12
+
13
+
14
+ class Generator(object):
15
+ def __init__(
16
+ self,
17
+ model_path: str = 'generator.onnx',
18
+ ) -> None:
19
+ self.input_name = 'input_0'
20
+ self.output_name = 'output_0'
21
+ self.session = onnxruntime.InferenceSession(model_path, None)
22
+
23
+ def predict(self, x: np.ndarray) -> PIL.Image.Image:
24
+ x = np.expand_dims(x, 0)
25
+ output = self.session.run([self.output_name], {
26
+ self.input_name: x,
27
+ })[0][0]
28
+ return to_image(output)