Spaces:
Running
on
A10G
Running
on
A10G
merge with main
Browse files- viz/capture_widget.py +3 -2
- viz/drag_widget.py +1 -0
- viz/renderer.py +84 -108
viz/capture_widget.py
CHANGED
@@ -31,7 +31,7 @@ class CaptureWidget:
|
|
31 |
viz = self.viz
|
32 |
try:
|
33 |
_height, _width, channels = image.shape
|
34 |
-
|
35 |
assert image.dtype == np.uint8
|
36 |
os.makedirs(self.path, exist_ok=True)
|
37 |
file_id = 0
|
@@ -43,8 +43,9 @@ class CaptureWidget:
|
|
43 |
if channels == 1:
|
44 |
pil_image = PIL.Image.fromarray(image[:, :, 0], 'L')
|
45 |
else:
|
46 |
-
pil_image = PIL.Image.fromarray(image, 'RGB')
|
47 |
pil_image.save(os.path.join(self.path, f'{file_id:05d}.png'))
|
|
|
48 |
except:
|
49 |
viz.result.error = renderer.CapturedException()
|
50 |
|
|
|
31 |
viz = self.viz
|
32 |
try:
|
33 |
_height, _width, channels = image.shape
|
34 |
+
print(viz.result)
|
35 |
assert image.dtype == np.uint8
|
36 |
os.makedirs(self.path, exist_ok=True)
|
37 |
file_id = 0
|
|
|
43 |
if channels == 1:
|
44 |
pil_image = PIL.Image.fromarray(image[:, :, 0], 'L')
|
45 |
else:
|
46 |
+
pil_image = PIL.Image.fromarray(image[:, :, :3], 'RGB')
|
47 |
pil_image.save(os.path.join(self.path, f'{file_id:05d}.png'))
|
48 |
+
np.save(os.path.join(self.path, f'{file_id:05d}.npy'), viz.result.w)
|
49 |
except:
|
50 |
viz.result.error = renderer.CapturedException()
|
51 |
|
viz/drag_widget.py
CHANGED
@@ -90,6 +90,7 @@ class DragWidget:
|
|
90 |
@imgui_utils.scoped_by_object_id
|
91 |
def __call__(self, show=True):
|
92 |
viz = self.viz
|
|
|
93 |
if show:
|
94 |
with imgui_utils.grayed_out(self.disabled_time != 0):
|
95 |
imgui.text('Drag')
|
|
|
90 |
@imgui_utils.scoped_by_object_id
|
91 |
def __call__(self, show=True):
|
92 |
viz = self.viz
|
93 |
+
reset = False
|
94 |
if show:
|
95 |
with imgui_utils.grayed_out(self.disabled_time != 0):
|
96 |
imgui.text('Drag')
|
viz/renderer.py
CHANGED
@@ -20,10 +20,9 @@ import torch.nn.functional as F
|
|
20 |
import matplotlib.cm
|
21 |
import dnnlib
|
22 |
from torch_utils.ops import upfirdn2d
|
23 |
-
import legacy
|
24 |
-
|
25 |
-
# ----------------------------------------------------------------------------
|
26 |
|
|
|
27 |
|
28 |
class CapturedException(Exception):
|
29 |
def __init__(self, msg=None):
|
@@ -37,16 +36,14 @@ class CapturedException(Exception):
|
|
37 |
assert isinstance(msg, str)
|
38 |
super().__init__(msg)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
|
43 |
class CaptureSuccess(Exception):
|
44 |
def __init__(self, out):
|
45 |
super().__init__()
|
46 |
self.out = out
|
47 |
|
48 |
-
|
49 |
-
|
50 |
|
51 |
def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
52 |
image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
|
@@ -57,10 +54,8 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
|
57 |
d = ImageDraw.Draw(txt)
|
58 |
|
59 |
text_width, text_height = font.getsize(watermark_text)
|
60 |
-
text_position = (image.size[0] - text_width -
|
61 |
-
|
62 |
-
# white color with the alpha channel set to semi-transparent
|
63 |
-
text_color = (255, 255, 255, 128)
|
64 |
|
65 |
# Draw the text onto the text canvas
|
66 |
d.text(text_position, watermark_text, font=font, fill=text_color)
|
@@ -70,22 +65,22 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
|
70 |
watermarked_array = np.array(watermarked)
|
71 |
return watermarked_array
|
72 |
|
73 |
-
|
74 |
-
|
75 |
|
76 |
class Renderer:
|
77 |
def __init__(self, disable_timing=False):
|
78 |
-
self._device
|
79 |
-
self.
|
80 |
-
self.
|
81 |
-
self.
|
82 |
-
self.
|
83 |
-
self.
|
|
|
84 |
if not disable_timing:
|
85 |
-
self._start_event
|
86 |
-
self._end_event
|
87 |
self._disable_timing = disable_timing
|
88 |
-
self._net_layers
|
89 |
|
90 |
def render(self, **args):
|
91 |
if self._disable_timing:
|
@@ -101,6 +96,9 @@ class Renderer:
|
|
101 |
if hasattr(self, 'pkl'):
|
102 |
if self.pkl != args['pkl']:
|
103 |
init_net = True
|
|
|
|
|
|
|
104 |
if hasattr(self, 'w0_seed'):
|
105 |
if self.w0_seed != args['w0_seed']:
|
106 |
init_net = True
|
@@ -128,8 +126,7 @@ class Renderer:
|
|
128 |
|
129 |
if self._is_timing and not self._disable_timing:
|
130 |
self._end_event.synchronize()
|
131 |
-
res.render_time = self._start_event.elapsed_time(
|
132 |
-
self._end_event) * 1e-3
|
133 |
self._is_timing = False
|
134 |
return res
|
135 |
|
@@ -150,8 +147,7 @@ class Renderer:
|
|
150 |
raise data
|
151 |
|
152 |
orig_net = data[key]
|
153 |
-
cache_key = (orig_net, self._device, tuple(
|
154 |
-
sorted(tweak_kwargs.items())))
|
155 |
net = self._networks.get(cache_key, None)
|
156 |
if net is None:
|
157 |
try:
|
@@ -167,11 +163,9 @@ class Renderer:
|
|
167 |
print(data[key].init_args)
|
168 |
print(data[key].init_kwargs)
|
169 |
if 'stylegan_human' in pkl:
|
170 |
-
net = Generator(
|
171 |
-
*data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
|
172 |
else:
|
173 |
-
net = Generator(*data[key].init_args,
|
174 |
-
**data[key].init_kwargs)
|
175 |
net.load_state_dict(data[key].state_dict())
|
176 |
net.to(self._device)
|
177 |
except:
|
@@ -212,27 +206,25 @@ class Renderer:
|
|
212 |
return x
|
213 |
|
214 |
def init_network(self, res,
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
# Dig up network details.
|
227 |
self.pkl = pkl
|
228 |
G = self.get_network(pkl, 'G_ema')
|
229 |
self.G = G
|
230 |
res.img_resolution = G.img_resolution
|
231 |
res.num_ws = G.num_ws
|
232 |
-
res.has_noise = any('noise_const' in name for name,
|
233 |
-
|
234 |
-
res.has_input_transform = (
|
235 |
-
hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
|
236 |
|
237 |
# Set input transform.
|
238 |
if res.has_input_transform:
|
@@ -250,13 +242,11 @@ class Renderer:
|
|
250 |
|
251 |
if self.w_load is None:
|
252 |
# Generate random latents.
|
253 |
-
z = torch.from_numpy(np.random.RandomState(
|
254 |
-
w0_seed).randn(1, 512)).to(self._device).float()
|
255 |
|
256 |
# Run mapping network.
|
257 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
258 |
-
w = G.mapping(z, label, truncation_psi=trunc_psi,
|
259 |
-
truncation_cutoff=trunc_cutoff)
|
260 |
else:
|
261 |
w = self.w_load.clone().to(self._device)
|
262 |
|
@@ -280,34 +270,34 @@ class Renderer:
|
|
280 |
print(' Remain feat_refs and points0_pt')
|
281 |
|
282 |
def _render_drag_impl(self, res,
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
G = self.G
|
307 |
ws = self.w
|
308 |
if ws.dim() == 2:
|
309 |
-
ws = ws.unsqueeze(1).repeat(1,
|
310 |
-
ws = torch.cat([ws[
|
311 |
if hasattr(self, 'points'):
|
312 |
if len(points) != len(self.points):
|
313 |
reset = True
|
@@ -318,8 +308,7 @@ class Renderer:
|
|
318 |
|
319 |
# Run synthesis network.
|
320 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
321 |
-
img, feat = G(ws, label, truncation_psi=trunc_psi,
|
322 |
-
noise_mode=noise_mode, input_is_w=True, return_feature=True)
|
323 |
|
324 |
h, w = G.img_resolution, G.img_resolution
|
325 |
|
@@ -327,17 +316,14 @@ class Renderer:
|
|
327 |
X = torch.linspace(0, h, h)
|
328 |
Y = torch.linspace(0, w, w)
|
329 |
xx, yy = torch.meshgrid(X, Y)
|
330 |
-
feat_resize = F.interpolate(
|
331 |
-
feat[feature_idx], [h, w], mode='bilinear')
|
332 |
if self.feat_refs is None:
|
333 |
-
self.feat0_resize = F.interpolate(
|
334 |
-
feat[feature_idx].detach(), [h, w], mode='bilinear')
|
335 |
self.feat_refs = []
|
336 |
for point in points:
|
337 |
py, px = round(point[0]), round(point[1])
|
338 |
-
self.feat_refs.append(self.feat0_resize[
|
339 |
-
self.points0_pt = torch.Tensor(points).unsqueeze(
|
340 |
-
0).to(self._device) # 1, N, 2
|
341 |
|
342 |
# Point tracking with feature matching
|
343 |
with torch.no_grad():
|
@@ -347,13 +333,11 @@ class Renderer:
|
|
347 |
down = min(point[0] + r + 1, h)
|
348 |
left = max(point[1] - r, 0)
|
349 |
right = min(point[1] + r + 1, w)
|
350 |
-
feat_patch = feat_resize[
|
351 |
-
L2 = torch.linalg.norm(
|
352 |
-
|
353 |
-
_, idx = torch.min(L2.view(1, -1), -1)
|
354 |
width = right - left
|
355 |
-
point = [idx.item() // width + up, idx.item() %
|
356 |
-
width + left]
|
357 |
points[j] = point
|
358 |
|
359 |
res.points = [[point[0], point[1]] for point in points]
|
@@ -362,31 +346,24 @@ class Renderer:
|
|
362 |
loss_motion = 0
|
363 |
res.stop = True
|
364 |
for j, point in enumerate(points):
|
365 |
-
direction = torch.Tensor(
|
366 |
-
[targets[j][1] - point[1], targets[j][0] - point[0]])
|
367 |
if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
|
368 |
res.stop = False
|
369 |
if torch.linalg.norm(direction) > 1:
|
370 |
-
distance = (
|
371 |
-
(xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
|
372 |
relis, reljs = torch.where(distance < round(r1 / 512 * h))
|
373 |
-
direction = direction /
|
374 |
-
(torch.linalg.norm(direction) + 1e-7)
|
375 |
gridh = (relis-direction[1]) / (h-1) * 2 - 1
|
376 |
gridw = (reljs-direction[0]) / (w-1) * 2 - 1
|
377 |
-
grid = torch.stack(
|
378 |
-
|
379 |
-
|
380 |
-
feat_resize.float(), grid, align_corners=True).squeeze(2)
|
381 |
-
loss_motion += F.l1_loss(
|
382 |
-
feat_resize[:, :, relis, reljs], target.detach())
|
383 |
|
384 |
loss = loss_motion
|
385 |
if mask is not None:
|
386 |
if mask.min() == 0 and mask.max() == 1:
|
387 |
mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
|
388 |
-
loss_fix = F.l1_loss(
|
389 |
-
feat_resize * mask_usq, self.feat0_resize * mask_usq)
|
390 |
loss += lambda_mask * loss_fix
|
391 |
|
392 |
loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
|
@@ -398,15 +375,14 @@ class Renderer:
|
|
398 |
# Scale and convert to uint8.
|
399 |
img = img[0]
|
400 |
if img_normalize:
|
401 |
-
img = img / img.norm(float('inf'),
|
402 |
-
dim=[1, 2], keepdim=True).clip(1e-8, 1e8)
|
403 |
img = img * (10 ** (img_scale_db / 20))
|
404 |
-
img = (img * 127.5 + 128).clamp(0,
|
405 |
-
255).to(torch.uint8).permute(1, 2, 0)
|
406 |
if to_pil:
|
407 |
from PIL import Image
|
408 |
img = img.cpu().numpy()
|
409 |
img = Image.fromarray(img)
|
410 |
res.image = img
|
|
|
411 |
|
412 |
-
|
|
|
20 |
import matplotlib.cm
|
21 |
import dnnlib
|
22 |
from torch_utils.ops import upfirdn2d
|
23 |
+
import legacy # pylint: disable=import-error
|
|
|
|
|
24 |
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
|
27 |
class CapturedException(Exception):
|
28 |
def __init__(self, msg=None):
|
|
|
36 |
assert isinstance(msg, str)
|
37 |
super().__init__(msg)
|
38 |
|
39 |
+
#----------------------------------------------------------------------------
|
|
|
40 |
|
41 |
class CaptureSuccess(Exception):
|
42 |
def __init__(self, out):
|
43 |
super().__init__()
|
44 |
self.out = out
|
45 |
|
46 |
+
#----------------------------------------------------------------------------
|
|
|
47 |
|
48 |
def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
49 |
image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
|
|
|
54 |
d = ImageDraw.Draw(txt)
|
55 |
|
56 |
text_width, text_height = font.getsize(watermark_text)
|
57 |
+
text_position = (image.size[0] - text_width - 10, image.size[1] - text_height - 10)
|
58 |
+
text_color = (255, 255, 255, 128) # white color with the alpha channel set to semi-transparent
|
|
|
|
|
59 |
|
60 |
# Draw the text onto the text canvas
|
61 |
d.text(text_position, watermark_text, font=font, fill=text_color)
|
|
|
65 |
watermarked_array = np.array(watermarked)
|
66 |
return watermarked_array
|
67 |
|
68 |
+
#----------------------------------------------------------------------------
|
|
|
69 |
|
70 |
class Renderer:
|
71 |
def __init__(self, disable_timing=False):
|
72 |
+
self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
73 |
+
self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
|
74 |
+
self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
|
75 |
+
self._networks = dict() # {cache_key: torch.nn.Module, ...}
|
76 |
+
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
|
77 |
+
self._cmaps = dict() # {name: torch.Tensor, ...}
|
78 |
+
self._is_timing = False
|
79 |
if not disable_timing:
|
80 |
+
self._start_event = torch.cuda.Event(enable_timing=True)
|
81 |
+
self._end_event = torch.cuda.Event(enable_timing=True)
|
82 |
self._disable_timing = disable_timing
|
83 |
+
self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
|
84 |
|
85 |
def render(self, **args):
|
86 |
if self._disable_timing:
|
|
|
96 |
if hasattr(self, 'pkl'):
|
97 |
if self.pkl != args['pkl']:
|
98 |
init_net = True
|
99 |
+
if hasattr(self, 'w_load'):
|
100 |
+
if self.w_load is not args['w_load']:
|
101 |
+
init_net = True
|
102 |
if hasattr(self, 'w0_seed'):
|
103 |
if self.w0_seed != args['w0_seed']:
|
104 |
init_net = True
|
|
|
126 |
|
127 |
if self._is_timing and not self._disable_timing:
|
128 |
self._end_event.synchronize()
|
129 |
+
res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3
|
|
|
130 |
self._is_timing = False
|
131 |
return res
|
132 |
|
|
|
147 |
raise data
|
148 |
|
149 |
orig_net = data[key]
|
150 |
+
cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items())))
|
|
|
151 |
net = self._networks.get(cache_key, None)
|
152 |
if net is None:
|
153 |
try:
|
|
|
163 |
print(data[key].init_args)
|
164 |
print(data[key].init_kwargs)
|
165 |
if 'stylegan_human' in pkl:
|
166 |
+
net = Generator(*data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
|
|
|
167 |
else:
|
168 |
+
net = Generator(*data[key].init_args, **data[key].init_kwargs)
|
|
|
169 |
net.load_state_dict(data[key].state_dict())
|
170 |
net.to(self._device)
|
171 |
except:
|
|
|
206 |
return x
|
207 |
|
208 |
def init_network(self, res,
|
209 |
+
pkl = None,
|
210 |
+
w0_seed = 0,
|
211 |
+
w_load = None,
|
212 |
+
w_plus = True,
|
213 |
+
noise_mode = 'const',
|
214 |
+
trunc_psi = 0.7,
|
215 |
+
trunc_cutoff = None,
|
216 |
+
input_transform = None,
|
217 |
+
lr = 0.001,
|
218 |
+
**kwargs
|
219 |
+
):
|
220 |
# Dig up network details.
|
221 |
self.pkl = pkl
|
222 |
G = self.get_network(pkl, 'G_ema')
|
223 |
self.G = G
|
224 |
res.img_resolution = G.img_resolution
|
225 |
res.num_ws = G.num_ws
|
226 |
+
res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers())
|
227 |
+
res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
|
|
|
|
|
228 |
|
229 |
# Set input transform.
|
230 |
if res.has_input_transform:
|
|
|
242 |
|
243 |
if self.w_load is None:
|
244 |
# Generate random latents.
|
245 |
+
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype)
|
|
|
246 |
|
247 |
# Run mapping network.
|
248 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
249 |
+
w = G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff)
|
|
|
250 |
else:
|
251 |
w = self.w_load.clone().to(self._device)
|
252 |
|
|
|
270 |
print(' Remain feat_refs and points0_pt')
|
271 |
|
272 |
def _render_drag_impl(self, res,
|
273 |
+
points = [],
|
274 |
+
targets = [],
|
275 |
+
mask = None,
|
276 |
+
lambda_mask = 10,
|
277 |
+
reg = 0,
|
278 |
+
feature_idx = 5,
|
279 |
+
r1 = 3,
|
280 |
+
r2 = 12,
|
281 |
+
random_seed = 0,
|
282 |
+
noise_mode = 'const',
|
283 |
+
trunc_psi = 0.7,
|
284 |
+
force_fp32 = False,
|
285 |
+
layer_name = None,
|
286 |
+
sel_channels = 3,
|
287 |
+
base_channel = 0,
|
288 |
+
img_scale_db = 0,
|
289 |
+
img_normalize = False,
|
290 |
+
untransform = False,
|
291 |
+
is_drag = False,
|
292 |
+
reset = False,
|
293 |
+
to_pil = False,
|
294 |
+
**kwargs
|
295 |
+
):
|
296 |
G = self.G
|
297 |
ws = self.w
|
298 |
if ws.dim() == 2:
|
299 |
+
ws = ws.unsqueeze(1).repeat(1,6,1)
|
300 |
+
ws = torch.cat([ws[:,:6,:], self.w0[:,6:,:]], dim=1)
|
301 |
if hasattr(self, 'points'):
|
302 |
if len(points) != len(self.points):
|
303 |
reset = True
|
|
|
308 |
|
309 |
# Run synthesis network.
|
310 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
311 |
+
img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True)
|
|
|
312 |
|
313 |
h, w = G.img_resolution, G.img_resolution
|
314 |
|
|
|
316 |
X = torch.linspace(0, h, h)
|
317 |
Y = torch.linspace(0, w, w)
|
318 |
xx, yy = torch.meshgrid(X, Y)
|
319 |
+
feat_resize = F.interpolate(feat[feature_idx], [h, w], mode='bilinear')
|
|
|
320 |
if self.feat_refs is None:
|
321 |
+
self.feat0_resize = F.interpolate(feat[feature_idx].detach(), [h, w], mode='bilinear')
|
|
|
322 |
self.feat_refs = []
|
323 |
for point in points:
|
324 |
py, px = round(point[0]), round(point[1])
|
325 |
+
self.feat_refs.append(self.feat0_resize[:,:,py,px])
|
326 |
+
self.points0_pt = torch.Tensor(points).unsqueeze(0).to(self._device) # 1, N, 2
|
|
|
327 |
|
328 |
# Point tracking with feature matching
|
329 |
with torch.no_grad():
|
|
|
333 |
down = min(point[0] + r + 1, h)
|
334 |
left = max(point[1] - r, 0)
|
335 |
right = min(point[1] + r + 1, w)
|
336 |
+
feat_patch = feat_resize[:,:,up:down,left:right]
|
337 |
+
L2 = torch.linalg.norm(feat_patch - self.feat_refs[j].reshape(1,-1,1,1), dim=1)
|
338 |
+
_, idx = torch.min(L2.view(1,-1), -1)
|
|
|
339 |
width = right - left
|
340 |
+
point = [idx.item() // width + up, idx.item() % width + left]
|
|
|
341 |
points[j] = point
|
342 |
|
343 |
res.points = [[point[0], point[1]] for point in points]
|
|
|
346 |
loss_motion = 0
|
347 |
res.stop = True
|
348 |
for j, point in enumerate(points):
|
349 |
+
direction = torch.Tensor([targets[j][1] - point[1], targets[j][0] - point[0]])
|
|
|
350 |
if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
|
351 |
res.stop = False
|
352 |
if torch.linalg.norm(direction) > 1:
|
353 |
+
distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
|
|
|
354 |
relis, reljs = torch.where(distance < round(r1 / 512 * h))
|
355 |
+
direction = direction / (torch.linalg.norm(direction) + 1e-7)
|
|
|
356 |
gridh = (relis-direction[1]) / (h-1) * 2 - 1
|
357 |
gridw = (reljs-direction[0]) / (w-1) * 2 - 1
|
358 |
+
grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0)
|
359 |
+
target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2)
|
360 |
+
loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs], target.detach())
|
|
|
|
|
|
|
361 |
|
362 |
loss = loss_motion
|
363 |
if mask is not None:
|
364 |
if mask.min() == 0 and mask.max() == 1:
|
365 |
mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
|
366 |
+
loss_fix = F.l1_loss(feat_resize * mask_usq, self.feat0_resize * mask_usq)
|
|
|
367 |
loss += lambda_mask * loss_fix
|
368 |
|
369 |
loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
|
|
|
375 |
# Scale and convert to uint8.
|
376 |
img = img[0]
|
377 |
if img_normalize:
|
378 |
+
img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8)
|
|
|
379 |
img = img * (10 ** (img_scale_db / 20))
|
380 |
+
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
|
|
|
381 |
if to_pil:
|
382 |
from PIL import Image
|
383 |
img = img.cpu().numpy()
|
384 |
img = Image.fromarray(img)
|
385 |
res.image = img
|
386 |
+
res.w = ws.detach().cpu().numpy()
|
387 |
|
388 |
+
#----------------------------------------------------------------------------
|