Spaces:
Runtime error
Runtime error
senior-sigan
commited on
Commit
•
424d90e
1
Parent(s):
54c0a4b
init project
Browse files- .gitignore +170 -0
- .python-version +1 -0
- app.py +103 -0
- encoder.py +32 -0
- examples/elon-musk.jpg +0 -0
- face_detector.py +135 -0
- generator.py +28 -0
.gitignore
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#### joe made this: http://goel.io/joe
|
2 |
+
#### macos ####
|
3 |
+
# General
|
4 |
+
*.DS_Store
|
5 |
+
.AppleDouble
|
6 |
+
.LSOverride
|
7 |
+
|
8 |
+
# Icon must end with two \r
|
9 |
+
Icon
|
10 |
+
|
11 |
+
|
12 |
+
# Thumbnails
|
13 |
+
._*
|
14 |
+
|
15 |
+
# Files that might appear in the root of a volume
|
16 |
+
.DocumentRevisions-V100
|
17 |
+
.fseventsd
|
18 |
+
.Spotlight-V100
|
19 |
+
.TemporaryItems
|
20 |
+
.Trashes
|
21 |
+
.VolumeIcon.icns
|
22 |
+
.com.apple.timemachine.donotpresent
|
23 |
+
|
24 |
+
# Directories potentially created on remote AFP share
|
25 |
+
.AppleDB
|
26 |
+
.AppleDesktop
|
27 |
+
Network Trash Folder
|
28 |
+
Temporary Items
|
29 |
+
.apdisk
|
30 |
+
#### linux ####
|
31 |
+
*~
|
32 |
+
|
33 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
34 |
+
.fuse_hidden*
|
35 |
+
|
36 |
+
# KDE directory preferences
|
37 |
+
.directory
|
38 |
+
|
39 |
+
# Linux trash folder which might appear on any partition or disk
|
40 |
+
.Trash-*
|
41 |
+
|
42 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
43 |
+
.nfs*
|
44 |
+
#### windows ####
|
45 |
+
# Windows thumbnail cache files
|
46 |
+
Thumbs.db
|
47 |
+
ehthumbs.db
|
48 |
+
ehthumbs_vista.db
|
49 |
+
|
50 |
+
# Dump file
|
51 |
+
*.stackdump
|
52 |
+
|
53 |
+
# Folder config file
|
54 |
+
Desktop.ini
|
55 |
+
|
56 |
+
# Recycle Bin used on file shares
|
57 |
+
$RECYCLE.BIN/
|
58 |
+
|
59 |
+
# Windows Installer files
|
60 |
+
*.cab
|
61 |
+
*.msi
|
62 |
+
*.msm
|
63 |
+
*.msp
|
64 |
+
|
65 |
+
# Windows shortcuts
|
66 |
+
*.lnk
|
67 |
+
#### python ####
|
68 |
+
# Byte-compiled / optimized / DLL files
|
69 |
+
__pycache__/
|
70 |
+
*.py[cod]
|
71 |
+
*$py.class
|
72 |
+
|
73 |
+
# C extensions
|
74 |
+
*.so
|
75 |
+
|
76 |
+
# Distribution / packaging
|
77 |
+
.Python
|
78 |
+
build/
|
79 |
+
develop-eggs/
|
80 |
+
dist/
|
81 |
+
downloads/
|
82 |
+
eggs/
|
83 |
+
.eggs/
|
84 |
+
lib/
|
85 |
+
lib64/
|
86 |
+
parts/
|
87 |
+
sdist/
|
88 |
+
var/
|
89 |
+
wheels/
|
90 |
+
*.egg-info/
|
91 |
+
.installed.cfg
|
92 |
+
*.egg
|
93 |
+
|
94 |
+
# PyInstaller
|
95 |
+
# Usually these files are written by a python script from a template
|
96 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
97 |
+
*.manifest
|
98 |
+
*.spec
|
99 |
+
|
100 |
+
# Installer logs
|
101 |
+
pip-log.txt
|
102 |
+
pip-delete-this-directory.txt
|
103 |
+
|
104 |
+
# Unit test / coverage reports
|
105 |
+
htmlcov/
|
106 |
+
.tox/
|
107 |
+
.coverage
|
108 |
+
.coverage.*
|
109 |
+
.cache
|
110 |
+
nosetests.xml
|
111 |
+
coverage.xml
|
112 |
+
*.cover
|
113 |
+
.hypothesis/
|
114 |
+
|
115 |
+
# Translations
|
116 |
+
*.mo
|
117 |
+
*.pot
|
118 |
+
|
119 |
+
# Django stuff:
|
120 |
+
*.log
|
121 |
+
local_settings.py
|
122 |
+
|
123 |
+
# Flask stuff:
|
124 |
+
instance/
|
125 |
+
.webassets-cache
|
126 |
+
|
127 |
+
# Scrapy stuff:
|
128 |
+
.scrapy
|
129 |
+
|
130 |
+
# Sphinx documentation
|
131 |
+
docs/_build/
|
132 |
+
|
133 |
+
# PyBuilder
|
134 |
+
target/
|
135 |
+
|
136 |
+
# Jupyter Notebook
|
137 |
+
.ipynb_checkpoints
|
138 |
+
|
139 |
+
# celery beat schedule file
|
140 |
+
celerybeat-schedule
|
141 |
+
|
142 |
+
# SageMath parsed files
|
143 |
+
*.sage.py
|
144 |
+
|
145 |
+
# Environments
|
146 |
+
.env
|
147 |
+
.venv
|
148 |
+
env/
|
149 |
+
venv/
|
150 |
+
ENV/
|
151 |
+
|
152 |
+
# Spyder project settings
|
153 |
+
.spyderproject
|
154 |
+
.spyproject
|
155 |
+
|
156 |
+
# Rope project settings
|
157 |
+
.ropeproject
|
158 |
+
|
159 |
+
# mkdocs documentation
|
160 |
+
/site
|
161 |
+
|
162 |
+
# mypy
|
163 |
+
.mypy_cache/
|
164 |
+
|
165 |
+
#### jetbrains ####
|
166 |
+
|
167 |
+
.idea/
|
168 |
+
.docker/
|
169 |
+
.pytest_cache/
|
170 |
+
*.db
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
miniforge3
|
app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import pathlib
|
4 |
+
import gradio as gr
|
5 |
+
import PIL.Image
|
6 |
+
from encoder import Encoder
|
7 |
+
|
8 |
+
from face_detector import FaceAligner
|
9 |
+
from generator import Generator
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args() -> argparse.Namespace:
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--share', action='store_true')
|
16 |
+
parser.add_argument('--port', type=int)
|
17 |
+
parser.add_argument('--disable-queue',
|
18 |
+
dest='enable_queue',
|
19 |
+
action='store_false')
|
20 |
+
return parser.parse_args()
|
21 |
+
|
22 |
+
|
23 |
+
def load_examples():
|
24 |
+
image_dir = pathlib.Path('examples')
|
25 |
+
images = sorted(image_dir.glob('*.jpg'))
|
26 |
+
return [path.as_posix() for path in images]
|
27 |
+
|
28 |
+
|
29 |
+
def predict(
|
30 |
+
image: PIL.Image.Image,
|
31 |
+
face_aligner: FaceAligner,
|
32 |
+
encoder: Encoder,
|
33 |
+
generator: Generator,
|
34 |
+
):
|
35 |
+
images = face_aligner.align(image)
|
36 |
+
|
37 |
+
gen_imgs = []
|
38 |
+
for img in images:
|
39 |
+
x = encoder.predict(img)
|
40 |
+
gen_img = generator.predict(x)
|
41 |
+
gen_imgs.append(gen_img)
|
42 |
+
|
43 |
+
return gen_imgs
|
44 |
+
|
45 |
+
|
46 |
+
def load_models():
|
47 |
+
encoder_path = hf_hub_download(
|
48 |
+
'senior-sigan/nijigenka',
|
49 |
+
'encoder.onnx',
|
50 |
+
)
|
51 |
+
generator_path = hf_hub_download(
|
52 |
+
'senior-sigan/nijigenka',
|
53 |
+
'face2art.onnx',
|
54 |
+
)
|
55 |
+
shape_predictor_path = hf_hub_download(
|
56 |
+
'senior-sigan/nijigenka',
|
57 |
+
'shape_predictor_68_face_landmarks.bin',
|
58 |
+
)
|
59 |
+
|
60 |
+
face_aligner = FaceAligner(
|
61 |
+
image_size=512,
|
62 |
+
shape_predictor_path=shape_predictor_path,
|
63 |
+
)
|
64 |
+
encoder = Encoder(model_path=encoder_path)
|
65 |
+
generator = Generator(model_path=generator_path)
|
66 |
+
|
67 |
+
return face_aligner, encoder, generator
|
68 |
+
|
69 |
+
|
70 |
+
def main():
|
71 |
+
args = parse_args()
|
72 |
+
gr.close_all()
|
73 |
+
|
74 |
+
face_aligner, encoder, generator = load_models()
|
75 |
+
|
76 |
+
func = functools.partial(
|
77 |
+
predict,
|
78 |
+
face_aligner=face_aligner,
|
79 |
+
encoder=encoder,
|
80 |
+
generator=generator,
|
81 |
+
)
|
82 |
+
func = functools.update_wrapper(func, predict)
|
83 |
+
|
84 |
+
iface = gr.Interface(
|
85 |
+
fn=func,
|
86 |
+
inputs=[
|
87 |
+
gr.inputs.Image(type='pil', label='Input')
|
88 |
+
],
|
89 |
+
outputs=gr.outputs.Carousel(['image']),
|
90 |
+
examples=load_examples(),
|
91 |
+
title='Nijigenka: Portrait to Art',
|
92 |
+
allow_flagging='never',
|
93 |
+
theme='huggingface',
|
94 |
+
)
|
95 |
+
iface.launch(
|
96 |
+
enable_queue=args.enable_queue,
|
97 |
+
server_port=args.port,
|
98 |
+
share=args.share,
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == '__main__':
|
103 |
+
main()
|
encoder.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL.Image
|
2 |
+
import onnxruntime
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def to_tensor(
|
7 |
+
image: PIL.Image.Image,
|
8 |
+
image_size=(256, 256),
|
9 |
+
) -> np.ndarray:
|
10 |
+
image = image.resize(image_size)
|
11 |
+
x = np.asarray(image) / 255.0
|
12 |
+
x = np.transpose(x, (2, 0, 1))
|
13 |
+
x = (x - 0.5) / 0.5
|
14 |
+
x = np.expand_dims(x, axis=0).astype(np.float32)
|
15 |
+
return x
|
16 |
+
|
17 |
+
|
18 |
+
class Encoder(object):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model_path: str = 'encoder.onnx',
|
22 |
+
) -> None:
|
23 |
+
self.input_name = 'input_0'
|
24 |
+
self.output_name = 'output_0'
|
25 |
+
self.session = onnxruntime.InferenceSession(model_path, None)
|
26 |
+
|
27 |
+
def predict(self, image: PIL.Image.Image) -> np.ndarray:
|
28 |
+
x = to_tensor(image)
|
29 |
+
output = self.session.run([self.output_name], {
|
30 |
+
self.input_name: x,
|
31 |
+
})[0][0]
|
32 |
+
return output
|
examples/elon-musk.jpg
ADDED
face_detector.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterator, List, Tuple
|
2 |
+
|
3 |
+
import dlib
|
4 |
+
import numpy as np
|
5 |
+
import PIL.Image
|
6 |
+
import scipy.ndimage
|
7 |
+
|
8 |
+
|
9 |
+
class FaceAligner(object):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
shape_predictor_path: str = 'shape_predictor_68_face_landmarks.dat',
|
13 |
+
image_size: int = 512,
|
14 |
+
) -> None:
|
15 |
+
self.image_size = image_size
|
16 |
+
self.detector = dlib.get_frontal_face_detector()
|
17 |
+
self.shape_predictor = dlib.shape_predictor(shape_predictor_path)
|
18 |
+
|
19 |
+
def align(self, image: PIL.Image.Image) -> List[PIL.Image.Image]:
|
20 |
+
landmarks = self.get_landmarks(image)
|
21 |
+
|
22 |
+
return [image_align(
|
23 |
+
image,
|
24 |
+
face_landmarks,
|
25 |
+
output_size=self.image_size,
|
26 |
+
transform_size=self.image_size * 2,
|
27 |
+
) for face_landmarks in landmarks]
|
28 |
+
|
29 |
+
def get_landmarks(
|
30 |
+
self,
|
31 |
+
image: PIL.Image.Image,
|
32 |
+
) -> Iterator[List[Tuple[int, int]]]:
|
33 |
+
img = np.asarray(image.convert('L'))
|
34 |
+
dets = self.detector(img, 1)
|
35 |
+
|
36 |
+
for detection in dets:
|
37 |
+
try:
|
38 |
+
parts = self.shape_predictor(img, detection).parts()
|
39 |
+
face_landmarks = [(point.x, point.y) for point in parts]
|
40 |
+
yield face_landmarks
|
41 |
+
except:
|
42 |
+
print("Exception in get_landmarks()!")
|
43 |
+
|
44 |
+
|
45 |
+
def image_align(
|
46 |
+
img: PIL.Image.Image,
|
47 |
+
face_landmarks: List[Tuple[int, int]],
|
48 |
+
output_size: int = 1024,
|
49 |
+
transform_size: int = 4096,
|
50 |
+
enable_padding: bool = True,
|
51 |
+
x_scale: float = 1,
|
52 |
+
y_scale: float = 1,
|
53 |
+
em_scale: float = 0.1,
|
54 |
+
) -> PIL.Image.Image:
|
55 |
+
# Align function from FFHQ dataset pre-processing step
|
56 |
+
# https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
|
57 |
+
|
58 |
+
lm = np.array(face_landmarks)
|
59 |
+
lm_chin = lm[0: 17] # left-right
|
60 |
+
lm_eyebrow_left = lm[17: 22] # left-right
|
61 |
+
lm_eyebrow_right = lm[22: 27] # left-right
|
62 |
+
lm_nose = lm[27: 31] # top-down
|
63 |
+
lm_nostrils = lm[31: 36] # top-down
|
64 |
+
lm_eye_left = lm[36: 42] # left-clockwise
|
65 |
+
lm_eye_right = lm[42: 48] # left-clockwise
|
66 |
+
lm_mouth_outer = lm[48: 60] # left-clockwise
|
67 |
+
lm_mouth_inner = lm[60: 68] # left-clockwise
|
68 |
+
|
69 |
+
# Calculate auxiliary vectors.
|
70 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
71 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
72 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
73 |
+
eye_to_eye = eye_right - eye_left
|
74 |
+
mouth_left = lm_mouth_outer[0]
|
75 |
+
mouth_right = lm_mouth_outer[6]
|
76 |
+
mouth_avg = (mouth_left + mouth_right) * 0.5
|
77 |
+
eye_to_mouth = mouth_avg - eye_avg
|
78 |
+
|
79 |
+
# Choose oriented crop rectangle.
|
80 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
81 |
+
x /= np.hypot(*x)
|
82 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
83 |
+
x *= x_scale
|
84 |
+
y = np.flipud(x) * [-y_scale, y_scale]
|
85 |
+
c = eye_avg + eye_to_mouth * em_scale
|
86 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
87 |
+
qsize = np.hypot(*x) * 2
|
88 |
+
|
89 |
+
# Shrink.
|
90 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
91 |
+
if shrink > 1:
|
92 |
+
rsize = (int(np.rint(float(img.size[0]) / shrink)),
|
93 |
+
int(np.rint(float(img.size[1]) / shrink)))
|
94 |
+
img = img.resize(rsize, PIL.Image.ANTIALIAS)
|
95 |
+
quad /= shrink
|
96 |
+
qsize /= shrink
|
97 |
+
|
98 |
+
# Crop.
|
99 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
100 |
+
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(
|
101 |
+
np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
|
102 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
|
103 |
+
min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
|
104 |
+
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
|
105 |
+
img = img.crop(crop)
|
106 |
+
quad -= crop[0:2]
|
107 |
+
|
108 |
+
# Pad.
|
109 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(
|
110 |
+
np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
|
111 |
+
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] -
|
112 |
+
img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
|
113 |
+
if enable_padding and max(pad) > border - 4:
|
114 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
115 |
+
img = np.pad(np.float32(img),
|
116 |
+
((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
117 |
+
h, w, _ = img.shape
|
118 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
119 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(
|
120 |
+
w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
|
121 |
+
blur = qsize * 0.02
|
122 |
+
img += (scipy.ndimage.gaussian_filter(img,
|
123 |
+
[blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
124 |
+
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
125 |
+
img = np.uint8(np.clip(np.rint(img), 0, 255))
|
126 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
127 |
+
quad += pad[:2]
|
128 |
+
|
129 |
+
# Transform.
|
130 |
+
img = img.transform((transform_size, transform_size),
|
131 |
+
PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
|
132 |
+
if output_size < transform_size:
|
133 |
+
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
|
134 |
+
|
135 |
+
return img
|
generator.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from xml.etree.ElementTree import PI
|
2 |
+
import PIL.Image
|
3 |
+
import onnxruntime
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def to_image(tensor: np.ndarray) -> PIL.Image.Image:
|
8 |
+
tensor = tensor * 0.5 + 0.5
|
9 |
+
tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8)
|
10 |
+
tensor = np.transpose(tensor, (1, 2, 0))
|
11 |
+
return PIL.Image.fromarray(tensor)
|
12 |
+
|
13 |
+
|
14 |
+
class Generator(object):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
model_path: str = 'generator.onnx',
|
18 |
+
) -> None:
|
19 |
+
self.input_name = 'input_0'
|
20 |
+
self.output_name = 'output_0'
|
21 |
+
self.session = onnxruntime.InferenceSession(model_path, None)
|
22 |
+
|
23 |
+
def predict(self, x: np.ndarray) -> PIL.Image.Image:
|
24 |
+
x = np.expand_dims(x, 0)
|
25 |
+
output = self.session.run([self.output_name], {
|
26 |
+
self.input_name: x,
|
27 |
+
})[0][0]
|
28 |
+
return to_image(output)
|