Mehdi Cherti commited on
Commit
27911d6
1 Parent(s): ee8f9c5
Files changed (3) hide show
  1. clip_encoder.py +1 -1
  2. run.py → model_configs.py +36 -93
  3. test_ddgan.py +83 -157
clip_encoder.py CHANGED
@@ -16,7 +16,7 @@ class CLIPEncoder(nn.Module):
16
  self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)
17
  self.output_size = self.model.transformer.width
18
 
19
- def forward(self, texts, return_only_pooled=True):
20
  device = next(self.parameters()).device
21
  toks = open_clip.tokenize(texts).to(device)
22
  x = self.model.token_embedding(toks) # [batch_size, n_ctx, d_model]
 
16
  self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)
17
  self.output_size = self.model.transformer.width
18
 
19
+ def forward(self, texts, return_only_pooled=False):
20
  device = next(self.parameters()).device
21
  toks = open_clip.tokenize(texts).to(device)
22
  x = self.model.token_embedding(toks) # [batch_size, n_ctx, d_model]
run.py → model_configs.py RENAMED
@@ -10,30 +10,41 @@ def base():
10
  "n": 8,
11
  },
12
  "model":{
13
- "dataset" :"wds",
14
- "dataset_root": "/p/scratch/ccstdl/cherti1/CC12M/{00000..01099}.tar",
15
- "image_size": 256,
16
  "num_channels": 3,
 
 
 
 
17
  "num_channels_dae": 128,
18
- "ch_mult": "1 1 2 2 4 4",
19
- "num_timesteps": 4,
20
  "num_res_blocks": 2,
21
- "batch_size": 8,
22
- "num_epoch": 1000,
23
- "ngf": 64,
 
 
 
 
 
 
 
 
24
  "embedding_type": "positional",
25
- "use_ema": "",
26
- "ema_decay": 0.999,
27
- "r1_gamma": 1.0,
 
 
28
  "z_emb_dim": 256,
29
- "lr_d": 1e-4,
30
- "lr_g": 1.6e-4,
31
- "lazy_reg": 10,
32
- "save_content": "",
33
- "save_ckpt_every": 1,
34
- "masked_mean": "",
35
- "resume": "",
36
- },
37
  }
38
  def ddgan_cc12m_v2():
39
  cfg = base()
@@ -72,7 +83,7 @@ def ddgan_cc12m_v11():
72
  cfg = base()
73
  cfg['model']['text_encoder'] = "google/t5-v1_1-large"
74
  cfg['model']['classifier_free_guidance_proba'] = 0.2
75
- cfg['model']['cross_attention'] = ""
76
  return cfg
77
 
78
  def ddgan_cc12m_v12():
@@ -102,7 +113,7 @@ def ddgan_cifar10_cond17():
102
  cfg['model']['image_size'] = 32
103
  cfg['model']['classifier_free_guidance_proba'] = 0.2
104
  cfg['model']['ch_mult'] = "1 2 2 2"
105
- cfg['model']['cross_attention'] = ""
106
  cfg['model']['dataset'] = "cifar10"
107
  cfg['model']['n_mlp'] = 4
108
  return cfg
@@ -276,7 +287,7 @@ def ddgan_ddb_v7():
276
 
277
  def ddgan_ddb_v9():
278
  cfg = ddgan_ddb_v3()
279
- cfg['model']['attn_resolutions'] = '4 8 16 32'
280
  return cfg
281
 
282
  def ddgan_laion_aesthetic_v15():
@@ -313,6 +324,7 @@ models = [
313
  ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
314
  ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
315
  ddgan_cc12m_v15, # fine-tune v11 with --mismatch_loss and --grad_penalty_cond
 
316
  ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
317
  ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
318
  ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
@@ -352,76 +364,7 @@ models = [
352
  ddgan_ddb_v12,
353
  ]
354
 
355
- def get_model(model_name):
356
  for model in models:
357
  if model.__name__ == model_name:
358
- return model()
359
-
360
-
361
- def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional", compute_image_reward=False):
362
-
363
- cfg = get_model(model_name)
364
- model = cfg['model']
365
- if epoch is None:
366
- paths = glob('./saved_info/dd_gan/{}/{}/netG_*.pth'.format(model["dataset"], model_name))
367
- epoch = max(
368
- [int(os.path.basename(path).replace(".pth", "").split("_")[1]) for path in paths]
369
- )
370
- args = {}
371
- args['exp'] = model_name
372
- args['image_size'] = model['image_size']
373
- args['seed'] = seed
374
- args['num_channels'] = model['num_channels']
375
- args['dataset'] = model['dataset']
376
- args['num_channels_dae'] = model['num_channels_dae']
377
- args['ch_mult'] = model['ch_mult']
378
- args['num_timesteps'] = model['num_timesteps']
379
- args['num_res_blocks'] = model['num_res_blocks']
380
- args['batch_size'] = model['batch_size'] if batch_size is None else batch_size
381
- args['epoch'] = epoch
382
- args['cond_text'] = f'"{cond_text}"'
383
- args['text_encoder'] = model.get("text_encoder")
384
- args['cross_attention'] = model.get("cross_attention")
385
- args['guidance_scale'] = guidance_scale
386
- args['masked_mean'] = model.get("masked_mean")
387
- args['dynamic_thresholding_quantile'] = q
388
- args['scale_factor_h'] = scale_factor_h
389
- args['scale_factor_w'] = scale_factor_w
390
- args['n_mlp'] = model.get("n_mlp")
391
- args['scale_method'] = scale_method
392
- args['attn_resolutions'] = model.get("attn_resolutions", "16")
393
- if fid:
394
- args['compute_fid'] = ''
395
- args['real_img_dir'] = real_img_dir
396
- args['nb_images_for_fid'] = nb_images_for_fid
397
- if compute_clip_score:
398
- args['compute_clip_score'] = ""
399
-
400
- if compute_image_reward:
401
- args['compute_image_reward'] = ""
402
- if eval_name:
403
- args["eval_name"] = eval_name
404
- cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
405
- print(cmd)
406
- call(cmd, shell=True)
407
-
408
- def eval_results(model_name):
409
- import pandas as pd
410
- rows = []
411
- cfg = get_model(model_name)
412
- model = cfg['model']
413
- paths = glob('./saved_info/dd_gan/{}/{}/fid*.json'.format(model["dataset"], model_name))
414
- for path in paths:
415
- with open(path, "r") as fd:
416
- data = json.load(fd)
417
- row = {}
418
- row['fid'] = data['fid']
419
- row['epoch'] = data['epoch_id']
420
- rows.append(row)
421
- out = './saved_info/dd_gan/{}/{}/fid.csv'.format(model["dataset"], model_name)
422
- df = pd.DataFrame(rows)
423
- df.to_csv(out, index=False)
424
-
425
- if __name__ == "__main__":
426
- from clize import run
427
- run([test, eval_results])
 
10
  "n": 8,
11
  },
12
  "model":{
13
+ "dataset": "wds",
14
+ "seed": 0,
15
+ "cross_attention": False,
16
  "num_channels": 3,
17
+ "centered": True,
18
+ "use_geometric": False,
19
+ "beta_min": 0.1,
20
+ "beta_max": 20.0,
21
  "num_channels_dae": 128,
22
+ "n_mlp": 3,
23
+ "ch_mult": [1, 1, 2, 2, 4, 4],
24
  "num_res_blocks": 2,
25
+ "attn_resolutions": (16,),
26
+ "dropout": 0.0,
27
+ "resamp_with_conv": True,
28
+ "conditional": True,
29
+ "fir": True,
30
+ "fir_kernel": [1, 3, 3, 1],
31
+ "skip_rescale": True,
32
+ "resblock_type": "biggan",
33
+ "progressive": "none",
34
+ "progressive_input": "residual",
35
+ "progressive_combine": "sum",
36
  "embedding_type": "positional",
37
+ "fourier_scale": 16.0,
38
+ "not_use_tanh": False,
39
+ "image_size": 256,
40
+ "nz": 100,
41
+ "num_timesteps": 4,
42
  "z_emb_dim": 256,
43
+ "t_emb_dim": 256,
44
+ "text_encoder": "google/t5-v1_1-base",
45
+ "masked_mean": True,
46
+ "cross_attention_block": "basic",
47
+ }
 
 
 
48
  }
49
  def ddgan_cc12m_v2():
50
  cfg = base()
 
83
  cfg = base()
84
  cfg['model']['text_encoder'] = "google/t5-v1_1-large"
85
  cfg['model']['classifier_free_guidance_proba'] = 0.2
86
+ cfg['model']['cross_attention'] = True
87
  return cfg
88
 
89
  def ddgan_cc12m_v12():
 
113
  cfg['model']['image_size'] = 32
114
  cfg['model']['classifier_free_guidance_proba'] = 0.2
115
  cfg['model']['ch_mult'] = "1 2 2 2"
116
+ cfg['model']['cross_attention'] = True
117
  cfg['model']['dataset'] = "cifar10"
118
  cfg['model']['n_mlp'] = 4
119
  return cfg
 
287
 
288
  def ddgan_ddb_v9():
289
  cfg = ddgan_ddb_v3()
290
+ cfg['model']['attn_resolutions'] = [4, 8, 16, 32]
291
  return cfg
292
 
293
  def ddgan_laion_aesthetic_v15():
 
324
  ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
325
  ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
326
  ddgan_cc12m_v15, # fine-tune v11 with --mismatch_loss and --grad_penalty_cond
327
+
328
  ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
329
  ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
330
  ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
 
364
  ddgan_ddb_v12,
365
  ]
366
 
367
+ def get_model_config(model_name):
368
  for model in models:
369
  if model.__name__ == model_name:
370
+ return model()['model']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_ddgan.py CHANGED
@@ -11,8 +11,32 @@ import time
11
  import os
12
  import json
13
  import torchvision
 
 
14
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from encoder import build_encoder
 
 
 
16
 
17
  #%% Diffusion coefficients
18
  def var_func_vp(t, beta_min, beta_max):
@@ -138,6 +162,12 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
138
  return x
139
 
140
 
 
 
 
 
 
 
141
  def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
142
  x = x_init
143
  null = text_encoder([""] * len(x_init), return_only_pooled=False)
@@ -353,106 +383,84 @@ def get_fold_unfold(x, kernel_size, stride, split_input_params, uf=1, df=1): #
353
 
354
  return fold, unfold, normalization, weighting
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
 
358
  #%%
359
- def sample_and_test(args):
360
- torch.manual_seed(args.seed)
361
 
362
- device = 'cuda:0'
363
- text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
364
- args.cond_size = text_encoder.output_size
365
- if args.dataset == 'cifar10':
366
- real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
367
- elif args.dataset == 'celeba_256':
368
- real_img_dir = 'pytorch_fid/celeba_256_stat.npy'
369
- elif args.dataset == 'lsun':
370
- real_img_dir = 'pytorch_fid/lsun_church_stat.npy'
371
- else:
372
- real_img_dir = args.real_img_dir
373
-
374
- to_range_0_1 = lambda x: (x + 1.) / 2.
375
 
376
- print(vars(args))
377
- netG = NCSNpp(args).to(device)
378
-
 
379
  if args.epoch_id == -1:
380
  epochs = range(1000)
381
  else:
382
  epochs = [args.epoch_id]
383
  if args.compute_image_reward:
384
- import ImageReward as RM
385
  #image_reward = RM.load("ImageReward-v1.0", download_root=".").to(device)
386
  image_reward = RM.load("ImageReward.pt", download_root=".").to(device)
387
-
388
  for epoch in epochs:
389
  args.epoch_id = epoch
390
- path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
391
- next_next_path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id+2)
 
 
392
  if not os.path.exists(path):
393
  continue
394
  if not os.path.exists(next_next_path):
395
  break
396
  print("PATH", path)
397
-
398
- #if not os.path.exists(next_path):
399
- # print(f"STOP at {epoch}")
400
- # break
401
- try:
402
- ckpt = torch.load(path, map_location=device)
403
- except Exception:
404
- continue
405
  suffix = '_' + args.eval_name if args.eval_name else ""
406
- dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
407
  if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
408
  continue
409
- print("Eval Epoch", args.epoch_id)
410
- #loading weights from ddp in single gpu
411
- #print(ckpt.keys())
412
- for key in list(ckpt.keys()):
413
- if key.startswith("module"):
414
- ckpt[key[7:]] = ckpt.pop(key)
415
- netG.load_state_dict(ckpt)
416
- netG.eval()
417
-
418
-
419
- T = get_time_schedule(args, device)
420
-
421
- pos_coeff = Posterior_Coefficients(args, device)
422
-
423
-
424
- save_dir = "./generated_samples/{}".format(args.dataset)
425
 
426
  if not os.path.exists(save_dir):
427
  os.makedirs(save_dir)
428
 
429
 
430
  if args.compute_fid or args.compute_clip_score or args.compute_image_reward:
431
- from torch.nn.functional import adaptive_avg_pool2d
432
- from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
433
- from pytorch_fid.inception import InceptionV3
434
- import random
435
  random.seed(args.seed)
436
  texts = open(args.cond_text).readlines()
437
  texts = [t.strip() for t in texts]
438
  if args.nb_images_for_fid:
439
  random.shuffle(texts)
440
  texts = texts[0:args.nb_images_for_fid]
441
- #iters_needed = len(texts) // args.batch_size
442
- #texts = list(map(lambda s:s.strip(), texts))
443
- #ntimes = max(30000 // len(texts), 1)
444
- #texts = texts * ntimes
445
  print("Text size:", len(texts))
446
- #print("Iters:", iters_needed)
447
  i = 0
448
-
449
  if args.compute_fid:
450
  dims = 2048
451
  block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
452
  inceptionv3 = InceptionV3([block_idx]).to(device)
453
 
454
  if args.compute_clip_score:
455
- import clip
456
  CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
457
  CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
458
  clip_model, preprocess = clip.load(args.clip_model, device)
@@ -481,14 +489,14 @@ def sample_and_test(args):
481
  for b in range(0, len(texts), args.batch_size):
482
  text = texts[b:b+args.batch_size]
483
  with torch.no_grad():
484
- cond = text_encoder(text, return_only_pooled=False)
485
  bs = len(text)
486
  t0 = time.time()
487
- x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
488
  if args.guidance_scale:
489
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
490
  else:
491
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
492
  fake_sample = to_range_0_1(fake_sample)
493
 
494
  if args.compute_fid:
@@ -513,8 +521,8 @@ def sample_and_test(args):
513
  clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
514
 
515
  if args.compute_image_reward:
516
- for k, sample in enumerate(fake_sample):
517
- img = sample.cpu().numpy().transpose(1,2,0)
518
  img = img * 255
519
  img = img.astype(np.uint8)
520
  text_k = text[k]
@@ -542,7 +550,8 @@ def sample_and_test(args):
542
  with open(dest, "w") as fd:
543
  json.dump(results, fd)
544
  print(results)
545
- else:
 
546
  if args.cond_text.endswith(".txt"):
547
  texts = open(args.cond_text).readlines()
548
  texts = [t.strip() for t in texts]
@@ -550,7 +559,6 @@ def sample_and_test(args):
550
  texts = [args.cond_text] * args.batch_size
551
  clip_guidance = False
552
  if clip_guidance:
553
- from clip_encoder import CLIPImageEncoder
554
  cond = text_encoder(texts, return_only_pooled=False)
555
  clip_image_model = CLIPImageEncoder().to(device)
556
  x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
@@ -559,14 +567,14 @@ def sample_and_test(args):
559
  torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
560
 
561
  else:
562
- cond = text_encoder(texts, return_only_pooled=False)
563
- x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
564
  t0 = time.time()
565
  if args.guidance_scale:
566
  if args.scale_factor_h > 1 or args.scale_factor_w > 1:
567
  if args.scale_method == "convolutional":
568
  split_input_params = {
569
- "ks": (args.image_size, args.image_size),
570
  "stride": (150, 150),
571
  "clip_max_tie_weight": 0.5,
572
  "clip_min_tie_weight": 0.01,
@@ -583,22 +591,17 @@ def sample_and_test(args):
583
  else:
584
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
585
  else:
586
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
587
 
588
  print(time.time() - t0)
589
  fake_sample = to_range_0_1(fake_sample)
590
- torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
591
-
592
-
593
-
594
-
595
-
596
-
597
 
598
  if __name__ == '__main__':
599
  parser = argparse.ArgumentParser('ddgan parameters')
600
- parser.add_argument('--seed', type=int, default=1024,
601
- help='seed used for initialization')
 
602
  parser.add_argument('--compute_fid', action='store_true', default=False,
603
  help='whether or not compute FID')
604
  parser.add_argument('--compute_clip_score', action='store_true', default=False,
@@ -608,92 +611,15 @@ if __name__ == '__main__':
608
 
609
  parser.add_argument('--clip_model', type=str,default="ViT-L/14")
610
  parser.add_argument('--eval_name', type=str,default="")
611
-
612
- parser.add_argument('--epoch_id', type=int,default=1000)
613
  parser.add_argument('--guidance_scale', type=float,default=0)
614
  parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
615
- parser.add_argument('--cond_text', type=str,default="0")
616
  parser.add_argument('--scale_factor_h', type=int,default=1)
617
  parser.add_argument('--scale_factor_w', type=int,default=1)
618
  parser.add_argument('--scale_method', type=str,default="convolutional")
619
-
620
- parser.add_argument('--cross_attention', action='store_true',default=False)
621
-
622
-
623
- parser.add_argument('--num_channels', type=int, default=3,
624
- help='channel of image')
625
- parser.add_argument('--centered', action='store_false', default=True,
626
- help='-1,1 scale')
627
- parser.add_argument('--use_geometric', action='store_true',default=False)
628
- parser.add_argument('--beta_min', type=float, default= 0.1,
629
- help='beta_min for diffusion')
630
- parser.add_argument('--beta_max', type=float, default=20.,
631
- help='beta_max for diffusion')
632
-
633
-
634
- parser.add_argument('--num_channels_dae', type=int, default=128,
635
- help='number of initial channels in denosing model')
636
- parser.add_argument('--n_mlp', type=int, default=3,
637
- help='number of mlp layers for z')
638
- parser.add_argument('--ch_mult', nargs='+', type=int,
639
- help='channel multiplier')
640
-
641
- parser.add_argument('--num_res_blocks', type=int, default=2,
642
- help='number of resnet blocks per scale')
643
- parser.add_argument('--attn_resolutions', default=(16,), nargs='+', type=int,
644
- help='resolution of applying attention')
645
- parser.add_argument('--dropout', type=float, default=0.,
646
- help='drop-out rate')
647
- parser.add_argument('--resamp_with_conv', action='store_false', default=True,
648
- help='always up/down sampling with conv')
649
- parser.add_argument('--conditional', action='store_false', default=True,
650
- help='noise conditional')
651
- parser.add_argument('--fir', action='store_false', default=True,
652
- help='FIR')
653
- parser.add_argument('--fir_kernel', default=[1, 3, 3, 1],
654
- help='FIR kernel')
655
- parser.add_argument('--skip_rescale', action='store_false', default=True,
656
- help='skip rescale')
657
- parser.add_argument('--resblock_type', default='biggan',
658
- help='tyle of resnet block, choice in biggan and ddpm')
659
- parser.add_argument('--progressive', type=str, default='none', choices=['none', 'output_skip', 'residual'],
660
- help='progressive type for output')
661
- parser.add_argument('--progressive_input', type=str, default='residual', choices=['none', 'input_skip', 'residual'],
662
- help='progressive type for input')
663
- parser.add_argument('--progressive_combine', type=str, default='sum', choices=['sum', 'cat'],
664
- help='progressive combine method.')
665
-
666
- parser.add_argument('--embedding_type', type=str, default='positional', choices=['positional', 'fourier'],
667
- help='type of time embedding')
668
- parser.add_argument('--fourier_scale', type=float, default=16.,
669
- help='scale of fourier transform')
670
- parser.add_argument('--not_use_tanh', action='store_true',default=False)
671
-
672
- #geenrator and training
673
- parser.add_argument('--exp', default='experiment_cifar_default', help='name of experiment')
674
- parser.add_argument('--real_img_dir', default='./pytorch_fid/cifar10_train_stat.npy', help='directory to real images for FID computation')
675
-
676
- parser.add_argument('--dataset', default='cifar10', help='name of dataset')
677
- parser.add_argument('--image_size', type=int, default=32,
678
- help='size of image')
679
-
680
- parser.add_argument('--nz', type=int, default=100)
681
- parser.add_argument('--num_timesteps', type=int, default=4)
682
-
683
-
684
- parser.add_argument('--z_emb_dim', type=int, default=256)
685
- parser.add_argument('--t_emb_dim', type=int, default=256)
686
- parser.add_argument('--batch_size', type=int, default=200, help='sample generating batch size')
687
- parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
688
- parser.add_argument('--masked_mean', action='store_true',default=False)
689
  parser.add_argument('--nb_images_for_fid', type=int, default=0)
690
-
691
-
692
-
693
-
694
-
695
  args = parser.parse_args()
696
-
697
  sample_and_test(args)
698
 
699
 
 
11
  import os
12
  import json
13
  import torchvision
14
+ import random
15
+
16
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
17
+ from torch.nn.functional import adaptive_avg_pool2d
18
+
19
+ try:
20
+ from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
21
+ from pytorch_fid.inception import InceptionV3
22
+ except ImportError:
23
+ pass
24
+
25
+ try:
26
+ import ImageReward as RM
27
+ except ImportError:
28
+ pass
29
+
30
+
31
+ try:
32
+ import clip
33
+ except ImportError:
34
+ pass
35
+
36
  from encoder import build_encoder
37
+ from clip_encoder import CLIPImageEncoder
38
+
39
+ from model_configs import get_model_config
40
 
41
  #%% Diffusion coefficients
42
  def var_func_vp(t, beta_min, beta_max):
 
162
  return x
163
 
164
 
165
+ def sample(generator, x_init, cond=None):
166
+ return sample_from_model(
167
+ generator.pos_coeff, generator, n_time=generator.config.num_timesteps, x_init=x_init,
168
+ T=generator.time_schedule, opt=generator.config, cond=cond
169
+ )
170
+
171
  def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
172
  x = x_init
173
  null = text_encoder([""] * len(x_init), return_only_pooled=False)
 
383
 
384
  return fold, unfold, normalization, weighting
385
 
386
+ class ObjectFromDict:
387
+ def __init__(self, d):
388
+ self.__dict__ = d
389
+
390
+ def load_model(config, path, device="cpu"):
391
+ config = ObjectFromDict(config)
392
+ text_encoder = build_encoder(name=config.text_encoder, masked_mean=config.masked_mean)
393
+ config.cond_size = text_encoder.output_size
394
+ netG = NCSNpp(config)
395
+ ckpt = torch.load(path, map_location="cpu")
396
+ for key in list(ckpt.keys()):
397
+ if key.startswith("module"):
398
+ ckpt[key[7:]] = ckpt.pop(key)
399
+ netG.load_state_dict(ckpt)
400
+ netG.eval()
401
+ netG.pos_coeff = Posterior_Coefficients(config, device)
402
+ netG.text_encoder = text_encoder
403
+ netG.config = config
404
+ netG.time_schedule = get_time_schedule(config, device)
405
+ netG = netG.to(device)
406
+ return netG
407
 
408
 
409
  #%%
 
 
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
+ def sample_and_test(args):
413
+ torch.manual_seed(args.seed)
414
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
415
+ to_range_0_1 = lambda x: (x + 1.) / 2.
416
  if args.epoch_id == -1:
417
  epochs = range(1000)
418
  else:
419
  epochs = [args.epoch_id]
420
  if args.compute_image_reward:
 
421
  #image_reward = RM.load("ImageReward-v1.0", download_root=".").to(device)
422
  image_reward = RM.load("ImageReward.pt", download_root=".").to(device)
423
+ cfg = get_model_config(args.name)
424
  for epoch in epochs:
425
  args.epoch_id = epoch
426
+
427
+ path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(cfg['dataset'], args.name, args.epoch_id)
428
+ next_next_path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(cfg['dataset'], args.name, args.epoch_id+2)
429
+ print(path)
430
  if not os.path.exists(path):
431
  continue
432
  if not os.path.exists(next_next_path):
433
  break
434
  print("PATH", path)
 
 
 
 
 
 
 
 
435
  suffix = '_' + args.eval_name if args.eval_name else ""
436
+ dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(cfg['dataset'],'ddgan', args.epoch_id, suffix)
437
  if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
438
  continue
439
+ print("Load epoch", args.epoch_id, "checkpoint")
440
+
441
+ netG = load_model(cfg, path, device=device)
442
+ save_dir = "./generated_samples/{}".format(cfg['dataset'])
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  if not os.path.exists(save_dir):
445
  os.makedirs(save_dir)
446
 
447
 
448
  if args.compute_fid or args.compute_clip_score or args.compute_image_reward:
449
+ # Evaluate
 
 
 
450
  random.seed(args.seed)
451
  texts = open(args.cond_text).readlines()
452
  texts = [t.strip() for t in texts]
453
  if args.nb_images_for_fid:
454
  random.shuffle(texts)
455
  texts = texts[0:args.nb_images_for_fid]
 
 
 
 
456
  print("Text size:", len(texts))
 
457
  i = 0
 
458
  if args.compute_fid:
459
  dims = 2048
460
  block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
461
  inceptionv3 = InceptionV3([block_idx]).to(device)
462
 
463
  if args.compute_clip_score:
 
464
  CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
465
  CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
466
  clip_model, preprocess = clip.load(args.clip_model, device)
 
489
  for b in range(0, len(texts), args.batch_size):
490
  text = texts[b:b+args.batch_size]
491
  with torch.no_grad():
492
+ cond = netG.text_encoder(text)
493
  bs = len(text)
494
  t0 = time.time()
495
+ x_t_1 = torch.randn(bs, cfg['num_channels'], cfg['image_size'], cfg['image_size']).to(device)
496
  if args.guidance_scale:
497
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
498
  else:
499
+ fake_sample = sample(generator=model, x_init=x_init, cond=cond)
500
  fake_sample = to_range_0_1(fake_sample)
501
 
502
  if args.compute_fid:
 
521
  clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
522
 
523
  if args.compute_image_reward:
524
+ for k, img in enumerate(fake_sample):
525
+ img = img.cpu().numpy().transpose(1,2,0)
526
  img = img * 255
527
  img = img.astype(np.uint8)
528
  text_k = text[k]
 
550
  with open(dest, "w") as fd:
551
  json.dump(results, fd)
552
  print(results)
553
+ else:
554
+ # just generate some samples
555
  if args.cond_text.endswith(".txt"):
556
  texts = open(args.cond_text).readlines()
557
  texts = [t.strip() for t in texts]
 
559
  texts = [args.cond_text] * args.batch_size
560
  clip_guidance = False
561
  if clip_guidance:
 
562
  cond = text_encoder(texts, return_only_pooled=False)
563
  clip_image_model = CLIPImageEncoder().to(device)
564
  x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
 
567
  torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
568
 
569
  else:
570
+ cond = netG.text_encoder(texts)
571
+ x_t_1 = torch.randn(len(texts), cfg['num_channels'], cfg['image_size'] * args.scale_factor_h, cfg['image_size'] * args.scale_factor_w).to(device)
572
  t0 = time.time()
573
  if args.guidance_scale:
574
  if args.scale_factor_h > 1 or args.scale_factor_w > 1:
575
  if args.scale_method == "convolutional":
576
  split_input_params = {
577
+ "ks": (cfg['image_size'], cfg['image_size']),
578
  "stride": (150, 150),
579
  "clip_max_tie_weight": 0.5,
580
  "clip_min_tie_weight": 0.01,
 
591
  else:
592
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
593
  else:
594
+ fake_sample = sample(generator=netG, x_init=x_t_1, cond=cond)
595
 
596
  print(time.time() - t0)
597
  fake_sample = to_range_0_1(fake_sample)
598
+ torchvision.utils.save_image(fake_sample, 'samples.jpg')
 
 
 
 
 
 
599
 
600
  if __name__ == '__main__':
601
  parser = argparse.ArgumentParser('ddgan parameters')
602
+ parser.add_argument('--name', type=str, default="", help="model config name")
603
+ parser.add_argument('--batch_size', type=int, default=16)
604
+ parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
605
  parser.add_argument('--compute_fid', action='store_true', default=False,
606
  help='whether or not compute FID')
607
  parser.add_argument('--compute_clip_score', action='store_true', default=False,
 
611
 
612
  parser.add_argument('--clip_model', type=str,default="ViT-L/14")
613
  parser.add_argument('--eval_name', type=str,default="")
614
+ parser.add_argument('--epoch_id', type=int,default=-1)
 
615
  parser.add_argument('--guidance_scale', type=float,default=0)
616
  parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
617
+ parser.add_argument('--cond_text', type=str,default="a chair in the form of an avocado")
618
  parser.add_argument('--scale_factor_h', type=int,default=1)
619
  parser.add_argument('--scale_factor_w', type=int,default=1)
620
  parser.add_argument('--scale_method', type=str,default="convolutional")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  parser.add_argument('--nb_images_for_fid', type=int, default=0)
 
 
 
 
 
622
  args = parser.parse_args()
 
623
  sample_and_test(args)
624
 
625