radames commited on
Commit
35536db
1 Parent(s): e04439f

single transform function

Browse files
Files changed (2) hide show
  1. interface/app.py +13 -36
  2. interface/model_loader.py +189 -240
interface/app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
- import sys
 
3
  sys.path.append(".")
4
  sys.path.append("..")
5
  from model_loader import Model
@@ -34,7 +35,7 @@ def random_sample(model_name: str):
34
  return pil_img, model_name, latents
35
 
36
 
37
- def zoom(model_state, latents_state, dx=0, dy=0, dz=0, sxsy=[128, 128]):
38
  model = models[model_state]
39
  dx = dx
40
  dy = dy
@@ -42,34 +43,10 @@ def zoom(model_state, latents_state, dx=0, dy=0, dz=0, sxsy=[128, 128]):
42
  sx = sxsy[0]
43
  sy = sxsy[1]
44
  stop_points = []
45
- img, latents_state = model.zoom(
46
- latents_state, dz, sxsy=[sx, sy], stop_points=stop_points
47
- ) # dz, sxsy=[sx, sy], stop_points=stop_points)
48
- pil_img = cv_to_pil(img)
49
- return pil_img, latents_state
50
 
51
-
52
- def translate(model_state, latents_state, dx=0, dy=0, dz=0, sxsy=[128, 128]):
53
- model = models[model_state]
54
-
55
- dx = dx
56
- dy = dy
57
- dz = dz
58
- sx = sxsy[0]
59
- sy = sxsy[1]
60
- stop_points = []
61
- zi = False
62
- zo = False
63
-
64
- img, latents_state = model.translate(
65
- latents_state,
66
- [dx, dy],
67
- sxsy=[sx, sy],
68
- stop_points=stop_points,
69
- zoom_in=zi,
70
- zoom_out=zo,
71
  )
72
-
73
  pil_img = cv_to_pil(img)
74
  return pil_img, latents_state
75
 
@@ -109,6 +86,7 @@ with gr.Blocks() as block:
109
  with gr.Row():
110
  button = gr.Button("Random sample")
111
  reset_btn = gr.Button("Reset")
 
112
 
113
  dx = gr.Slider(
114
  minimum=-256, maximum=256, step_size=0.1, label="dx", value=0.0
@@ -117,14 +95,13 @@ with gr.Blocks() as block:
117
  minimum=-256, maximum=256, step_size=0.1, label="dy", value=0.0
118
  )
119
  dz = gr.Slider(
120
- minimum=-256, maximum=256, step_size=0.1, label="dz", value=0.0
121
  )
122
-
123
- with gr.Row():
124
- change_style_bt = gr.Button("Change style")
125
 
126
  with gr.Column():
127
- image = gr.Image(type="pil", label="")
 
128
  image.select(image_click, inputs=None, outputs=sxsy)
129
  button.click(
130
  random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
@@ -141,19 +118,19 @@ with gr.Blocks() as block:
141
  outputs=[image, latents_state],
142
  )
143
  dx.change(
144
- translate,
145
  inputs=[model_state, latents_state, dx, dy, dz, sxsy],
146
  outputs=[image, latents_state],
147
  show_progress=False,
148
  )
149
  dy.change(
150
- translate,
151
  inputs=[model_state, latents_state, dx, dy, dz, sxsy],
152
  outputs=[image, latents_state],
153
  show_progress=False,
154
  )
155
  dz.change(
156
- zoom,
157
  inputs=[model_state, latents_state, dx, dy, dz, sxsy],
158
  outputs=[image, latents_state],
159
  show_progress=False,
 
1
  import gradio as gr
2
+ import sys
3
+
4
  sys.path.append(".")
5
  sys.path.append("..")
6
  from model_loader import Model
 
35
  return pil_img, model_name, latents
36
 
37
 
38
+ def transform(model_state, latents_state, dx=0, dy=0, dz=0, sxsy=[128, 128]):
39
  model = models[model_state]
40
  dx = dx
41
  dy = dy
 
43
  sx = sxsy[0]
44
  sy = sxsy[1]
45
  stop_points = []
 
 
 
 
 
46
 
47
+ img, latents_state = model.transform(
48
+ latents_state, dz, dxy=[dx, dy], sxsy=[sx, sy], stop_points=stop_points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  )
 
50
  pil_img = cv_to_pil(img)
51
  return pil_img, latents_state
52
 
 
86
  with gr.Row():
87
  button = gr.Button("Random sample")
88
  reset_btn = gr.Button("Reset")
89
+ change_style_bt = gr.Button("Change style")
90
 
91
  dx = gr.Slider(
92
  minimum=-256, maximum=256, step_size=0.1, label="dx", value=0.0
 
95
  minimum=-256, maximum=256, step_size=0.1, label="dy", value=0.0
96
  )
97
  dz = gr.Slider(
98
+ minimum=-5, maximum=5, step_size=0.01, label="dz", value=0.0
99
  )
100
+ image = gr.Image(type="pil", label="").style(height=500)
 
 
101
 
102
  with gr.Column():
103
+ html = gr.HTML(label="output")
104
+
105
  image.select(image_click, inputs=None, outputs=sxsy)
106
  button.click(
107
  random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
 
118
  outputs=[image, latents_state],
119
  )
120
  dx.change(
121
+ transform,
122
  inputs=[model_state, latents_state, dx, dy, dz, sxsy],
123
  outputs=[image, latents_state],
124
  show_progress=False,
125
  )
126
  dy.change(
127
+ transform,
128
  inputs=[model_state, latents_state, dx, dy, dz, sxsy],
129
  outputs=[image, latents_state],
130
  show_progress=False,
131
  )
132
  dz.change(
133
+ transform,
134
  inputs=[model_state, latents_state, dx, dy, dz, sxsy],
135
  outputs=[image, latents_state],
136
  show_progress=False,
interface/model_loader.py CHANGED
@@ -1,240 +1,189 @@
1
- import os
2
- from argparse import Namespace
3
- import numpy as np
4
- import torch
5
-
6
- from models.StyleGANControler import StyleGANControler
7
-
8
-
9
- class Model:
10
- def __init__(
11
- self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
12
- ):
13
- self.truncation = truncation
14
- self.use_average_code_as_input = use_average_code_as_input
15
- ckpt = torch.load(checkpoint_path, map_location="cpu")
16
- opts = ckpt["opts"]
17
- opts["checkpoint_path"] = checkpoint_path
18
- self.opts = Namespace(**ckpt["opts"])
19
- self.net = StyleGANControler(self.opts)
20
- self.net.eval()
21
- self.net.cuda()
22
- self.target_layers = [0, 1, 2, 3, 4, 5]
23
-
24
- def random_sample(self):
25
- z1 = torch.randn(1, 512).to("cuda")
26
- x1, w1, f1 = self.net.decoder(
27
- [z1],
28
- input_is_latent=False,
29
- randomize_noise=False,
30
- return_feature_map=True,
31
- return_latents=True,
32
- truncation=self.truncation,
33
- truncation_latent=self.net.latent_avg[0],
34
- )
35
- w1_initial = w1.clone()
36
- x1 = self.net.face_pool(x1)
37
- image = (
38
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
39
- )
40
- return (
41
- image,
42
- {
43
- "w1": w1.cpu().detach().numpy(),
44
- "w1_initial": w1_initial.cpu().detach().numpy(),
45
- },
46
- ) # return latent vector along with the image
47
-
48
- def latents_to_tensor(self, latents):
49
- w1 = latents["w1"]
50
- w1_initial = latents["w1_initial"]
51
-
52
- w1 = torch.tensor(w1).to("cuda")
53
- w1_initial = torch.tensor(w1_initial).to("cuda")
54
-
55
- x1, w1, f1 = self.net.decoder(
56
- [w1],
57
- input_is_latent=True,
58
- randomize_noise=False,
59
- return_feature_map=True,
60
- return_latents=True,
61
- )
62
- x1, w1_initial, f1 = self.net.decoder(
63
- [w1_initial],
64
- input_is_latent=True,
65
- randomize_noise=False,
66
- return_feature_map=True,
67
- return_latents=True,
68
- )
69
-
70
- return (w1, w1_initial, f1)
71
-
72
- def zoom(self, latents, dz, sxsy=[0, 0], stop_points=[]):
73
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
74
- w1 = w1_initial.clone()
75
-
76
- vec_num = abs(dz) / 5
77
- dz = 100 * np.sign(dz)
78
- x = torch.from_numpy(np.array([[[1.0, 0, dz]]], dtype=np.float32)).cuda()
79
- f1 = torch.nn.functional.interpolate(f1, (256, 256))
80
- y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
81
-
82
- if len(stop_points) > 0:
83
- x = torch.cat(
84
- [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
85
- )
86
- tmp = []
87
- for sp in stop_points:
88
- tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
89
- y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
90
-
91
- if not self.use_average_code_as_input:
92
- w_hat = self.net.encoder(
93
- w1[:, self.target_layers].detach(),
94
- x.detach(),
95
- y.detach(),
96
- alpha=vec_num,
97
- )
98
- w1 = w1.clone()
99
- w1[:, self.target_layers] = w_hat
100
- else:
101
- w_hat = self.net.encoder(
102
- self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
103
- x.detach(),
104
- y.detach(),
105
- alpha=vec_num,
106
- )
107
- w1 = w1.clone()
108
- w1[:, self.target_layers] = (
109
- w1.clone()[:, self.target_layers]
110
- + w_hat
111
- - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
112
- )
113
-
114
- x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
115
-
116
- x1 = self.net.face_pool(x1)
117
- result = (
118
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
119
- )
120
- return (
121
- result,
122
- {
123
- "w1": w1.cpu().detach().numpy(),
124
- "w1_initial": w1_initial.cpu().detach().numpy(),
125
- },
126
- ) # return latent vector along with the image
127
-
128
- def translate(
129
- self, latents, dxy, sxsy=[0, 0], stop_points=[], zoom_in=False, zoom_out=False
130
- ):
131
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
132
- w1 = w1_initial.clone()
133
- dz = -5.0 if zoom_in else 0.0
134
- dz = 5.0 if zoom_out else dz
135
-
136
- dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
137
- dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
138
- dxyz[:2] = dxyz[:2] / dxy_norm
139
- vec_num = dxy_norm / 10
140
-
141
- x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
142
- f1 = torch.nn.functional.interpolate(f1, (256, 256))
143
- y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
144
-
145
- if len(stop_points) > 0:
146
- x = torch.cat(
147
- [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
148
- )
149
- tmp = []
150
- for sp in stop_points:
151
- tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
152
- y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
153
-
154
- if not self.use_average_code_as_input:
155
- w_hat = self.net.encoder(
156
- w1[:, self.target_layers].detach(),
157
- x.detach(),
158
- y.detach(),
159
- alpha=vec_num,
160
- )
161
- w1 = w1.clone()
162
- w1[:, self.target_layers] = w_hat
163
- else:
164
- w_hat = self.net.encoder(
165
- self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
166
- x.detach(),
167
- y.detach(),
168
- alpha=vec_num,
169
- )
170
- w1 = w1.clone()
171
- w1[:, self.target_layers] = (
172
- w1.clone()[:, self.target_layers]
173
- + w_hat
174
- - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
175
- )
176
-
177
- x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
178
-
179
- x1 = self.net.face_pool(x1)
180
- result = (
181
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
182
- )
183
- return (
184
- result,
185
- {
186
- "w1": w1.cpu().detach().numpy(),
187
- "w1_initial": w1_initial.cpu().detach().numpy(),
188
- },
189
- )
190
-
191
- def change_style(self, latents):
192
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
193
- w1 = w1_initial.clone()
194
-
195
- z1 = torch.randn(1, 512).to("cuda")
196
- x1, w2 = self.net.decoder(
197
- [z1],
198
- input_is_latent=False,
199
- randomize_noise=False,
200
- return_latents=True,
201
- truncation=self.truncation,
202
- truncation_latent=self.net.latent_avg[0],
203
- )
204
- w1[:, 6:] = w2.detach()[:, 0]
205
- x1, w1_new = self.net.decoder(
206
- [w1],
207
- input_is_latent=True,
208
- randomize_noise=False,
209
- return_latents=True,
210
- )
211
- result = (
212
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
213
- )
214
- return (
215
- result,
216
- {
217
- "w1": w1_new.cpu().detach().numpy(),
218
- "w1_initial": w1_new.cpu().detach().numpy(),
219
- },
220
- )
221
-
222
- def reset(self, latents):
223
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
224
- x1, w1_new, f1 = self.net.decoder(
225
- [w1_initial],
226
- input_is_latent=True,
227
- randomize_noise=False,
228
- return_feature_map=True,
229
- return_latents=True,
230
- )
231
- result = (
232
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
233
- )
234
- return (
235
- result,
236
- {
237
- "w1": w1_new.cpu().detach().numpy(),
238
- "w1_initial": w1_new.cpu().detach().numpy(),
239
- },
240
- )
 
1
+ import os
2
+ from argparse import Namespace
3
+ import numpy as np
4
+ import torch
5
+
6
+ from models.StyleGANControler import StyleGANControler
7
+
8
+
9
+ class Model:
10
+ def __init__(
11
+ self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
12
+ ):
13
+ self.truncation = truncation
14
+ self.use_average_code_as_input = use_average_code_as_input
15
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
16
+ opts = ckpt["opts"]
17
+ opts["checkpoint_path"] = checkpoint_path
18
+ self.opts = Namespace(**ckpt["opts"])
19
+ self.net = StyleGANControler(self.opts)
20
+ self.net.eval()
21
+ self.net.cuda()
22
+ self.target_layers = [0, 1, 2, 3, 4, 5]
23
+
24
+ def random_sample(self):
25
+ z1 = torch.randn(1, 512).to("cuda")
26
+ x1, w1, f1 = self.net.decoder(
27
+ [z1],
28
+ input_is_latent=False,
29
+ randomize_noise=False,
30
+ return_feature_map=True,
31
+ return_latents=True,
32
+ truncation=self.truncation,
33
+ truncation_latent=self.net.latent_avg[0],
34
+ )
35
+ w1_initial = w1.clone()
36
+ x1 = self.net.face_pool(x1)
37
+ image = (
38
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
39
+ )
40
+ return (
41
+ image,
42
+ {
43
+ "w1": w1.cpu().detach().numpy(),
44
+ "w1_initial": w1_initial.cpu().detach().numpy(),
45
+ },
46
+ ) # return latent vector along with the image
47
+
48
+ def latents_to_tensor(self, latents):
49
+ w1 = latents["w1"]
50
+ w1_initial = latents["w1_initial"]
51
+
52
+ w1 = torch.tensor(w1).to("cuda")
53
+ w1_initial = torch.tensor(w1_initial).to("cuda")
54
+
55
+ x1, w1, f1 = self.net.decoder(
56
+ [w1],
57
+ input_is_latent=True,
58
+ randomize_noise=False,
59
+ return_feature_map=True,
60
+ return_latents=True,
61
+ )
62
+ x1, w1_initial, f1 = self.net.decoder(
63
+ [w1_initial],
64
+ input_is_latent=True,
65
+ randomize_noise=False,
66
+ return_feature_map=True,
67
+ return_latents=True,
68
+ )
69
+
70
+ return (w1, w1_initial, f1)
71
+
72
+ def transform(
73
+ self,
74
+ latents,
75
+ dz,
76
+ dxy,
77
+ sxsy=[0, 0],
78
+ stop_points=[],
79
+ zoom_in=False,
80
+ zoom_out=False,
81
+ ):
82
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
83
+ w1 = w1_initial.clone()
84
+
85
+ dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
86
+ dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
87
+ dxyz[:2] = dxyz[:2] / dxy_norm
88
+ vec_num = dxy_norm / 10
89
+
90
+ x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
91
+ f1 = torch.nn.functional.interpolate(f1, (256, 256))
92
+ y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
93
+
94
+ if len(stop_points) > 0:
95
+ x = torch.cat(
96
+ [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
97
+ )
98
+ tmp = []
99
+ for sp in stop_points:
100
+ tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
101
+ y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
102
+
103
+ if not self.use_average_code_as_input:
104
+ w_hat = self.net.encoder(
105
+ w1[:, self.target_layers].detach(),
106
+ x.detach(),
107
+ y.detach(),
108
+ alpha=vec_num,
109
+ )
110
+ w1 = w1.clone()
111
+ w1[:, self.target_layers] = w_hat
112
+ else:
113
+ w_hat = self.net.encoder(
114
+ self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
115
+ x.detach(),
116
+ y.detach(),
117
+ alpha=vec_num,
118
+ )
119
+ w1 = w1.clone()
120
+ w1[:, self.target_layers] = (
121
+ w1.clone()[:, self.target_layers]
122
+ + w_hat
123
+ - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
124
+ )
125
+
126
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
127
+
128
+ x1 = self.net.face_pool(x1)
129
+ result = (
130
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
131
+ )
132
+ return (
133
+ result,
134
+ {
135
+ "w1": w1.cpu().detach().numpy(),
136
+ "w1_initial": w1_initial.cpu().detach().numpy(),
137
+ },
138
+ )
139
+
140
+ def change_style(self, latents):
141
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
142
+ w1 = w1_initial.clone()
143
+
144
+ z1 = torch.randn(1, 512).to("cuda")
145
+ x1, w2 = self.net.decoder(
146
+ [z1],
147
+ input_is_latent=False,
148
+ randomize_noise=False,
149
+ return_latents=True,
150
+ truncation=self.truncation,
151
+ truncation_latent=self.net.latent_avg[0],
152
+ )
153
+ w1[:, 6:] = w2.detach()[:, 0]
154
+ x1, w1_new = self.net.decoder(
155
+ [w1],
156
+ input_is_latent=True,
157
+ randomize_noise=False,
158
+ return_latents=True,
159
+ )
160
+ result = (
161
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
162
+ )
163
+ return (
164
+ result,
165
+ {
166
+ "w1": w1_new.cpu().detach().numpy(),
167
+ "w1_initial": w1_new.cpu().detach().numpy(),
168
+ },
169
+ )
170
+
171
+ def reset(self, latents):
172
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
173
+ x1, w1_new, f1 = self.net.decoder(
174
+ [w1_initial],
175
+ input_is_latent=True,
176
+ randomize_noise=False,
177
+ return_feature_map=True,
178
+ return_latents=True,
179
+ )
180
+ result = (
181
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
182
+ )
183
+ return (
184
+ result,
185
+ {
186
+ "w1": w1_new.cpu().detach().numpy(),
187
+ "w1_initial": w1_new.cpu().detach().numpy(),
188
+ },
189
+ )