radames commited on
Commit
8a8b805
1 Parent(s): 67eaa47

add state for initial position

Browse files
Files changed (2) hide show
  1. interface/app.py +20 -14
  2. interface/model_loader.py +240 -240
interface/app.py CHANGED
@@ -28,13 +28,13 @@ def random_sample(model_name: str):
28
  return pil_img, model_name, latents
29
 
30
 
31
- def zoom(model_state, latents_state, dx=0, dy=0, dz=0):
32
  model = models[model_state]
33
  dx = dx
34
  dy = dy
35
  dz = dz
36
- sx = 100
37
- sy = 100
38
  stop_points = []
39
  img, latents_state = model.zoom(
40
  latents_state, dz, sxsy=[sx, sy], stop_points=stop_points
@@ -43,14 +43,14 @@ def zoom(model_state, latents_state, dx=0, dy=0, dz=0):
43
  return pil_img, latents_state
44
 
45
 
46
- def translate(model_state, latents_state, dx=0, dy=0, dz=0):
47
  model = models[model_state]
48
 
49
  dx = dx
50
  dy = dy
51
  dz = dz
52
- sx = 128
53
- sy = 128
54
  stop_points = []
55
  zi = False
56
  zo = False
@@ -82,9 +82,15 @@ def reset(model_state, latents_state):
82
  return pil_img, latents_state
83
 
84
 
 
 
 
 
 
85
  with gr.Blocks() as block:
86
  model_state = gr.State(value="cat")
87
  latents_state = gr.State({})
 
88
  gr.Markdown("# UserControllableLT: User controllable latent transformer")
89
  gr.Markdown("## Select model")
90
  with gr.Row():
@@ -99,13 +105,13 @@ with gr.Blocks() as block:
99
  reset_btn = gr.Button("Reset")
100
 
101
  dx = gr.Slider(
102
- minimum=-128, maximum=128, step_size=0.1, label="dx", value=0.0
103
  )
104
  dy = gr.Slider(
105
- minimum=-128, maximum=128, step_size=0.1, label="dy", value=0.0
106
  )
107
  dz = gr.Slider(
108
- minimum=-128, maximum=128, step_size=0.1, label="dz", value=0.0
109
  )
110
 
111
  with gr.Row():
@@ -113,10 +119,10 @@ with gr.Blocks() as block:
113
 
114
  with gr.Column():
115
  image = gr.Image(type="pil", label="")
 
116
  button.click(
117
  random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
118
  )
119
-
120
  reset_btn.click(
121
  reset,
122
  inputs=[model_state, latents_state],
@@ -130,22 +136,22 @@ with gr.Blocks() as block:
130
  )
131
  dx.change(
132
  translate,
133
- inputs=[model_state, latents_state, dx, dy, dz],
134
  outputs=[image, latents_state],
135
  show_progress=False,
136
  )
137
  dy.change(
138
  translate,
139
- inputs=[model_state, latents_state, dx, dy, dz],
140
  outputs=[image, latents_state],
141
  show_progress=False,
142
  )
143
  dz.change(
144
  zoom,
145
- inputs=[model_state, latents_state, dx, dy, dz],
146
  outputs=[image, latents_state],
147
  show_progress=False,
148
  )
149
 
150
-
151
  block.launch()
 
28
  return pil_img, model_name, latents
29
 
30
 
31
+ def zoom(model_state, latents_state, dx=0, dy=0, dz=0, sxsy=[128, 128]):
32
  model = models[model_state]
33
  dx = dx
34
  dy = dy
35
  dz = dz
36
+ sx = sxsy[0]
37
+ sy = sxsy[1]
38
  stop_points = []
39
  img, latents_state = model.zoom(
40
  latents_state, dz, sxsy=[sx, sy], stop_points=stop_points
 
43
  return pil_img, latents_state
44
 
45
 
46
+ def translate(model_state, latents_state, dx=0, dy=0, dz=0, sxsy=[128, 128]):
47
  model = models[model_state]
48
 
49
  dx = dx
50
  dy = dy
51
  dz = dz
52
+ sx = sxsy[0]
53
+ sy = sxsy[1]
54
  stop_points = []
55
  zi = False
56
  zo = False
 
82
  return pil_img, latents_state
83
 
84
 
85
+ def image_click(evt: gr.SelectData):
86
+ click_pos = evt.index
87
+ return click_pos
88
+
89
+
90
  with gr.Blocks() as block:
91
  model_state = gr.State(value="cat")
92
  latents_state = gr.State({})
93
+ sxsy = gr.State([128, 128])
94
  gr.Markdown("# UserControllableLT: User controllable latent transformer")
95
  gr.Markdown("## Select model")
96
  with gr.Row():
 
105
  reset_btn = gr.Button("Reset")
106
 
107
  dx = gr.Slider(
108
+ minimum=-256, maximum=256, step_size=0.1, label="dx", value=0.0
109
  )
110
  dy = gr.Slider(
111
+ minimum=-256, maximum=256, step_size=0.1, label="dy", value=0.0
112
  )
113
  dz = gr.Slider(
114
+ minimum=-256, maximum=256, step_size=0.1, label="dz", value=0.0
115
  )
116
 
117
  with gr.Row():
 
119
 
120
  with gr.Column():
121
  image = gr.Image(type="pil", label="")
122
+ image.select(image_click, inputs=None, outputs=sxsy)
123
  button.click(
124
  random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
125
  )
 
126
  reset_btn.click(
127
  reset,
128
  inputs=[model_state, latents_state],
 
136
  )
137
  dx.change(
138
  translate,
139
+ inputs=[model_state, latents_state, dx, dy, dz, sxsy],
140
  outputs=[image, latents_state],
141
  show_progress=False,
142
  )
143
  dy.change(
144
  translate,
145
+ inputs=[model_state, latents_state, dx, dy, dz, sxsy],
146
  outputs=[image, latents_state],
147
  show_progress=False,
148
  )
149
  dz.change(
150
  zoom,
151
+ inputs=[model_state, latents_state, dx, dy, dz, sxsy],
152
  outputs=[image, latents_state],
153
  show_progress=False,
154
  )
155
 
156
+ block.queue()
157
  block.launch()
interface/model_loader.py CHANGED
@@ -1,240 +1,240 @@
1
- import os
2
- from argparse import Namespace
3
- import numpy as np
4
- import torch
5
-
6
- from models.StyleGANControler import StyleGANControler
7
-
8
-
9
- class Model:
10
- def __init__(
11
- self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
12
- ):
13
- self.truncation = truncation
14
- self.use_average_code_as_input = use_average_code_as_input
15
- ckpt = torch.load(checkpoint_path, map_location="cpu")
16
- opts = ckpt["opts"]
17
- opts["checkpoint_path"] = checkpoint_path
18
- self.opts = Namespace(**ckpt["opts"])
19
- self.net = StyleGANControler(self.opts)
20
- self.net.eval()
21
- self.net.cuda()
22
- self.target_layers = [0, 1, 2, 3, 4, 5]
23
-
24
- def random_sample(self):
25
- z1 = torch.randn(1, 512).to("cuda")
26
- x1, w1, f1 = self.net.decoder(
27
- [z1],
28
- input_is_latent=False,
29
- randomize_noise=False,
30
- return_feature_map=True,
31
- return_latents=True,
32
- truncation=self.truncation,
33
- truncation_latent=self.net.latent_avg[0],
34
- )
35
- w1_initial = w1.clone()
36
- x1 = self.net.face_pool(x1)
37
- image = (
38
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
39
- )
40
- return (
41
- image,
42
- {
43
- "w1": w1.cpu().detach().numpy(),
44
- "w1_initial": w1_initial.cpu().detach().numpy(),
45
- },
46
- ) # return latent vector along with the image
47
-
48
- def latents_to_tensor(self, latents):
49
- w1 = latents["w1"]
50
- w1_initial = latents["w1_initial"]
51
-
52
- w1 = torch.tensor(w1).to("cuda")
53
- w1_initial = torch.tensor(w1_initial).to("cuda")
54
-
55
- x1, w1, f1 = self.net.decoder(
56
- [w1],
57
- input_is_latent=True,
58
- randomize_noise=False,
59
- return_feature_map=True,
60
- return_latents=True,
61
- )
62
- x1, w1_initial, f1 = self.net.decoder(
63
- [w1_initial],
64
- input_is_latent=True,
65
- randomize_noise=False,
66
- return_feature_map=True,
67
- return_latents=True,
68
- )
69
-
70
- return (w1, w1_initial, f1)
71
-
72
- def zoom(self, latents, dz, sxsy=[0, 0], stop_points=[]):
73
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
74
- w1 = w1_initial.clone()
75
-
76
- vec_num = abs(dz) / 5
77
- dz = 100 * np.sign(dz)
78
- x = torch.from_numpy(np.array([[[1.0, 0, dz]]], dtype=np.float32)).cuda()
79
- f1 = torch.nn.functional.interpolate(f1, (256, 256))
80
- y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
81
-
82
- if len(stop_points) > 0:
83
- x = torch.cat(
84
- [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
85
- )
86
- tmp = []
87
- for sp in stop_points:
88
- tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
89
- y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
90
-
91
- if not self.use_average_code_as_input:
92
- w_hat = self.net.encoder(
93
- w1[:, self.target_layers].detach(),
94
- x.detach(),
95
- y.detach(),
96
- alpha=vec_num,
97
- )
98
- w1 = w1.clone()
99
- w1[:, self.target_layers] = w_hat
100
- else:
101
- w_hat = self.net.encoder(
102
- self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
103
- x.detach(),
104
- y.detach(),
105
- alpha=vec_num,
106
- )
107
- w1 = w1.clone()
108
- w1[:, self.target_layers] = (
109
- w1.clone()[:, self.target_layers]
110
- + w_hat
111
- - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
112
- )
113
-
114
- x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
115
-
116
- x1 = self.net.face_pool(x1)
117
- result = (
118
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
119
- )
120
- return (
121
- result,
122
- {
123
- "w1": w1.cpu().detach().numpy(),
124
- "w1_initial": w1_initial.cpu().detach().numpy(),
125
- },
126
- ) # return latent vector along with the image
127
-
128
- def translate(
129
- self, latents, dxy, sxsy=[0, 0], stop_points=[], zoom_in=False, zoom_out=False
130
- ):
131
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
132
- w1 = w1_initial.clone()
133
- dz = -5.0 if zoom_in else 0.0
134
- dz = 5.0 if zoom_out else dz
135
-
136
- dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
137
- dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
138
- dxyz[:2] = dxyz[:2] / dxy_norm
139
- vec_num = dxy_norm / 10
140
-
141
- x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
142
- f1 = torch.nn.functional.interpolate(f1, (256, 256))
143
- y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
144
-
145
- if len(stop_points) > 0:
146
- x = torch.cat(
147
- [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
148
- )
149
- tmp = []
150
- for sp in stop_points:
151
- tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
152
- y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
153
-
154
- if not self.use_average_code_as_input:
155
- w_hat = self.net.encoder(
156
- w1[:, self.target_layers].detach(),
157
- x.detach(),
158
- y.detach(),
159
- alpha=vec_num,
160
- )
161
- w1 = w1.clone()
162
- w1[:, self.target_layers] = w_hat
163
- else:
164
- w_hat = self.net.encoder(
165
- self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
166
- x.detach(),
167
- y.detach(),
168
- alpha=vec_num,
169
- )
170
- w1 = w1.clone()
171
- w1[:, self.target_layers] = (
172
- w1.clone()[:, self.target_layers]
173
- + w_hat
174
- - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
175
- )
176
-
177
- x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
178
-
179
- x1 = self.net.face_pool(x1)
180
- result = (
181
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
182
- )
183
- return (
184
- result,
185
- {
186
- "w1": w1.cpu().detach().numpy(),
187
- "w1_initial": w1_initial.cpu().detach().numpy(),
188
- },
189
- )
190
-
191
- def change_style(self, latents):
192
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
193
- w1 = w1_initial.clone()
194
-
195
- z1 = torch.randn(1, 512).to("cuda")
196
- x1, w2 = self.net.decoder(
197
- [z1],
198
- input_is_latent=False,
199
- randomize_noise=False,
200
- return_latents=True,
201
- truncation=self.truncation,
202
- truncation_latent=self.net.latent_avg[0],
203
- )
204
- w1[:, 6:] = w2.detach()[:, 0]
205
- x1, w1_new = self.net.decoder(
206
- [w1],
207
- input_is_latent=True,
208
- randomize_noise=False,
209
- return_latents=True,
210
- )
211
- result = (
212
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
213
- )
214
- return (
215
- result,
216
- {
217
- "w1": w1_new.cpu().detach().numpy(),
218
- "w1_initial": w1_new.cpu().detach().numpy(),
219
- },
220
- )
221
-
222
- def reset(self, latents):
223
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
224
- x1, w1_new, f1 = self.net.decoder(
225
- [w1_initial],
226
- input_is_latent=True,
227
- randomize_noise=False,
228
- return_feature_map=True,
229
- return_latents=True,
230
- )
231
- result = (
232
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
233
- )
234
- return (
235
- result,
236
- {
237
- "w1": w1_new.cpu().detach().numpy(),
238
- "w1_initial": w1_new.cpu().detach().numpy(),
239
- },
240
- )
 
1
+ import os
2
+ from argparse import Namespace
3
+ import numpy as np
4
+ import torch
5
+
6
+ from models.StyleGANControler import StyleGANControler
7
+
8
+
9
+ class Model:
10
+ def __init__(
11
+ self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
12
+ ):
13
+ self.truncation = truncation
14
+ self.use_average_code_as_input = use_average_code_as_input
15
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
16
+ opts = ckpt["opts"]
17
+ opts["checkpoint_path"] = checkpoint_path
18
+ self.opts = Namespace(**ckpt["opts"])
19
+ self.net = StyleGANControler(self.opts)
20
+ self.net.eval()
21
+ self.net.cuda()
22
+ self.target_layers = [0, 1, 2, 3, 4, 5]
23
+
24
+ def random_sample(self):
25
+ z1 = torch.randn(1, 512).to("cuda")
26
+ x1, w1, f1 = self.net.decoder(
27
+ [z1],
28
+ input_is_latent=False,
29
+ randomize_noise=False,
30
+ return_feature_map=True,
31
+ return_latents=True,
32
+ truncation=self.truncation,
33
+ truncation_latent=self.net.latent_avg[0],
34
+ )
35
+ w1_initial = w1.clone()
36
+ x1 = self.net.face_pool(x1)
37
+ image = (
38
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
39
+ )
40
+ return (
41
+ image,
42
+ {
43
+ "w1": w1.cpu().detach().numpy(),
44
+ "w1_initial": w1_initial.cpu().detach().numpy(),
45
+ },
46
+ ) # return latent vector along with the image
47
+
48
+ def latents_to_tensor(self, latents):
49
+ w1 = latents["w1"]
50
+ w1_initial = latents["w1_initial"]
51
+
52
+ w1 = torch.tensor(w1).to("cuda")
53
+ w1_initial = torch.tensor(w1_initial).to("cuda")
54
+
55
+ x1, w1, f1 = self.net.decoder(
56
+ [w1],
57
+ input_is_latent=True,
58
+ randomize_noise=False,
59
+ return_feature_map=True,
60
+ return_latents=True,
61
+ )
62
+ x1, w1_initial, f1 = self.net.decoder(
63
+ [w1_initial],
64
+ input_is_latent=True,
65
+ randomize_noise=False,
66
+ return_feature_map=True,
67
+ return_latents=True,
68
+ )
69
+
70
+ return (w1, w1_initial, f1)
71
+
72
+ def zoom(self, latents, dz, sxsy=[0, 0], stop_points=[]):
73
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
74
+ w1 = w1_initial.clone()
75
+
76
+ vec_num = abs(dz) / 5
77
+ dz = 100 * np.sign(dz)
78
+ x = torch.from_numpy(np.array([[[1.0, 0, dz]]], dtype=np.float32)).cuda()
79
+ f1 = torch.nn.functional.interpolate(f1, (256, 256))
80
+ y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
81
+
82
+ if len(stop_points) > 0:
83
+ x = torch.cat(
84
+ [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
85
+ )
86
+ tmp = []
87
+ for sp in stop_points:
88
+ tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
89
+ y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
90
+
91
+ if not self.use_average_code_as_input:
92
+ w_hat = self.net.encoder(
93
+ w1[:, self.target_layers].detach(),
94
+ x.detach(),
95
+ y.detach(),
96
+ alpha=vec_num,
97
+ )
98
+ w1 = w1.clone()
99
+ w1[:, self.target_layers] = w_hat
100
+ else:
101
+ w_hat = self.net.encoder(
102
+ self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
103
+ x.detach(),
104
+ y.detach(),
105
+ alpha=vec_num,
106
+ )
107
+ w1 = w1.clone()
108
+ w1[:, self.target_layers] = (
109
+ w1.clone()[:, self.target_layers]
110
+ + w_hat
111
+ - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
112
+ )
113
+
114
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
115
+
116
+ x1 = self.net.face_pool(x1)
117
+ result = (
118
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
119
+ )
120
+ return (
121
+ result,
122
+ {
123
+ "w1": w1.cpu().detach().numpy(),
124
+ "w1_initial": w1_initial.cpu().detach().numpy(),
125
+ },
126
+ ) # return latent vector along with the image
127
+
128
+ def translate(
129
+ self, latents, dxy, sxsy=[0, 0], stop_points=[], zoom_in=False, zoom_out=False
130
+ ):
131
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
132
+ w1 = w1_initial.clone()
133
+ dz = -5.0 if zoom_in else 0.0
134
+ dz = 5.0 if zoom_out else dz
135
+
136
+ dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
137
+ dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
138
+ dxyz[:2] = dxyz[:2] / dxy_norm
139
+ vec_num = dxy_norm / 10
140
+
141
+ x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
142
+ f1 = torch.nn.functional.interpolate(f1, (256, 256))
143
+ y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
144
+
145
+ if len(stop_points) > 0:
146
+ x = torch.cat(
147
+ [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
148
+ )
149
+ tmp = []
150
+ for sp in stop_points:
151
+ tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
152
+ y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
153
+
154
+ if not self.use_average_code_as_input:
155
+ w_hat = self.net.encoder(
156
+ w1[:, self.target_layers].detach(),
157
+ x.detach(),
158
+ y.detach(),
159
+ alpha=vec_num,
160
+ )
161
+ w1 = w1.clone()
162
+ w1[:, self.target_layers] = w_hat
163
+ else:
164
+ w_hat = self.net.encoder(
165
+ self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
166
+ x.detach(),
167
+ y.detach(),
168
+ alpha=vec_num,
169
+ )
170
+ w1 = w1.clone()
171
+ w1[:, self.target_layers] = (
172
+ w1.clone()[:, self.target_layers]
173
+ + w_hat
174
+ - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
175
+ )
176
+
177
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
178
+
179
+ x1 = self.net.face_pool(x1)
180
+ result = (
181
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
182
+ )
183
+ return (
184
+ result,
185
+ {
186
+ "w1": w1.cpu().detach().numpy(),
187
+ "w1_initial": w1_initial.cpu().detach().numpy(),
188
+ },
189
+ )
190
+
191
+ def change_style(self, latents):
192
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
193
+ w1 = w1_initial.clone()
194
+
195
+ z1 = torch.randn(1, 512).to("cuda")
196
+ x1, w2 = self.net.decoder(
197
+ [z1],
198
+ input_is_latent=False,
199
+ randomize_noise=False,
200
+ return_latents=True,
201
+ truncation=self.truncation,
202
+ truncation_latent=self.net.latent_avg[0],
203
+ )
204
+ w1[:, 6:] = w2.detach()[:, 0]
205
+ x1, w1_new = self.net.decoder(
206
+ [w1],
207
+ input_is_latent=True,
208
+ randomize_noise=False,
209
+ return_latents=True,
210
+ )
211
+ result = (
212
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
213
+ )
214
+ return (
215
+ result,
216
+ {
217
+ "w1": w1_new.cpu().detach().numpy(),
218
+ "w1_initial": w1_new.cpu().detach().numpy(),
219
+ },
220
+ )
221
+
222
+ def reset(self, latents):
223
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
224
+ x1, w1_new, f1 = self.net.decoder(
225
+ [w1_initial],
226
+ input_is_latent=True,
227
+ randomize_noise=False,
228
+ return_feature_map=True,
229
+ return_latents=True,
230
+ )
231
+ result = (
232
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
233
+ )
234
+ return (
235
+ result,
236
+ {
237
+ "w1": w1_new.cpu().detach().numpy(),
238
+ "w1_initial": w1_new.cpu().detach().numpy(),
239
+ },
240
+ )