patrickvonplaten commited on
Commit
6e970c6
1 Parent(s): 1afa2e1

Upload modeling_ddim.py

Browse files
Files changed (1) hide show
  1. modeling_ddim.py +24 -36
modeling_ddim.py CHANGED
@@ -34,61 +34,49 @@ class DDIM(DiffusionPipeline):
34
  inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
35
 
36
  self.unet.to(torch_device)
37
-
38
- # Sample gaussian noise to begin loop
39
  image = self.noise_scheduler.sample_noise(
40
  (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
41
  device=torch_device,
42
  generator=generator,
43
  )
44
 
45
- # See formulas (9), (10) and (7) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
46
- # Ideally, read DDIM paper in-detail understanding
47
-
48
- # Notation (<variable name> -> <name in paper>
49
- # - pred_noise_t -> e_theta(x_t, t)
50
- # - pred_original_image -> f_theta(x_t, t) or x_0
51
- # - std_dev_t -> sigma_t
52
-
53
  for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
54
- # 1. predict noise residual
55
- with torch.no_grad():
56
- pred_noise_t = self.unet(image, inference_step_times[t])
57
-
58
- # 2. get actual t and t-1
59
  train_step = inference_step_times[t]
60
  prev_train_step = inference_step_times[t - 1] if t > 0 else -1
61
 
62
- # 3. compute alphas, betas
63
  alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
64
  alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
 
 
65
  beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
66
  beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
67
 
68
- # 4. Compute predicted previous image from predicted noise
69
- # First: compute predicted original image from predicted noise also called
70
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
71
- pred_original_image = (image - beta_prod_t_sqrt * pred_noise_t) / alpha_prod_t.sqrt()
72
- # Second: Clip "predicted x_0"
73
- pred_original_image = torch.clamp(pred_original_image, -1, 1)
74
- # Third: Compute variance: "sigma_t" -> see
75
- # std_dev_t = (1 - alpha_prod_t / alpha_prod_t_prev).sqrt() * beta_prod_t_prev_sqrt / beta_prod_t_sqrt
76
- std_dev_t = (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
77
- std_dev_t = std_dev_t * eta
78
- # Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
79
- pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
80
-
81
- # Fourth: Compute outer formula (DDIM formula)
82
- pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
 
 
 
83
 
84
  # if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
85
  if eta > 0.0:
86
  noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
87
- prev_image = pred_prev_image + std_dev_t * noise
88
  else:
89
- prev_image = pred_prev_image
90
-
91
- # Set current image to prev_image: x_t -> x_t-1
92
- image = prev_image
93
 
94
  return image
 
34
  inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
35
 
36
  self.unet.to(torch_device)
 
 
37
  image = self.noise_scheduler.sample_noise(
38
  (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
39
  device=torch_device,
40
  generator=generator,
41
  )
42
 
 
 
 
 
 
 
 
 
43
  for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
44
+ # get actual t and t-1
 
 
 
 
45
  train_step = inference_step_times[t]
46
  prev_train_step = inference_step_times[t - 1] if t > 0 else -1
47
 
48
+ # compute alphas
49
  alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
50
  alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
51
+ alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
52
+ alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
53
  beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
54
  beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
55
 
56
+ # compute relevant coefficients
57
+ coeff_1 = (
58
+ (alpha_prod_t_prev - alpha_prod_t).sqrt()
59
+ * alpha_prod_t_prev_rsqrt
60
+ * beta_prod_t_prev_sqrt
61
+ / beta_prod_t_sqrt
62
+ * eta
63
+ )
64
+ coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
65
+
66
+ # model forward
67
+ with torch.no_grad():
68
+ noise_residual = self.unet(image, train_step)
69
+
70
+ # predict mean of prev image
71
+ pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
72
+ pred_mean = torch.clamp(pred_mean, -1, 1)
73
+ pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
74
 
75
  # if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
76
  if eta > 0.0:
77
  noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
78
+ image = pred_mean + coeff_1 * noise
79
  else:
80
+ image = pred_mean
 
 
 
81
 
82
  return image