skytnt commited on
Commit
7df097b
1 Parent(s): dd1aa6d
Files changed (3) hide show
  1. .gitignore +117 -0
  2. app.py +125 -0
  3. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ .hypothesis/
49
+ .pytest_cache/
50
+
51
+ # Translations
52
+ *.mo
53
+ *.pot
54
+
55
+ # Django stuff:
56
+ *.log
57
+ local_settings.py
58
+ db.sqlite3
59
+
60
+ # Flask stuff:
61
+ instance/
62
+ .webassets-cache
63
+
64
+ # Scrapy stuff:
65
+ .scrapy
66
+
67
+ # Sphinx documentation
68
+ docs/_build/
69
+
70
+ # PyBuilder
71
+ target/
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # IPython
77
+ profile_default/
78
+ ipython_config.py
79
+
80
+ # pyenv
81
+ .python-version
82
+
83
+ # celery beat schedule file
84
+ celerybeat-schedule
85
+
86
+ # SageMath parsed files
87
+ *.sage.py
88
+
89
+ # Environments
90
+ .env
91
+ .venv
92
+ env/
93
+ venv/
94
+ ENV/
95
+ env.bak/
96
+ venv.bak/
97
+
98
+ # Spyder project settings
99
+ .spyderproject
100
+ .spyproject
101
+
102
+ # Rope project settings
103
+ .ropeproject
104
+
105
+ # mkdocs documentation
106
+ /site
107
+
108
+ # mypy
109
+ .mypy_cache/
110
+ .dmypy.json
111
+ dmypy.json
112
+
113
+ # Pyre type checker
114
+ .pyre/
115
+
116
+ .idea/
117
+ video.mp4
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import huggingface_hub
5
+ import imageio
6
+ import numpy as np
7
+ import onnxruntime as rt
8
+ from numpy.random import RandomState
9
+ from skimage import transform
10
+
11
+
12
+ class Model:
13
+ def __init__(self):
14
+ self.g_synthesis = None
15
+ self.g_mapping = None
16
+ self.load_models()
17
+
18
+ def load_models(self):
19
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
20
+ g_mapping_path = huggingface_hub.hf_hub_download("skytnt/waifu-gan", "g_mapping.onnx")
21
+ g_synthesis_path = huggingface_hub.hf_hub_download("skytnt/waifu-gan", "g_synthesis.onnx")
22
+ self.g_mapping = rt.InferenceSession(g_mapping_path, providers=providers)
23
+ self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=providers)
24
+
25
+ def get_img(self, w):
26
+ img = self.g_synthesis.run(None, {'w': w})[0]
27
+ return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
28
+
29
+ def get_w(self, z, psi1, psi2):
30
+ return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]
31
+
32
+ def gen_video(self, w1, w2, path, frame_num=10):
33
+ video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
34
+ lin = np.linspace(0, 1, frame_num)
35
+ for i in range(0, frame_num):
36
+ img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2))
37
+ video.append_data(img)
38
+ video.close()
39
+
40
+
41
+ def get_thumbnail(img):
42
+ img_new = np.full((192, 288, 3), 200, dtype=np.uint8)
43
+ img_new[:, 80:208] = transform.resize(img, (192, 128), preserve_range=True)
44
+ return img_new
45
+
46
+
47
+ def gen_fn(method, seed, psi1, psi2):
48
+ if method == 0:
49
+ seed = random.randint(0, 2 ** 32 -1)
50
+ z = RandomState(int(seed)).randn(1, 1024)
51
+ w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
52
+ img_out = model.get_img(w)
53
+ return img_out, seed, w, get_thumbnail(img_out)
54
+
55
+
56
+ def gen_video_fn(w1, w2, frame):
57
+ if w1 is None or w2 is None:
58
+ return None
59
+ model.gen_video(w1, w2, "video.mp4", int(frame))
60
+ return "video.mp4"
61
+
62
+
63
+ if __name__ == '__main__':
64
+ model = Model()
65
+
66
+ app = gr.Blocks()
67
+ with app:
68
+ gr.Markdown("# Waifu GAN\n\n"
69
+ "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.waifu-gan)\n\n")
70
+ with gr.Tabs():
71
+ with gr.TabItem("generate image"):
72
+ with gr.Row():
73
+ with gr.Column():
74
+ with gr.Row():
75
+ gen_input1 = gr.Radio(label="method", value="random",
76
+ choices=["random", "use seed"], type="index")
77
+ gen_input2 = gr.Slider(minimum=0, maximum=2 ** 32 - 1, step=1, value=0,
78
+ label="seed")
79
+ gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 1")
80
+ gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
81
+ with gr.Group():
82
+ gen_submit = gr.Button("Generate", variant="primary")
83
+ with gr.Column():
84
+ gen_output1 = gr.Image(label="output image")
85
+ select_img_input_w1 = gr.Variable()
86
+ select_img_input_img1 = gr.Variable()
87
+
88
+ with gr.TabItem("generate video"):
89
+ with gr.Row():
90
+ with gr.Column():
91
+ gr.Markdown("## generate video between 2 images")
92
+ with gr.Row():
93
+ with gr.Column():
94
+ gr.Markdown("please select image 1")
95
+ select_img1_dropdown = gr.Radio(label="source", value="current generated image",
96
+ choices=["current generated image"], type="index")
97
+ with gr.Group():
98
+ select_img1_button = gr.Button("Select", variant="primary")
99
+ select_img1_output_img = gr.Image(label="selected image 1")
100
+ select_img1_output_w = gr.Variable()
101
+ with gr.Column():
102
+ gr.Markdown("please select image 2")
103
+ select_img2_dropdown = gr.Radio(label="source", value="current generated image",
104
+ choices=["current generated image"], type="index")
105
+ with gr.Group():
106
+ select_img2_button = gr.Button("Select", variant="primary")
107
+ select_img2_output_img = gr.Image(label="selected image 2")
108
+ select_img2_output_w = gr.Variable()
109
+ generate_video_frame = gr.Slider(minimum=10, maximum=30, step=1, label="frame", value=15)
110
+ with gr.Group():
111
+ generate_video_button = gr.Button("Generate", variant="primary")
112
+ with gr.Column():
113
+ generate_video_output = gr.Video(label="output video")
114
+ gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4],
115
+ [gen_output1, gen_input2, select_img_input_w1, select_img_input_img1])
116
+ select_img1_button.click(lambda i, img1, w1: (img1, w1),
117
+ [select_img1_dropdown, select_img_input_img1, select_img_input_w1],
118
+ [select_img1_output_img, select_img1_output_w])
119
+ select_img2_button.click(lambda i, img1, w1: (img1, w1),
120
+ [select_img2_dropdown, select_img_input_img1, select_img_input_w1],
121
+ [select_img2_output_img, select_img2_output_w])
122
+ generate_video_button.click(gen_video_fn,
123
+ [select_img1_output_w, select_img2_output_w, generate_video_frame],
124
+ [generate_video_output])
125
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ onnx
2
+ onnxruntime-gpu
3
+ scikit-image
4
+ imageio-ffmpeg