radames HF staff commited on
Commit
f4724fc
β€’
1 Parent(s): c760a5e

restart OOM

Browse files
Files changed (2) hide show
  1. visualizer_drag_gradio.py +1 -1
  2. viz/renderer.py +114 -105
visualizer_drag_gradio.py CHANGED
@@ -915,5 +915,5 @@ with gr.Blocks() as app:
915
 
916
  print("SHAReD: Start app", parser.parse_args())
917
  gr.close_all()
918
- app.queue(concurrency_count=3, max_size=200, api_open=False)
919
  app.launch(share=args.share, show_api=False)
 
915
 
916
  print("SHAReD: Start app", parser.parse_args())
917
  gr.close_all()
918
+ app.queue(concurrency_count=2, max_size=200, api_open=False)
919
  app.launch(share=args.share, show_api=False)
viz/renderer.py CHANGED
@@ -308,111 +308,120 @@ class Renderer:
308
  to_pil=False,
309
  **kwargs
310
  ):
311
- G = self.G
312
- ws = self.w
313
- if ws.dim() == 2:
314
- ws = ws.unsqueeze(1).repeat(1, 6, 1)
315
- ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1)
316
- if hasattr(self, 'points'):
317
- if len(points) != len(self.points):
318
- reset = True
319
- if reset:
320
- self.feat_refs = None
321
- self.points0_pt = None
322
- self.points = points
323
-
324
- # Run synthesis network.
325
- label = torch.zeros([1, G.c_dim], device=self._device)
326
- img, feat = G(ws, label, truncation_psi=trunc_psi,
327
- noise_mode=noise_mode, input_is_w=True, return_feature=True)
328
-
329
- h, w = G.img_resolution, G.img_resolution
330
-
331
- if is_drag:
332
- X = torch.linspace(0, h, h)
333
- Y = torch.linspace(0, w, w)
334
- xx, yy = torch.meshgrid(X, Y)
335
- feat_resize = F.interpolate(
336
- feat[feature_idx], [h, w], mode='bilinear')
337
- if self.feat_refs is None:
338
- self.feat0_resize = F.interpolate(
339
- feat[feature_idx].detach(), [h, w], mode='bilinear')
340
- self.feat_refs = []
341
- for point in points:
342
- py, px = round(point[0]), round(point[1])
343
- self.feat_refs.append(self.feat0_resize[:, :, py, px])
344
- self.points0_pt = torch.Tensor(points).unsqueeze(
345
- 0).to(self._device) # 1, N, 2
346
-
347
- # Point tracking with feature matching
348
- with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  for j, point in enumerate(points):
350
- r = round(r2 / 512 * h)
351
- up = max(point[0] - r, 0)
352
- down = min(point[0] + r + 1, h)
353
- left = max(point[1] - r, 0)
354
- right = min(point[1] + r + 1, w)
355
- feat_patch = feat_resize[:, :, up:down, left:right]
356
- L2 = torch.linalg.norm(
357
- feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1)
358
- _, idx = torch.min(L2.view(1, -1), -1)
359
- width = right - left
360
- point = [idx.item() // width + up, idx.item() %
361
- width + left]
362
- points[j] = point
363
-
364
- res.points = [[point[0], point[1]] for point in points]
365
-
366
- # Motion supervision
367
- loss_motion = 0
368
- res.stop = True
369
- for j, point in enumerate(points):
370
- direction = torch.Tensor(
371
- [targets[j][1] - point[1], targets[j][0] - point[0]])
372
- if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
373
- res.stop = False
374
- if torch.linalg.norm(direction) > 1:
375
- distance = (
376
- (xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
377
- relis, reljs = torch.where(distance < round(r1 / 512 * h))
378
- direction = direction / \
379
- (torch.linalg.norm(direction) + 1e-7)
380
- gridh = (relis-direction[1]) / (h-1) * 2 - 1
381
- gridw = (reljs-direction[0]) / (w-1) * 2 - 1
382
- grid = torch.stack(
383
- [gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
384
- target = F.grid_sample(
385
- feat_resize.float(), grid, align_corners=True).squeeze(2)
386
- loss_motion += F.l1_loss(
387
- feat_resize[:, :, relis, reljs], target.detach())
388
-
389
- loss = loss_motion
390
- if mask is not None:
391
- if mask.min() == 0 and mask.max() == 1:
392
- mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
393
- loss_fix = F.l1_loss(
394
- feat_resize * mask_usq, self.feat0_resize * mask_usq)
395
- loss += lambda_mask * loss_fix
396
-
397
- loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
398
- if not res.stop:
399
- self.w_optim.zero_grad()
400
- loss.backward()
401
- self.w_optim.step()
402
-
403
- # Scale and convert to uint8.
404
- img = img[0]
405
- if img_normalize:
406
- img = img / img.norm(float('inf'),
407
- dim=[1, 2], keepdim=True).clip(1e-8, 1e8)
408
- img = img * (10 ** (img_scale_db / 20))
409
- img = (img * 127.5 + 128).clamp(0,
410
- 255).to(torch.uint8).permute(1, 2, 0)
411
- if to_pil:
412
- from PIL import Image
413
- img = img.cpu().numpy()
414
- img = Image.fromarray(img)
415
- res.image = img
416
- res.w = ws.detach().cpu().numpy()
417
 
418
  # ----------------------------------------------------------------------------
 
308
  to_pil=False,
309
  **kwargs
310
  ):
311
+ try:
312
+ G = self.G
313
+ ws = self.w
314
+ if ws.dim() == 2:
315
+ ws = ws.unsqueeze(1).repeat(1, 6, 1)
316
+ ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1)
317
+ if hasattr(self, 'points'):
318
+ if len(points) != len(self.points):
319
+ reset = True
320
+ if reset:
321
+ self.feat_refs = None
322
+ self.points0_pt = None
323
+ self.points = points
324
+
325
+ # Run synthesis network.
326
+ label = torch.zeros([1, G.c_dim], device=self._device)
327
+ img, feat = G(ws, label, truncation_psi=trunc_psi,
328
+ noise_mode=noise_mode, input_is_w=True, return_feature=True)
329
+
330
+ h, w = G.img_resolution, G.img_resolution
331
+
332
+ if is_drag:
333
+ X = torch.linspace(0, h, h)
334
+ Y = torch.linspace(0, w, w)
335
+ xx, yy = torch.meshgrid(X, Y)
336
+ feat_resize = F.interpolate(
337
+ feat[feature_idx], [h, w], mode='bilinear')
338
+ if self.feat_refs is None:
339
+ self.feat0_resize = F.interpolate(
340
+ feat[feature_idx].detach(), [h, w], mode='bilinear')
341
+ self.feat_refs = []
342
+ for point in points:
343
+ py, px = round(point[0]), round(point[1])
344
+ self.feat_refs.append(self.feat0_resize[:, :, py, px])
345
+ self.points0_pt = torch.Tensor(points).unsqueeze(
346
+ 0).to(self._device) # 1, N, 2
347
+
348
+ # Point tracking with feature matching
349
+ with torch.no_grad():
350
+ for j, point in enumerate(points):
351
+ r = round(r2 / 512 * h)
352
+ up = max(point[0] - r, 0)
353
+ down = min(point[0] + r + 1, h)
354
+ left = max(point[1] - r, 0)
355
+ right = min(point[1] + r + 1, w)
356
+ feat_patch = feat_resize[:, :, up:down, left:right]
357
+ L2 = torch.linalg.norm(
358
+ feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1)
359
+ _, idx = torch.min(L2.view(1, -1), -1)
360
+ width = right - left
361
+ point = [idx.item() // width + up, idx.item() %
362
+ width + left]
363
+ points[j] = point
364
+
365
+ res.points = [[point[0], point[1]] for point in points]
366
+
367
+ # Motion supervision
368
+ loss_motion = 0
369
+ res.stop = True
370
  for j, point in enumerate(points):
371
+ direction = torch.Tensor(
372
+ [targets[j][1] - point[1], targets[j][0] - point[0]])
373
+ if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
374
+ res.stop = False
375
+ if torch.linalg.norm(direction) > 1:
376
+ distance = (
377
+ (xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
378
+ relis, reljs = torch.where(
379
+ distance < round(r1 / 512 * h))
380
+ direction = direction / \
381
+ (torch.linalg.norm(direction) + 1e-7)
382
+ gridh = (relis-direction[1]) / (h-1) * 2 - 1
383
+ gridw = (reljs-direction[0]) / (w-1) * 2 - 1
384
+ grid = torch.stack(
385
+ [gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
386
+ target = F.grid_sample(
387
+ feat_resize.float(), grid, align_corners=True).squeeze(2)
388
+ loss_motion += F.l1_loss(
389
+ feat_resize[:, :, relis, reljs], target.detach())
390
+
391
+ loss = loss_motion
392
+ if mask is not None:
393
+ if mask.min() == 0 and mask.max() == 1:
394
+ mask_usq = mask.to(
395
+ self._device).unsqueeze(0).unsqueeze(0)
396
+ loss_fix = F.l1_loss(
397
+ feat_resize * mask_usq, self.feat0_resize * mask_usq)
398
+ loss += lambda_mask * loss_fix
399
+
400
+ # latent code regularization
401
+ loss += reg * F.l1_loss(ws, self.w0)
402
+ if not res.stop:
403
+ self.w_optim.zero_grad()
404
+ loss.backward()
405
+ self.w_optim.step()
406
+
407
+ # Scale and convert to uint8.
408
+ img = img[0]
409
+ if img_normalize:
410
+ img = img / img.norm(float('inf'),
411
+ dim=[1, 2], keepdim=True).clip(1e-8, 1e8)
412
+ img = img * (10 ** (img_scale_db / 20))
413
+ img = (img * 127.5 + 128).clamp(0,
414
+ 255).to(torch.uint8).permute(1, 2, 0)
415
+ if to_pil:
416
+ from PIL import Image
417
+ img = img.cpu().numpy()
418
+ img = Image.fromarray(img)
419
+ res.image = img
420
+ res.w = ws.detach().cpu().numpy()
421
+ except Exception as e:
422
+ import os
423
+ print(f'Renderer error: {e}')
424
+ print("Out of memory error occurred. Restarting the app...")
425
+ os.execv(sys.executable, ['python'] + sys.argv)
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
  # ----------------------------------------------------------------------------