Mehdi Cherti commited on
Commit
9e9d0ce
1 Parent(s): 27911d6
Files changed (1) hide show
  1. test_ddgan.py +16 -16
test_ddgan.py CHANGED
@@ -433,7 +433,7 @@ def sample_and_test(args):
433
  break
434
  print("PATH", path)
435
  suffix = '_' + args.eval_name if args.eval_name else ""
436
- dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(cfg['dataset'],'ddgan', args.epoch_id, suffix)
437
  if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
438
  continue
439
  print("Load epoch", args.epoch_id, "checkpoint")
@@ -496,7 +496,7 @@ def sample_and_test(args):
496
  if args.guidance_scale:
497
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
498
  else:
499
- fake_sample = sample(generator=model, x_init=x_init, cond=cond)
500
  fake_sample = to_range_0_1(fake_sample)
501
 
502
  if args.compute_fid:
@@ -600,25 +600,25 @@ def sample_and_test(args):
600
  if __name__ == '__main__':
601
  parser = argparse.ArgumentParser('ddgan parameters')
602
  parser.add_argument('--name', type=str, default="", help="model config name")
603
- parser.add_argument('--batch_size', type=int, default=16)
604
  parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
605
- parser.add_argument('--compute_fid', action='store_true', default=False,
606
  help='whether or not compute FID')
607
- parser.add_argument('--compute_clip_score', action='store_true', default=False,
608
  help='whether or not compute CLIP score')
609
- parser.add_argument('--compute_image_reward', action='store_true', default=False,
610
  help='whether or not compute CLIP score')
611
 
612
- parser.add_argument('--clip_model', type=str,default="ViT-L/14")
613
- parser.add_argument('--eval_name', type=str,default="")
614
- parser.add_argument('--epoch_id', type=int,default=-1)
615
- parser.add_argument('--guidance_scale', type=float,default=0)
616
- parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
617
- parser.add_argument('--cond_text', type=str,default="a chair in the form of an avocado")
618
- parser.add_argument('--scale_factor_h', type=int,default=1)
619
- parser.add_argument('--scale_factor_w', type=int,default=1)
620
- parser.add_argument('--scale_method', type=str,default="convolutional")
621
- parser.add_argument('--nb_images_for_fid', type=int, default=0)
622
  args = parser.parse_args()
623
  sample_and_test(args)
624
 
 
433
  break
434
  print("PATH", path)
435
  suffix = '_' + args.eval_name if args.eval_name else ""
436
+ dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(cfg['dataset'], args.name, args.epoch_id, suffix)
437
  if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
438
  continue
439
  print("Load epoch", args.epoch_id, "checkpoint")
 
496
  if args.guidance_scale:
497
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
498
  else:
499
+ fake_sample = sample(generator=netG, x_init=x_init, cond=cond)
500
  fake_sample = to_range_0_1(fake_sample)
501
 
502
  if args.compute_fid:
 
600
  if __name__ == '__main__':
601
  parser = argparse.ArgumentParser('ddgan parameters')
602
  parser.add_argument('--name', type=str, default="", help="model config name")
603
+ parser.add_argument('--batch-size', type=int, default=16)
604
  parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
605
+ parser.add_argument('--compute-fid', action='store_true', default=False,
606
  help='whether or not compute FID')
607
+ parser.add_argument('--compute-clip-score', action='store_true', default=False,
608
  help='whether or not compute CLIP score')
609
+ parser.add_argument('--compute-image-reward', action='store_true', default=False,
610
  help='whether or not compute CLIP score')
611
 
612
+ parser.add_argument('--clip-model', type=str,default="ViT-L/14")
613
+ parser.add_argument('--eval-name', type=str,default="")
614
+ parser.add_argument('--epoch-id', type=int,default=-1)
615
+ parser.add_argument('--guidance-scale', type=float,default=0)
616
+ parser.add_argument('--dynamic-thresholding-quantile', type=float,default=0)
617
+ parser.add_argument('--cond-text', type=str,default="a chair in the form of an avocado")
618
+ parser.add_argument('--scale-factor-h', type=int,default=1)
619
+ parser.add_argument('--scale-factor-w', type=int,default=1)
620
+ parser.add_argument('--scale-method', type=str,default="convolutional")
621
+ parser.add_argument('--nb-images-for-fid', type=int, default=0)
622
  args = parser.parse_args()
623
  sample_and_test(args)
624