radames HF staff commited on
Commit
e722a6c
β€’
1 Parent(s): 1294e7a

merge with main

Browse files
Files changed (3) hide show
  1. viz/capture_widget.py +3 -2
  2. viz/drag_widget.py +1 -0
  3. viz/renderer.py +84 -108
viz/capture_widget.py CHANGED
@@ -31,7 +31,7 @@ class CaptureWidget:
31
  viz = self.viz
32
  try:
33
  _height, _width, channels = image.shape
34
- assert channels in [1, 3]
35
  assert image.dtype == np.uint8
36
  os.makedirs(self.path, exist_ok=True)
37
  file_id = 0
@@ -43,8 +43,9 @@ class CaptureWidget:
43
  if channels == 1:
44
  pil_image = PIL.Image.fromarray(image[:, :, 0], 'L')
45
  else:
46
- pil_image = PIL.Image.fromarray(image, 'RGB')
47
  pil_image.save(os.path.join(self.path, f'{file_id:05d}.png'))
 
48
  except:
49
  viz.result.error = renderer.CapturedException()
50
 
 
31
  viz = self.viz
32
  try:
33
  _height, _width, channels = image.shape
34
+ print(viz.result)
35
  assert image.dtype == np.uint8
36
  os.makedirs(self.path, exist_ok=True)
37
  file_id = 0
 
43
  if channels == 1:
44
  pil_image = PIL.Image.fromarray(image[:, :, 0], 'L')
45
  else:
46
+ pil_image = PIL.Image.fromarray(image[:, :, :3], 'RGB')
47
  pil_image.save(os.path.join(self.path, f'{file_id:05d}.png'))
48
+ np.save(os.path.join(self.path, f'{file_id:05d}.npy'), viz.result.w)
49
  except:
50
  viz.result.error = renderer.CapturedException()
51
 
viz/drag_widget.py CHANGED
@@ -90,6 +90,7 @@ class DragWidget:
90
  @imgui_utils.scoped_by_object_id
91
  def __call__(self, show=True):
92
  viz = self.viz
 
93
  if show:
94
  with imgui_utils.grayed_out(self.disabled_time != 0):
95
  imgui.text('Drag')
 
90
  @imgui_utils.scoped_by_object_id
91
  def __call__(self, show=True):
92
  viz = self.viz
93
+ reset = False
94
  if show:
95
  with imgui_utils.grayed_out(self.disabled_time != 0):
96
  imgui.text('Drag')
viz/renderer.py CHANGED
@@ -20,10 +20,9 @@ import torch.nn.functional as F
20
  import matplotlib.cm
21
  import dnnlib
22
  from torch_utils.ops import upfirdn2d
23
- import legacy # pylint: disable=import-error
24
-
25
- # ----------------------------------------------------------------------------
26
 
 
27
 
28
  class CapturedException(Exception):
29
  def __init__(self, msg=None):
@@ -37,16 +36,14 @@ class CapturedException(Exception):
37
  assert isinstance(msg, str)
38
  super().__init__(msg)
39
 
40
- # ----------------------------------------------------------------------------
41
-
42
 
43
  class CaptureSuccess(Exception):
44
  def __init__(self, out):
45
  super().__init__()
46
  self.out = out
47
 
48
- # ----------------------------------------------------------------------------
49
-
50
 
51
  def add_watermark_np(input_image_array, watermark_text="AI Generated"):
52
  image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
@@ -57,10 +54,8 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
57
  d = ImageDraw.Draw(txt)
58
 
59
  text_width, text_height = font.getsize(watermark_text)
60
- text_position = (image.size[0] - text_width -
61
- 10, image.size[1] - text_height - 10)
62
- # white color with the alpha channel set to semi-transparent
63
- text_color = (255, 255, 255, 128)
64
 
65
  # Draw the text onto the text canvas
66
  d.text(text_position, watermark_text, font=font, fill=text_color)
@@ -70,22 +65,22 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
70
  watermarked_array = np.array(watermarked)
71
  return watermarked_array
72
 
73
- # ----------------------------------------------------------------------------
74
-
75
 
76
  class Renderer:
77
  def __init__(self, disable_timing=False):
78
- self._device = torch.device('cuda')
79
- self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
80
- self._networks = dict() # {cache_key: torch.nn.Module, ...}
81
- self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
82
- self._cmaps = dict() # {name: torch.Tensor, ...}
83
- self._is_timing = False
 
84
  if not disable_timing:
85
- self._start_event = torch.cuda.Event(enable_timing=True)
86
- self._end_event = torch.cuda.Event(enable_timing=True)
87
  self._disable_timing = disable_timing
88
- self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
89
 
90
  def render(self, **args):
91
  if self._disable_timing:
@@ -101,6 +96,9 @@ class Renderer:
101
  if hasattr(self, 'pkl'):
102
  if self.pkl != args['pkl']:
103
  init_net = True
 
 
 
104
  if hasattr(self, 'w0_seed'):
105
  if self.w0_seed != args['w0_seed']:
106
  init_net = True
@@ -128,8 +126,7 @@ class Renderer:
128
 
129
  if self._is_timing and not self._disable_timing:
130
  self._end_event.synchronize()
131
- res.render_time = self._start_event.elapsed_time(
132
- self._end_event) * 1e-3
133
  self._is_timing = False
134
  return res
135
 
@@ -150,8 +147,7 @@ class Renderer:
150
  raise data
151
 
152
  orig_net = data[key]
153
- cache_key = (orig_net, self._device, tuple(
154
- sorted(tweak_kwargs.items())))
155
  net = self._networks.get(cache_key, None)
156
  if net is None:
157
  try:
@@ -167,11 +163,9 @@ class Renderer:
167
  print(data[key].init_args)
168
  print(data[key].init_kwargs)
169
  if 'stylegan_human' in pkl:
170
- net = Generator(
171
- *data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
172
  else:
173
- net = Generator(*data[key].init_args,
174
- **data[key].init_kwargs)
175
  net.load_state_dict(data[key].state_dict())
176
  net.to(self._device)
177
  except:
@@ -212,27 +206,25 @@ class Renderer:
212
  return x
213
 
214
  def init_network(self, res,
215
- pkl=None,
216
- w0_seed=0,
217
- w_load=None,
218
- w_plus=True,
219
- noise_mode='const',
220
- trunc_psi=0.7,
221
- trunc_cutoff=None,
222
- input_transform=None,
223
- lr=0.001,
224
- **kwargs
225
- ):
226
  # Dig up network details.
227
  self.pkl = pkl
228
  G = self.get_network(pkl, 'G_ema')
229
  self.G = G
230
  res.img_resolution = G.img_resolution
231
  res.num_ws = G.num_ws
232
- res.has_noise = any('noise_const' in name for name,
233
- _buf in G.synthesis.named_buffers())
234
- res.has_input_transform = (
235
- hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
236
 
237
  # Set input transform.
238
  if res.has_input_transform:
@@ -250,13 +242,11 @@ class Renderer:
250
 
251
  if self.w_load is None:
252
  # Generate random latents.
253
- z = torch.from_numpy(np.random.RandomState(
254
- w0_seed).randn(1, 512)).to(self._device).float()
255
 
256
  # Run mapping network.
257
  label = torch.zeros([1, G.c_dim], device=self._device)
258
- w = G.mapping(z, label, truncation_psi=trunc_psi,
259
- truncation_cutoff=trunc_cutoff)
260
  else:
261
  w = self.w_load.clone().to(self._device)
262
 
@@ -280,34 +270,34 @@ class Renderer:
280
  print(' Remain feat_refs and points0_pt')
281
 
282
  def _render_drag_impl(self, res,
283
- points=[],
284
- targets=[],
285
- mask=None,
286
- lambda_mask=10,
287
- reg=0,
288
- feature_idx=5,
289
- r1=3,
290
- r2=12,
291
- random_seed=0,
292
- noise_mode='const',
293
- trunc_psi=0.7,
294
- force_fp32=False,
295
- layer_name=None,
296
- sel_channels=3,
297
- base_channel=0,
298
- img_scale_db=0,
299
- img_normalize=False,
300
- untransform=False,
301
- is_drag=False,
302
- reset=False,
303
- to_pil=False,
304
- **kwargs
305
- ):
306
  G = self.G
307
  ws = self.w
308
  if ws.dim() == 2:
309
- ws = ws.unsqueeze(1).repeat(1, 6, 1)
310
- ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1)
311
  if hasattr(self, 'points'):
312
  if len(points) != len(self.points):
313
  reset = True
@@ -318,8 +308,7 @@ class Renderer:
318
 
319
  # Run synthesis network.
320
  label = torch.zeros([1, G.c_dim], device=self._device)
321
- img, feat = G(ws, label, truncation_psi=trunc_psi,
322
- noise_mode=noise_mode, input_is_w=True, return_feature=True)
323
 
324
  h, w = G.img_resolution, G.img_resolution
325
 
@@ -327,17 +316,14 @@ class Renderer:
327
  X = torch.linspace(0, h, h)
328
  Y = torch.linspace(0, w, w)
329
  xx, yy = torch.meshgrid(X, Y)
330
- feat_resize = F.interpolate(
331
- feat[feature_idx], [h, w], mode='bilinear')
332
  if self.feat_refs is None:
333
- self.feat0_resize = F.interpolate(
334
- feat[feature_idx].detach(), [h, w], mode='bilinear')
335
  self.feat_refs = []
336
  for point in points:
337
  py, px = round(point[0]), round(point[1])
338
- self.feat_refs.append(self.feat0_resize[:, :, py, px])
339
- self.points0_pt = torch.Tensor(points).unsqueeze(
340
- 0).to(self._device) # 1, N, 2
341
 
342
  # Point tracking with feature matching
343
  with torch.no_grad():
@@ -347,13 +333,11 @@ class Renderer:
347
  down = min(point[0] + r + 1, h)
348
  left = max(point[1] - r, 0)
349
  right = min(point[1] + r + 1, w)
350
- feat_patch = feat_resize[:, :, up:down, left:right]
351
- L2 = torch.linalg.norm(
352
- feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1)
353
- _, idx = torch.min(L2.view(1, -1), -1)
354
  width = right - left
355
- point = [idx.item() // width + up, idx.item() %
356
- width + left]
357
  points[j] = point
358
 
359
  res.points = [[point[0], point[1]] for point in points]
@@ -362,31 +346,24 @@ class Renderer:
362
  loss_motion = 0
363
  res.stop = True
364
  for j, point in enumerate(points):
365
- direction = torch.Tensor(
366
- [targets[j][1] - point[1], targets[j][0] - point[0]])
367
  if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
368
  res.stop = False
369
  if torch.linalg.norm(direction) > 1:
370
- distance = (
371
- (xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
372
  relis, reljs = torch.where(distance < round(r1 / 512 * h))
373
- direction = direction / \
374
- (torch.linalg.norm(direction) + 1e-7)
375
  gridh = (relis-direction[1]) / (h-1) * 2 - 1
376
  gridw = (reljs-direction[0]) / (w-1) * 2 - 1
377
- grid = torch.stack(
378
- [gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
379
- target = F.grid_sample(
380
- feat_resize.float(), grid, align_corners=True).squeeze(2)
381
- loss_motion += F.l1_loss(
382
- feat_resize[:, :, relis, reljs], target.detach())
383
 
384
  loss = loss_motion
385
  if mask is not None:
386
  if mask.min() == 0 and mask.max() == 1:
387
  mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
388
- loss_fix = F.l1_loss(
389
- feat_resize * mask_usq, self.feat0_resize * mask_usq)
390
  loss += lambda_mask * loss_fix
391
 
392
  loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
@@ -398,15 +375,14 @@ class Renderer:
398
  # Scale and convert to uint8.
399
  img = img[0]
400
  if img_normalize:
401
- img = img / img.norm(float('inf'),
402
- dim=[1, 2], keepdim=True).clip(1e-8, 1e8)
403
  img = img * (10 ** (img_scale_db / 20))
404
- img = (img * 127.5 + 128).clamp(0,
405
- 255).to(torch.uint8).permute(1, 2, 0)
406
  if to_pil:
407
  from PIL import Image
408
  img = img.cpu().numpy()
409
  img = Image.fromarray(img)
410
  res.image = img
 
411
 
412
- # ----------------------------------------------------------------------------
 
20
  import matplotlib.cm
21
  import dnnlib
22
  from torch_utils.ops import upfirdn2d
23
+ import legacy # pylint: disable=import-error
 
 
24
 
25
+ #----------------------------------------------------------------------------
26
 
27
  class CapturedException(Exception):
28
  def __init__(self, msg=None):
 
36
  assert isinstance(msg, str)
37
  super().__init__(msg)
38
 
39
+ #----------------------------------------------------------------------------
 
40
 
41
  class CaptureSuccess(Exception):
42
  def __init__(self, out):
43
  super().__init__()
44
  self.out = out
45
 
46
+ #----------------------------------------------------------------------------
 
47
 
48
  def add_watermark_np(input_image_array, watermark_text="AI Generated"):
49
  image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
 
54
  d = ImageDraw.Draw(txt)
55
 
56
  text_width, text_height = font.getsize(watermark_text)
57
+ text_position = (image.size[0] - text_width - 10, image.size[1] - text_height - 10)
58
+ text_color = (255, 255, 255, 128) # white color with the alpha channel set to semi-transparent
 
 
59
 
60
  # Draw the text onto the text canvas
61
  d.text(text_position, watermark_text, font=font, fill=text_color)
 
65
  watermarked_array = np.array(watermarked)
66
  return watermarked_array
67
 
68
+ #----------------------------------------------------------------------------
 
69
 
70
  class Renderer:
71
  def __init__(self, disable_timing=False):
72
+ self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
73
+ self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
74
+ self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
75
+ self._networks = dict() # {cache_key: torch.nn.Module, ...}
76
+ self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
77
+ self._cmaps = dict() # {name: torch.Tensor, ...}
78
+ self._is_timing = False
79
  if not disable_timing:
80
+ self._start_event = torch.cuda.Event(enable_timing=True)
81
+ self._end_event = torch.cuda.Event(enable_timing=True)
82
  self._disable_timing = disable_timing
83
+ self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
84
 
85
  def render(self, **args):
86
  if self._disable_timing:
 
96
  if hasattr(self, 'pkl'):
97
  if self.pkl != args['pkl']:
98
  init_net = True
99
+ if hasattr(self, 'w_load'):
100
+ if self.w_load is not args['w_load']:
101
+ init_net = True
102
  if hasattr(self, 'w0_seed'):
103
  if self.w0_seed != args['w0_seed']:
104
  init_net = True
 
126
 
127
  if self._is_timing and not self._disable_timing:
128
  self._end_event.synchronize()
129
+ res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3
 
130
  self._is_timing = False
131
  return res
132
 
 
147
  raise data
148
 
149
  orig_net = data[key]
150
+ cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items())))
 
151
  net = self._networks.get(cache_key, None)
152
  if net is None:
153
  try:
 
163
  print(data[key].init_args)
164
  print(data[key].init_kwargs)
165
  if 'stylegan_human' in pkl:
166
+ net = Generator(*data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
 
167
  else:
168
+ net = Generator(*data[key].init_args, **data[key].init_kwargs)
 
169
  net.load_state_dict(data[key].state_dict())
170
  net.to(self._device)
171
  except:
 
206
  return x
207
 
208
  def init_network(self, res,
209
+ pkl = None,
210
+ w0_seed = 0,
211
+ w_load = None,
212
+ w_plus = True,
213
+ noise_mode = 'const',
214
+ trunc_psi = 0.7,
215
+ trunc_cutoff = None,
216
+ input_transform = None,
217
+ lr = 0.001,
218
+ **kwargs
219
+ ):
220
  # Dig up network details.
221
  self.pkl = pkl
222
  G = self.get_network(pkl, 'G_ema')
223
  self.G = G
224
  res.img_resolution = G.img_resolution
225
  res.num_ws = G.num_ws
226
+ res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers())
227
+ res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
 
 
228
 
229
  # Set input transform.
230
  if res.has_input_transform:
 
242
 
243
  if self.w_load is None:
244
  # Generate random latents.
245
+ z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype)
 
246
 
247
  # Run mapping network.
248
  label = torch.zeros([1, G.c_dim], device=self._device)
249
+ w = G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff)
 
250
  else:
251
  w = self.w_load.clone().to(self._device)
252
 
 
270
  print(' Remain feat_refs and points0_pt')
271
 
272
  def _render_drag_impl(self, res,
273
+ points = [],
274
+ targets = [],
275
+ mask = None,
276
+ lambda_mask = 10,
277
+ reg = 0,
278
+ feature_idx = 5,
279
+ r1 = 3,
280
+ r2 = 12,
281
+ random_seed = 0,
282
+ noise_mode = 'const',
283
+ trunc_psi = 0.7,
284
+ force_fp32 = False,
285
+ layer_name = None,
286
+ sel_channels = 3,
287
+ base_channel = 0,
288
+ img_scale_db = 0,
289
+ img_normalize = False,
290
+ untransform = False,
291
+ is_drag = False,
292
+ reset = False,
293
+ to_pil = False,
294
+ **kwargs
295
+ ):
296
  G = self.G
297
  ws = self.w
298
  if ws.dim() == 2:
299
+ ws = ws.unsqueeze(1).repeat(1,6,1)
300
+ ws = torch.cat([ws[:,:6,:], self.w0[:,6:,:]], dim=1)
301
  if hasattr(self, 'points'):
302
  if len(points) != len(self.points):
303
  reset = True
 
308
 
309
  # Run synthesis network.
310
  label = torch.zeros([1, G.c_dim], device=self._device)
311
+ img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True)
 
312
 
313
  h, w = G.img_resolution, G.img_resolution
314
 
 
316
  X = torch.linspace(0, h, h)
317
  Y = torch.linspace(0, w, w)
318
  xx, yy = torch.meshgrid(X, Y)
319
+ feat_resize = F.interpolate(feat[feature_idx], [h, w], mode='bilinear')
 
320
  if self.feat_refs is None:
321
+ self.feat0_resize = F.interpolate(feat[feature_idx].detach(), [h, w], mode='bilinear')
 
322
  self.feat_refs = []
323
  for point in points:
324
  py, px = round(point[0]), round(point[1])
325
+ self.feat_refs.append(self.feat0_resize[:,:,py,px])
326
+ self.points0_pt = torch.Tensor(points).unsqueeze(0).to(self._device) # 1, N, 2
 
327
 
328
  # Point tracking with feature matching
329
  with torch.no_grad():
 
333
  down = min(point[0] + r + 1, h)
334
  left = max(point[1] - r, 0)
335
  right = min(point[1] + r + 1, w)
336
+ feat_patch = feat_resize[:,:,up:down,left:right]
337
+ L2 = torch.linalg.norm(feat_patch - self.feat_refs[j].reshape(1,-1,1,1), dim=1)
338
+ _, idx = torch.min(L2.view(1,-1), -1)
 
339
  width = right - left
340
+ point = [idx.item() // width + up, idx.item() % width + left]
 
341
  points[j] = point
342
 
343
  res.points = [[point[0], point[1]] for point in points]
 
346
  loss_motion = 0
347
  res.stop = True
348
  for j, point in enumerate(points):
349
+ direction = torch.Tensor([targets[j][1] - point[1], targets[j][0] - point[0]])
 
350
  if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
351
  res.stop = False
352
  if torch.linalg.norm(direction) > 1:
353
+ distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
 
354
  relis, reljs = torch.where(distance < round(r1 / 512 * h))
355
+ direction = direction / (torch.linalg.norm(direction) + 1e-7)
 
356
  gridh = (relis-direction[1]) / (h-1) * 2 - 1
357
  gridw = (reljs-direction[0]) / (w-1) * 2 - 1
358
+ grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0)
359
+ target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2)
360
+ loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs], target.detach())
 
 
 
361
 
362
  loss = loss_motion
363
  if mask is not None:
364
  if mask.min() == 0 and mask.max() == 1:
365
  mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
366
+ loss_fix = F.l1_loss(feat_resize * mask_usq, self.feat0_resize * mask_usq)
 
367
  loss += lambda_mask * loss_fix
368
 
369
  loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
 
375
  # Scale and convert to uint8.
376
  img = img[0]
377
  if img_normalize:
378
+ img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8)
 
379
  img = img * (10 ** (img_scale_db / 20))
380
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
 
381
  if to_pil:
382
  from PIL import Image
383
  img = img.cpu().numpy()
384
  img = Image.fromarray(img)
385
  res.image = img
386
+ res.w = ws.detach().cpu().numpy()
387
 
388
+ #----------------------------------------------------------------------------