Mehdi Cherti commited on
Commit
8d2bdec
1 Parent(s): 23d6920

support fid eval on several epochs

Browse files
Files changed (1) hide show
  1. test_ddgan.py +139 -108
test_ddgan.py CHANGED
@@ -130,14 +130,18 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
130
  def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
131
  x = x_init
132
  null = text_encoder([""] * len(x_init), return_only_pooled=False)
 
133
  with torch.no_grad():
134
  for i in reversed(range(n_time)):
135
  t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
136
-
137
  t_time = t
138
- latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
 
139
 
140
  x_0_uncond = generator(x, t_time, latent_z, cond=null)
 
 
 
141
  x_0_cond = generator(x, t_time, latent_z, cond=cond)
142
 
143
  eps_uncond = (x - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_uncond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
@@ -149,8 +153,8 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
149
 
150
 
151
  # Dynamic thresholding
152
- q = args.dynamic_thresholding_percentile
153
- print("Before", x_0.min(), x_0.max())
154
  if q:
155
  shape = x_0.shape
156
  x_0_v = x_0.view(shape[0], -1)
@@ -158,7 +162,7 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
158
  d.clamp_(min=1)
159
  x_0_v = x_0_v.clamp(-d, d) / d
160
  x_0 = x_0_v.view(shape)
161
- print("After", x_0.min(), x_0.max())
162
 
163
  x_new = sample_posterior(coefficients, x_0, x, t)
164
 
@@ -197,112 +201,138 @@ def sample_and_test(args):
197
 
198
 
199
  netG = NCSNpp(args).to(device)
200
- ckpt = torch.load('./saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id), map_location=device)
201
-
202
- #loading weights from ddp in single gpu
203
- for key in list(ckpt.keys()):
204
- ckpt[key[7:]] = ckpt.pop(key)
205
- netG.load_state_dict(ckpt)
206
- netG.eval()
207
-
208
-
209
- T = get_time_schedule(args, device)
210
 
211
- pos_coeff = Posterior_Coefficients(args, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
-
214
- save_dir = "./generated_samples/{}".format(args.dataset)
215
-
216
- if not os.path.exists(save_dir):
217
- os.makedirs(save_dir)
218
-
219
- if args.compute_fid:
220
- from torch.nn.functional import adaptive_avg_pool2d
221
- from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
222
- from pytorch_fid.inception import InceptionV3
223
-
224
- texts = open(args.cond_text).readlines()
225
- #iters_needed = len(texts) // args.batch_size
226
- #texts = list(map(lambda s:s.strip(), texts))
227
- #ntimes = max(30000 // len(texts), 1)
228
- #texts = texts * ntimes
229
- print("Text size:", len(texts))
230
- #print("Iters:", iters_needed)
231
- i = 0
232
- dims = 2048
233
- block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
234
- inceptionv3 = InceptionV3([block_idx]).to(device)
235
-
236
- if not args.real_img_dir.endswith("npz"):
237
- real_mu, real_sigma = compute_statistics_of_path(
238
- args.real_img_dir, inceptionv3, args.batch_size, dims, device,
239
- resize=args.image_size,
240
- )
241
- np.savez("inception_statistics.npz", mu=real_mu, sigma=real_sigma)
242
- else:
243
- stats = np.load(args.real_img_dir)
244
- real_mu = stats['mu']
245
- real_sigma = stats['sigma']
246
-
247
- fake_features = []
248
- for b in range(0, len(texts), args.batch_size):
249
- text = texts[b:b+args.batch_size]
250
- with torch.no_grad():
251
- cond = text_encoder(text, return_only_pooled=False)
252
- bs = len(text)
253
- t0 = time.time()
254
- x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
255
- if args.guidance_scale:
256
- fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
257
- else:
258
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
259
- fake_sample = to_range_0_1(fake_sample)
260
- """
261
- for j, x in enumerate(fake_sample):
262
- index = i * args.batch_size + j
263
- torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
264
- """
265
  with torch.no_grad():
266
- pred = inceptionv3(fake_sample)[0]
267
- # If model output is not scalar, apply global spatial average pooling.
268
- # This happens if you choose a dimensionality not equal 2048.
269
- if pred.size(2) != 1 or pred.size(3) != 1:
270
- pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
271
- pred = pred.squeeze(3).squeeze(2).cpu().numpy()
272
- fake_features.append(pred)
273
- if i % 10 == 0:
274
- print('generating batch ', i, time.time() - t0)
275
- """
276
- if i % 10 == 0:
277
- ff = np.concatenate(fake_features)
278
- fake_mu = np.mean(ff, axis=0)
279
- fake_sigma = np.cov(ff, rowvar=False)
280
- fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
281
- print("FID", fid)
282
- """
283
- i += 1
284
-
285
- fake_features = np.concatenate(fake_features)
286
- fake_mu = np.mean(fake_features, axis=0)
287
- fake_sigma = np.cov(fake_features, rowvar=False)
288
- fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
289
- dest = './saved_info/dd_gan/{}/{}/fid_{}.json'.format(args.dataset, args.exp, args.epoch_id)
290
- results = {
291
- "fid": fid,
292
- }
293
- results.update(vars(args))
294
- with open(dest, "w") as fd:
295
- json.dump(results, fd)
296
- print('FID = {}'.format(fid))
297
- else:
298
- cond = text_encoder([args.cond_text] * args.batch_size, return_only_pooled=False)
299
- x_t_1 = torch.randn(args.batch_size, args.num_channels,args.image_size, args.image_size).to(device)
300
- if args.guidance_scale:
301
- fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
 
 
 
 
 
 
 
 
 
 
302
  else:
303
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
304
- fake_sample = to_range_0_1(fake_sample)
305
- torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
 
 
 
 
 
 
 
 
 
 
306
 
307
 
308
 
@@ -316,7 +346,7 @@ if __name__ == '__main__':
316
  help='whether or not compute FID')
317
  parser.add_argument('--epoch_id', type=int,default=1000)
318
  parser.add_argument('--guidance_scale', type=float,default=0)
319
- parser.add_argument('--dynamic_thresholding_percentile', type=float,default=0)
320
  parser.add_argument('--cond_text', type=str,default="0")
321
 
322
  parser.add_argument('--cross_attention', action='store_true',default=False)
@@ -388,6 +418,7 @@ if __name__ == '__main__':
388
  parser.add_argument('--batch_size', type=int, default=200, help='sample generating batch size')
389
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
390
  parser.add_argument('--masked_mean', action='store_true',default=False)
 
391
 
392
 
393
 
 
130
  def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
131
  x = x_init
132
  null = text_encoder([""] * len(x_init), return_only_pooled=False)
133
+ latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
134
  with torch.no_grad():
135
  for i in reversed(range(n_time)):
136
  t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
 
137
  t_time = t
138
+
139
+ #latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
140
 
141
  x_0_uncond = generator(x, t_time, latent_z, cond=null)
142
+
143
+ #latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
144
+
145
  x_0_cond = generator(x, t_time, latent_z, cond=cond)
146
 
147
  eps_uncond = (x - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_uncond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
 
153
 
154
 
155
  # Dynamic thresholding
156
+ q = args.dynamic_thresholding_quantile
157
+ #print("Before", x_0.min(), x_0.max())
158
  if q:
159
  shape = x_0.shape
160
  x_0_v = x_0.view(shape[0], -1)
 
162
  d.clamp_(min=1)
163
  x_0_v = x_0_v.clamp(-d, d) / d
164
  x_0 = x_0_v.view(shape)
165
+ #print("After", x_0.min(), x_0.max())
166
 
167
  x_new = sample_posterior(coefficients, x_0, x, t)
168
 
 
201
 
202
 
203
  netG = NCSNpp(args).to(device)
204
+
205
+
206
+ if args.epoch_id == -1:
207
+ epochs = range(1000)
208
+ else:
209
+ epochs = [args.epoch_id]
 
 
 
 
210
 
211
+ for epoch in epochs:
212
+ args.epoch_id = epoch
213
+ path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
214
+ if not os.path.exists(path):
215
+ continue
216
+ ckpt = torch.load(path, map_location=device)
217
+ dest = './saved_info/dd_gan/{}/{}/fid_{}.json'.format(args.dataset, args.exp, args.epoch_id)
218
+
219
+ if args.compute_fid and os.path.exists(dest):
220
+ continue
221
+ print("Eval Epoch", args.epoch_id)
222
+ #loading weights from ddp in single gpu
223
+ for key in list(ckpt.keys()):
224
+ ckpt[key[7:]] = ckpt.pop(key)
225
+ netG.load_state_dict(ckpt)
226
+ netG.eval()
227
 
228
+
229
+ T = get_time_schedule(args, device)
230
+
231
+ pos_coeff = Posterior_Coefficients(args, device)
232
+
233
+
234
+ save_dir = "./generated_samples/{}".format(args.dataset)
235
+
236
+ if not os.path.exists(save_dir):
237
+ os.makedirs(save_dir)
238
+
239
+ if args.compute_fid:
240
+ from torch.nn.functional import adaptive_avg_pool2d
241
+ from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
242
+ from pytorch_fid.inception import InceptionV3
243
+ import random
244
+ random.seed(args.seed)
245
+ texts = open(args.cond_text).readlines()
246
+ texts = [t.strip() for t in texts]
247
+ if args.nb_images_for_fid:
248
+ random.shuffle(texts)
249
+ texts = texts[0:args.nb_images_for_fid]
250
+ #iters_needed = len(texts) // args.batch_size
251
+ #texts = list(map(lambda s:s.strip(), texts))
252
+ #ntimes = max(30000 // len(texts), 1)
253
+ #texts = texts * ntimes
254
+ print("Text size:", len(texts))
255
+ #print("Iters:", iters_needed)
256
+ i = 0
257
+ dims = 2048
258
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
259
+ inceptionv3 = InceptionV3([block_idx]).to(device)
260
+
261
+ if not args.real_img_dir.endswith("npz"):
262
+ real_mu, real_sigma = compute_statistics_of_path(
263
+ args.real_img_dir, inceptionv3, args.batch_size, dims, device,
264
+ resize=args.image_size,
265
+ )
266
+ np.savez("inception_statistics.npz", mu=real_mu, sigma=real_sigma)
267
+ else:
268
+ stats = np.load(args.real_img_dir)
269
+ real_mu = stats['mu']
270
+ real_sigma = stats['sigma']
271
+
272
+ fake_features = []
273
+ for b in range(0, len(texts), args.batch_size):
274
+ text = texts[b:b+args.batch_size]
 
 
 
 
 
275
  with torch.no_grad():
276
+ cond = text_encoder(text, return_only_pooled=False)
277
+ bs = len(text)
278
+ t0 = time.time()
279
+ x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
280
+ if args.guidance_scale:
281
+ fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
282
+ else:
283
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
284
+ fake_sample = to_range_0_1(fake_sample)
285
+ """
286
+ for j, x in enumerate(fake_sample):
287
+ index = i * args.batch_size + j
288
+ torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
289
+ """
290
+ with torch.no_grad():
291
+ pred = inceptionv3(fake_sample)[0]
292
+ # If model output is not scalar, apply global spatial average pooling.
293
+ # This happens if you choose a dimensionality not equal 2048.
294
+ if pred.size(2) != 1 or pred.size(3) != 1:
295
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
296
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
297
+ fake_features.append(pred)
298
+ if i % 10 == 0:
299
+ print('generating batch ', i, time.time() - t0)
300
+ """
301
+ if i % 10 == 0:
302
+ ff = np.concatenate(fake_features)
303
+ fake_mu = np.mean(ff, axis=0)
304
+ fake_sigma = np.cov(ff, rowvar=False)
305
+ fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
306
+ print("FID", fid)
307
+ """
308
+ i += 1
309
+
310
+ fake_features = np.concatenate(fake_features)
311
+ fake_mu = np.mean(fake_features, axis=0)
312
+ fake_sigma = np.cov(fake_features, rowvar=False)
313
+ fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
314
+ dest = './saved_info/dd_gan/{}/{}/fid_{}.json'.format(args.dataset, args.exp, args.epoch_id)
315
+ results = {
316
+ "fid": fid,
317
+ }
318
+ results.update(vars(args))
319
+ with open(dest, "w") as fd:
320
+ json.dump(results, fd)
321
+ print('FID = {}'.format(fid))
322
  else:
323
+ if args.cond_text.endswith(".txt"):
324
+ texts = open(args.cond_text).readlines()
325
+ texts = [t.strip() for t in texts]
326
+ else:
327
+ texts = [args.cond_text] * args.batch_size
328
+ cond = text_encoder(texts, return_only_pooled=False)
329
+ x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size, args.image_size).to(device)
330
+ if args.guidance_scale:
331
+ fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
332
+ else:
333
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
334
+ fake_sample = to_range_0_1(fake_sample)
335
+ torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
336
 
337
 
338
 
 
346
  help='whether or not compute FID')
347
  parser.add_argument('--epoch_id', type=int,default=1000)
348
  parser.add_argument('--guidance_scale', type=float,default=0)
349
+ parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
350
  parser.add_argument('--cond_text', type=str,default="0")
351
 
352
  parser.add_argument('--cross_attention', action='store_true',default=False)
 
418
  parser.add_argument('--batch_size', type=int, default=200, help='sample generating batch size')
419
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
420
  parser.add_argument('--masked_mean', action='store_true',default=False)
421
+ parser.add_argument('--nb_images_for_fid', type=int, default=0)
422
 
423
 
424