radames commited on
Commit
ea476a5
1 Parent(s): 0585438

gradio app code

Browse files
Files changed (4) hide show
  1. .gitignore +4 -0
  2. interface/app.py +151 -0
  3. interface/model_loader.py +242 -0
  4. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ venv
3
+ pretrained_models/
4
+ pretrained_models.tar.gz
interface/app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from .model_loader import Model
4
+ from PIL import Image
5
+ import cv2
6
+ import io
7
+
8
+ # models fron pretrained/latent_transformer folder
9
+ models_files = {
10
+ "anime": "pretrained_models/latent_transformer/anime.pt",
11
+ "car": "pretrained_models/latent_transformer/car.pt",
12
+ "cat": "pretrained_models/latent_transformer/cat.pt",
13
+ "church": "pretrained_models/latent_transformer/church.pt",
14
+ "ffhq": "pretrained_models/latent_transformer/ffhq.pt",
15
+ }
16
+
17
+ models = {name: Model(path) for name, path in models_files.items()}
18
+
19
+
20
+ def cv_to_pil(img):
21
+ return Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
22
+
23
+
24
+ def random_sample(model_name: str):
25
+ model = models[model_name]
26
+ img, latents = model.random_sample()
27
+ pil_img = cv_to_pil(img)
28
+ return pil_img, model_name, latents
29
+
30
+
31
+ def zoom(dx, dy, dz, model_state, latents_state):
32
+ model = models[model_state]
33
+ dx = dx
34
+ dy = dy
35
+ dz = dz
36
+ sx = 100
37
+ sy = 100
38
+ stop_points = []
39
+ img, latents_state = model.zoom(
40
+ latents_state, dz, sxsy=[sx, sy], stop_points=stop_points
41
+ ) # dz, sxsy=[sx, sy], stop_points=stop_points)
42
+ pil_img = cv_to_pil(img)
43
+ return pil_img, latents_state
44
+
45
+
46
+ def translate(dx, dy, dz, model_state, latents_state):
47
+ model = models[model_state]
48
+
49
+ dx = dx
50
+ dy = dy
51
+ dz = dz
52
+ sx = 128
53
+ sy = 128
54
+ stop_points = []
55
+ zi = False
56
+ zo = False
57
+
58
+ img, latents_state = model.translate(
59
+ latents_state,
60
+ [dx, dy],
61
+ sxsy=[sx, sy],
62
+ stop_points=stop_points,
63
+ zoom_in=zi,
64
+ zoom_out=zo,
65
+ )
66
+
67
+ pil_img = cv_to_pil(img)
68
+ return pil_img, latents_state
69
+
70
+
71
+ def change_style(image: Image.Image, model_state, latents_state):
72
+ model = models[model_state]
73
+ img, latents_state = model.change_style(latents_state)
74
+ pil_img = cv_to_pil(img)
75
+ return pil_img, latents_state
76
+
77
+
78
+ def reset(model_state, latents_state):
79
+ model = models[model_state]
80
+ img, latents_state = model.reset(latents_state)
81
+ pil_img = cv_to_pil(img)
82
+ return pil_img, latents_state
83
+
84
+
85
+ with gr.Blocks() as block:
86
+ model_state = gr.State(value="cat")
87
+ latents_state = gr.State({})
88
+ gr.Markdown("# UserControllableLT: User controllable latent transformer")
89
+ gr.Markdown("## Select model")
90
+ with gr.Row():
91
+ with gr.Column():
92
+ model_name = gr.Dropdown(
93
+ choices=list(models_files.keys()),
94
+ label="Select Pretrained Model",
95
+ value="cat",
96
+ )
97
+ with gr.Row():
98
+ button = gr.Button("Random sample")
99
+ reset_btn = gr.Button("Reset")
100
+
101
+ dx = gr.Slider(
102
+ minimum=-128, maximum=128, step_size=0.1, label="dx", value=0.0
103
+ )
104
+ dy = gr.Slider(
105
+ minimum=-128, maximum=128, step_size=0.1, label="dy", value=0.0
106
+ )
107
+ dz = gr.Slider(
108
+ minimum=-128, maximum=128, step_size=0.1, label="dz", value=0.0
109
+ )
110
+
111
+ with gr.Row():
112
+ change_style_bt = gr.Button("Change style")
113
+
114
+ with gr.Column():
115
+ image = gr.Image(type="pil", label="")
116
+ button.click(
117
+ random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
118
+ )
119
+
120
+ reset_btn.click(
121
+ reset,
122
+ inputs=[model_state, latents_state],
123
+ outputs=[image, latents_state],
124
+ )
125
+
126
+ change_style_bt.click(
127
+ change_style,
128
+ inputs=[image, model_state, latents_state],
129
+ outputs=[image, latents_state],
130
+ )
131
+ dx.change(
132
+ translate,
133
+ inputs=[dx, dy, dz, model_state, latents_state],
134
+ outputs=[image, latents_state],
135
+ show_progress=False,
136
+ )
137
+ dy.change(
138
+ translate,
139
+ inputs=[dx, dy, dz, model_state, latents_state],
140
+ outputs=[image, latents_state],
141
+ show_progress=False,
142
+ )
143
+ dz.change(
144
+ zoom,
145
+ inputs=[dx, dy, dz, model_state, latents_state],
146
+ outputs=[image, latents_state],
147
+ show_progress=False,
148
+ )
149
+
150
+
151
+ block.launch()
interface/model_loader.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import Namespace
3
+ import numpy as np
4
+ import torch
5
+
6
+ from models.StyleGANControler import StyleGANControler
7
+
8
+
9
+ class Model:
10
+ def __init__(
11
+ self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
12
+ ):
13
+ self.truncation = truncation
14
+ self.use_average_code_as_input = use_average_code_as_input
15
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
16
+ opts = ckpt["opts"]
17
+ opts["checkpoint_path"] = checkpoint_path
18
+ self.opts = Namespace(**ckpt["opts"])
19
+ self.net = StyleGANControler(self.opts)
20
+ self.net.eval()
21
+ self.net.cuda()
22
+ self.target_layers = [0, 1, 2, 3, 4, 5]
23
+
24
+ def random_sample(self):
25
+ z1 = torch.randn(1, 512).to("cuda")
26
+ x1, w1, f1 = self.net.decoder(
27
+ [z1],
28
+ input_is_latent=False,
29
+ randomize_noise=False,
30
+ return_feature_map=True,
31
+ return_latents=True,
32
+ truncation=self.truncation,
33
+ truncation_latent=self.net.latent_avg[0],
34
+ )
35
+ w1_initial = w1.clone()
36
+ x1 = self.net.face_pool(x1)
37
+ image = (
38
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
39
+ )
40
+ return (
41
+ image,
42
+ {
43
+ "w1": w1.cpu().detach().numpy(),
44
+ "w1_initial": w1_initial.cpu().detach().numpy(),
45
+ },
46
+ ) # return latent vector along with the image
47
+
48
+ def latents_to_tensor(self, latents):
49
+ w1 = latents["w1"]
50
+ w1_initial = latents["w1_initial"]
51
+
52
+ w1 = torch.tensor(w1).to("cuda")
53
+ w1_initial = torch.tensor(w1_initial).to("cuda")
54
+
55
+ x1, w1 = self.net.decoder(
56
+ [w1],
57
+ input_is_latent=True,
58
+ randomize_noise=False,
59
+ return_feature_map=False,
60
+ return_latents=True,
61
+ truncation=self.truncation,
62
+ truncation_latent=self.net.latent_avg[0],
63
+ )
64
+ x1, _, f1 = self.net.decoder(
65
+ [w1_initial],
66
+ input_is_latent=False,
67
+ randomize_noise=False,
68
+ return_feature_map=True,
69
+ return_latents=True,
70
+ truncation=self.truncation,
71
+ truncation_latent=self.net.latent_avg[0],
72
+ )
73
+ return (w1, w1_initial, f1)
74
+
75
+ def zoom(self, latents, dz, sxsy=[0, 0], stop_points=[]):
76
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
77
+
78
+ vec_num = abs(dz) / 5
79
+ dz = 100 * np.sign(dz)
80
+ x = torch.from_numpy(np.array([[[1.0, 0, dz]]], dtype=np.float32)).cuda()
81
+ f1 = torch.nn.functional.interpolate(f1, (256, 256))
82
+ y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
83
+
84
+ if len(stop_points) > 0:
85
+ x = torch.cat(
86
+ [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
87
+ )
88
+ tmp = []
89
+ for sp in stop_points:
90
+ tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
91
+ y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
92
+
93
+ if not self.use_average_code_as_input:
94
+ w_hat = self.net.encoder(
95
+ w1[:, self.target_layers].detach(),
96
+ x.detach(),
97
+ y.detach(),
98
+ alpha=vec_num,
99
+ )
100
+ w1 = w1.clone()
101
+ w1[:, self.target_layers] = w_hat
102
+ else:
103
+ w_hat = self.net.encoder(
104
+ self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
105
+ x.detach(),
106
+ y.detach(),
107
+ alpha=vec_num,
108
+ )
109
+ w1 = w1.clone()
110
+ w1[:, self.target_layers] = (
111
+ w1.clone()[:, self.target_layers]
112
+ + w_hat
113
+ - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
114
+ )
115
+
116
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
117
+
118
+ x1 = self.net.face_pool(x1)
119
+ result = (
120
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
121
+ )
122
+ return (
123
+ result,
124
+ {
125
+ "w1": w1.cpu().detach().numpy(),
126
+ "w1_initial": w1_initial.cpu().detach().numpy(),
127
+ },
128
+ ) # return latent vector along with the image
129
+
130
+ def translate(
131
+ self, latents, dxy, sxsy=[0, 0], stop_points=[], zoom_in=False, zoom_out=False
132
+ ):
133
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
134
+
135
+ dz = -5.0 if zoom_in else 0.0
136
+ dz = 5.0 if zoom_out else dz
137
+
138
+ dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
139
+ dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
140
+ dxyz[:2] = dxyz[:2] / dxy_norm
141
+ vec_num = dxy_norm / 10
142
+
143
+ x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
144
+ f1 = torch.nn.functional.interpolate(f1, (256, 256))
145
+ y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
146
+
147
+ if len(stop_points) > 0:
148
+ x = torch.cat(
149
+ [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
150
+ )
151
+ tmp = []
152
+ for sp in stop_points:
153
+ tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
154
+ y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
155
+
156
+ if not self.use_average_code_as_input:
157
+ w_hat = self.net.encoder(
158
+ w1[:, self.target_layers].detach(),
159
+ x.detach(),
160
+ y.detach(),
161
+ alpha=vec_num,
162
+ )
163
+ w1 = w1.clone()
164
+ w1[:, self.target_layers] = w_hat
165
+ else:
166
+ w_hat = self.net.encoder(
167
+ self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
168
+ x.detach(),
169
+ y.detach(),
170
+ alpha=vec_num,
171
+ )
172
+ w1 = w1.clone()
173
+ w1[:, self.target_layers] = (
174
+ w1.clone()[:, self.target_layers]
175
+ + w_hat
176
+ - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
177
+ )
178
+
179
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
180
+
181
+ x1 = self.net.face_pool(x1)
182
+ result = (
183
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
184
+ )
185
+ return (
186
+ result,
187
+ {
188
+ "w1": w1.cpu().detach().numpy(),
189
+ "w1_initial": w1_initial.cpu().detach().numpy(),
190
+ },
191
+ )
192
+
193
+ def change_style(self, latents):
194
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
195
+
196
+ z1 = torch.randn(1, 512).to("cuda")
197
+ x1, w2 = self.net.decoder(
198
+ [z1],
199
+ input_is_latent=False,
200
+ randomize_noise=False,
201
+ return_latents=True,
202
+ truncation=self.truncation,
203
+ truncation_latent=self.net.latent_avg[0],
204
+ )
205
+ w1[:, 6:] = w2.detach()[:, 0]
206
+ x1, w1_new, f1 = self.net.decoder(
207
+ [w1],
208
+ input_is_latent=True,
209
+ randomize_noise=False,
210
+ return_feature_map=True,
211
+ return_latents=True,
212
+ )
213
+ result = (
214
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
215
+ )
216
+ return (
217
+ result,
218
+ {
219
+ "w1": w1_new.cpu().detach().numpy(),
220
+ "w1_initial": w1_initial.cpu().detach().numpy(),
221
+ },
222
+ )
223
+
224
+ def reset(self, latents):
225
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
226
+ x1, w1_new, f1 = self.net.decoder(
227
+ [w1_initial],
228
+ input_is_latent=True,
229
+ randomize_noise=False,
230
+ return_feature_map=True,
231
+ return_latents=True,
232
+ )
233
+ result = (
234
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
235
+ )
236
+ return (
237
+ result,
238
+ {
239
+ "w1": w1_new.cpu().detach().numpy(),
240
+ "w1_initial": w1_initial.cpu().detach().numpy(),
241
+ },
242
+ )
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ torch
3
+ opencv-python
4
+ Pillow
5
+ einops
6
+ ninja==1.10.2
7
+ einops==0.3.2
8
+ gradio