Mehdi Cherti commited on
Commit
1a02524
1 Parent(s): 572f947

cond grad penalty: use only cond embedding to compute grad

Browse files
score_sde/models/ncsnpp_generator_adagn.py CHANGED
@@ -325,9 +325,12 @@ class NCSNpp(nn.Module):
325
 
326
  hs = [modules[m_idx](x)]
327
  m_idx += 1
 
 
328
  for i_level in range(self.num_resolutions):
329
  # Residual blocks for this resolution
330
  for i_block in range(self.num_res_blocks):
 
331
  h = modules[m_idx](hs[-1], temb, zemb)
332
  m_idx += 1
333
  if h.shape[-1] in self.attn_resolutions:
 
325
 
326
  hs = [modules[m_idx](x)]
327
  m_idx += 1
328
+ #print(self.attn_resolutions)
329
+ #self.attn_resolutions = (32,)
330
  for i_level in range(self.num_resolutions):
331
  # Residual blocks for this resolution
332
  for i_block in range(self.num_res_blocks):
333
+ #print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
334
  h = modules[m_idx](hs[-1], temb, zemb)
335
  m_idx += 1
336
  if h.shape[-1] in self.attn_resolutions:
train_ddgan.py CHANGED
@@ -28,7 +28,10 @@ from torch.multiprocessing import Process
28
  import torch.distributed as dist
29
  import shutil
30
  import logging
31
- import t5
 
 
 
32
  def log_and_continue(exn):
33
  logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
34
  return True
@@ -192,7 +195,11 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
192
  return x
193
 
194
 
195
- from utils import ResampledShards2
 
 
 
 
196
 
197
  def train(rank, gpu, args):
198
  from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
@@ -278,6 +285,7 @@ def train(rank, gpu, args):
278
  ),
279
  ])
280
  pipeline.extend([
 
281
  wds.decode("pilrgb", handler=log_and_continue),
282
  wds.rename(image="jpg;png"),
283
  wds.map_dict(image=train_transform),
@@ -307,7 +315,7 @@ def train(rank, gpu, args):
307
  pin_memory=True,
308
  sampler=train_sampler,
309
  )
310
- text_encoder = t5.T5Encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
311
  args.cond_size = text_encoder.output_size
312
  netG = NCSNpp(args).to(device)
313
  nb_params = 0
@@ -387,7 +395,7 @@ def train(rank, gpu, args):
387
  .format(checkpoint['epoch']))
388
  else:
389
  global_step, epoch, init_epoch = 0, 0, 0
390
-
391
  for epoch in range(init_epoch, args.num_epoch+1):
392
  if args.dataset == "wds":
393
  os.environ["WDS_EPOCH"] = str(epoch)
@@ -419,45 +427,71 @@ def train(rank, gpu, args):
419
  x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
420
  x_t.requires_grad = True
421
 
422
- cond_for_discr = (cond_pooled, cond, cond_mask) if args.discr_type in ("large_cond_attn", "small_cond_attn") else cond_pooled
 
 
 
 
 
 
 
423
 
424
  # train with real
425
  D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
426
 
427
  errD_real = F.softplus(-D_real)
428
  errD_real = errD_real.mean()
 
429
 
430
  errD_real.backward(retain_graph=True)
431
 
432
 
433
  if args.lazy_reg is None:
434
- grad_real = torch.autograd.grad(
435
- outputs=D_real.sum(), inputs=x_t, create_graph=True
436
- )[0]
437
- grad_penalty = (
438
- grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
439
- ).mean()
440
-
441
-
442
- grad_penalty = args.r1_gamma / 2 * grad_penalty
443
- grad_penalty.backward()
444
- else:
445
- if global_step % args.lazy_reg == 0:
446
  grad_real = torch.autograd.grad(
447
- outputs=D_real.sum(), inputs=x_t, create_graph=True
448
- )[0]
449
  grad_penalty = (
450
- grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
451
- ).mean()
452
-
453
-
454
  grad_penalty = args.r1_gamma / 2 * grad_penalty
455
  grad_penalty.backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
  # train with fake
458
  latent_z = torch.randn(batch_size, nz, device=device)
459
 
460
-
461
  x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
462
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
463
 
@@ -466,6 +500,18 @@ def train(rank, gpu, args):
466
 
467
  errD_fake = F.softplus(output)
468
  errD_fake = errD_fake.mean()
 
 
 
 
 
 
 
 
 
 
 
 
469
  errD_fake.backward()
470
 
471
 
@@ -592,6 +638,7 @@ if __name__ == '__main__':
592
 
593
  parser.add_argument('--resume', action='store_true',default=False)
594
  parser.add_argument('--masked_mean', action='store_true',default=False)
 
595
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
596
  parser.add_argument('--cross_attention', action='store_true',default=False)
597
 
@@ -616,7 +663,7 @@ if __name__ == '__main__':
616
  help='channel multiplier')
617
  parser.add_argument('--num_res_blocks', type=int, default=2,
618
  help='number of resnet blocks per scale')
619
- parser.add_argument('--attn_resolutions', default=(16,),
620
  help='resolution of applying attention')
621
  parser.add_argument('--dropout', type=float, default=0.,
622
  help='drop-out rate')
@@ -665,12 +712,14 @@ if __name__ == '__main__':
665
  parser.add_argument('--beta2', type=float, default=0.9,
666
  help='beta2 for adam')
667
  parser.add_argument('--no_lr_decay',action='store_true', default=False)
668
-
 
669
  parser.add_argument('--use_ema', action='store_true', default=False,
670
  help='use EMA or not')
671
  parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
672
 
673
  parser.add_argument('--r1_gamma', type=float, default=0.05, help='coef for r1 reg')
 
674
  parser.add_argument('--lazy_reg', type=int, default=None,
675
  help='lazy regulariation.')
676
 
 
28
  import torch.distributed as dist
29
  import shutil
30
  import logging
31
+ from encoder import build_encoder
32
+ from utils import ResampledShards2
33
+
34
+
35
  def log_and_continue(exn):
36
  logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
37
  return True
 
195
  return x
196
 
197
 
198
+
199
+ def filter_no_caption(sample):
200
+ return 'txt' in sample
201
+
202
+
203
 
204
  def train(rank, gpu, args):
205
  from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
 
285
  ),
286
  ])
287
  pipeline.extend([
288
+ wds.select(filter_no_caption),
289
  wds.decode("pilrgb", handler=log_and_continue),
290
  wds.rename(image="jpg;png"),
291
  wds.map_dict(image=train_transform),
 
315
  pin_memory=True,
316
  sampler=train_sampler,
317
  )
318
+ text_encoder = build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
319
  args.cond_size = text_encoder.output_size
320
  netG = NCSNpp(args).to(device)
321
  nb_params = 0
 
395
  .format(checkpoint['epoch']))
396
  else:
397
  global_step, epoch, init_epoch = 0, 0, 0
398
+ use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn")
399
  for epoch in range(init_epoch, args.num_epoch+1):
400
  if args.dataset == "wds":
401
  os.environ["WDS_EPOCH"] = str(epoch)
 
427
  x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
428
  x_t.requires_grad = True
429
 
430
+ cond_for_discr = (cond_pooled, cond, cond_mask) if use_cond_attn_discr else cond_pooled
431
+ if args.grad_penalty_cond:
432
+ if use_cond_attn_discr:
433
+ #cond_pooled.requires_grad = True
434
+ cond.requires_grad = True
435
+ #cond_mask.requires_grad = True
436
+ else:
437
+ cond_for_discr.requires_grad = True
438
 
439
  # train with real
440
  D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
441
 
442
  errD_real = F.softplus(-D_real)
443
  errD_real = errD_real.mean()
444
+
445
 
446
  errD_real.backward(retain_graph=True)
447
 
448
 
449
  if args.lazy_reg is None:
450
+ if args.grad_penalty_cond:
451
+ inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
452
+ grad_real = torch.autograd.grad(
453
+ outputs=D_real.sum(), inputs=inputs, create_graph=True
454
+ )[0]
455
+ grad_real = torch.cat([g.view(g.size(0), -1) for g in grad_real])
456
+ grad_penalty = (grad_real.norm(2, dim=1) ** 2).mean()
457
+ grad_penalty = args.r1_gamma / 2 * grad_penalty
458
+ grad_penalty.backward()
459
+ else:
 
 
460
  grad_real = torch.autograd.grad(
461
+ outputs=D_real.sum(), inputs=x_t, create_graph=True
462
+ )[0]
463
  grad_penalty = (
464
+ grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
465
+ ).mean()
466
+
467
+
468
  grad_penalty = args.r1_gamma / 2 * grad_penalty
469
  grad_penalty.backward()
470
+ else:
471
+ if global_step % args.lazy_reg == 0:
472
+ if args.grad_penalty_cond:
473
+ inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
474
+ grad_real = torch.autograd.grad(
475
+ outputs=D_real.sum(), inputs=inputs, create_graph=True
476
+ )[0]
477
+ grad_real = torch.cat([g.view(g.size(0), -1) for g in grad_real])
478
+ grad_penalty = (grad_real.norm(2, dim=1) ** 2).mean()
479
+ grad_penalty = args.r1_gamma / 2 * grad_penalty
480
+ grad_penalty.backward()
481
+ else:
482
+ grad_real = torch.autograd.grad(
483
+ outputs=D_real.sum(), inputs=x_t, create_graph=True
484
+ )[0]
485
+ grad_penalty = (
486
+ grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
487
+ ).mean()
488
+
489
+ grad_penalty = args.r1_gamma / 2 * grad_penalty
490
+ grad_penalty.backward()
491
 
492
  # train with fake
493
  latent_z = torch.randn(batch_size, nz, device=device)
494
 
 
495
  x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
496
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
497
 
 
500
 
501
  errD_fake = F.softplus(output)
502
  errD_fake = errD_fake.mean()
503
+
504
+ if args.mismatch_loss:
505
+ # following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
506
+ # we add a discr loss for (real image, non matching text)
507
+ #inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
508
+ inds = torch.cat([torch.arange(1,len(x_t)),torch.arange(1)])
509
+ cond_for_discr_mis = (cond_pooled[inds], cond[inds], cond_mask[inds]) if use_cond_attn_discr else cond_pooled[inds]
510
+ D_real_mis = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr_mis).view(-1)
511
+ errD_real_mis = F.softplus(D_real_mis)
512
+ errD_real_mis = errD_real_mis.mean()
513
+ errD_fake = errD_fake * 0.5 + errD_real_mis * 0.5
514
+
515
  errD_fake.backward()
516
 
517
 
 
638
 
639
  parser.add_argument('--resume', action='store_true',default=False)
640
  parser.add_argument('--masked_mean', action='store_true',default=False)
641
+ parser.add_argument('--mismatch_loss', action='store_true',default=False)
642
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
643
  parser.add_argument('--cross_attention', action='store_true',default=False)
644
 
 
663
  help='channel multiplier')
664
  parser.add_argument('--num_res_blocks', type=int, default=2,
665
  help='number of resnet blocks per scale')
666
+ parser.add_argument('--attn_resolutions', default=(16,), nargs='+', type=int,
667
  help='resolution of applying attention')
668
  parser.add_argument('--dropout', type=float, default=0.,
669
  help='drop-out rate')
 
712
  parser.add_argument('--beta2', type=float, default=0.9,
713
  help='beta2 for adam')
714
  parser.add_argument('--no_lr_decay',action='store_true', default=False)
715
+ parser.add_argument('--grad_penalty_cond', action='store_true',default=False)
716
+
717
  parser.add_argument('--use_ema', action='store_true', default=False,
718
  help='use EMA or not')
719
  parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
720
 
721
  parser.add_argument('--r1_gamma', type=float, default=0.05, help='coef for r1 reg')
722
+
723
  parser.add_argument('--lazy_reg', type=int, default=None,
724
  help='lazy regulariation.')
725