LuChengTHU commited on
Commit
183c3a1
1 Parent(s): 531ea40

add a stablizing trick for steps < 15

Browse files

Former-commit-id: bf3b8783543bdbfc31721479091e35696baadd13

ldm/models/diffusion/dpm_solver/dpm_solver.py CHANGED
@@ -394,8 +394,8 @@ class DPM_Solver:
394
  if self.thresholding:
395
  p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
396
  s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
397
- s = expand_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), dims)
398
- x0 = torch.clamp(x0, -s, s) / (s / self.max_val)
399
  return x0
400
 
401
  def model_fn(self, x, t):
@@ -436,7 +436,7 @@ class DPM_Solver:
436
  else:
437
  raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
438
 
439
- def get_orders_for_singlestep_solver(self, steps, order):
440
  """
441
  Get the order of each step for sampling by the singlestep DPM-Solver.
442
 
@@ -458,6 +458,13 @@ class DPM_Solver:
458
  Args:
459
  order: A `int`. The max order for the solver (2 or 3).
460
  steps: A `int`. The total number of function evaluations (NFE).
 
 
 
 
 
 
 
461
  Returns:
462
  orders: A list of the solver order of each step.
463
  """
@@ -469,20 +476,26 @@ class DPM_Solver:
469
  orders = [3,] * (K - 1) + [1]
470
  else:
471
  orders = [3,] * (K - 1) + [2]
472
- return orders
473
  elif order == 2:
474
- K = steps // 2
475
  if steps % 2 == 0:
 
476
  orders = [2,] * K
477
  else:
478
- orders = [2,] * K + [1]
479
- return orders
480
  elif order == 1:
481
- return [1,] * steps
 
482
  else:
483
  raise ValueError("'order' must be '1' or '2' or '3'.")
 
 
 
 
 
 
484
 
485
- def denoise_fn(self, x, s):
486
  """
487
  Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
488
  """
@@ -950,8 +963,8 @@ class DPM_Solver:
950
  return x
951
 
952
  def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
953
- method='singlestep', denoise=False, solver_type='dpm_solver', atol=0.0078,
954
- rtol=0.05,
955
  ):
956
  """
957
  Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
@@ -1035,8 +1048,19 @@ class DPM_Solver:
1035
  order: A `int`. The order of DPM-Solver.
1036
  skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1037
  method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1038
- denoise: A `bool`. Whether to denoise at the final step. Default is False.
1039
- If `denoise` is True, the total NFE is (`steps` + 1).
 
 
 
 
 
 
 
 
 
 
 
1040
  solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1041
  atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1042
  rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
@@ -1067,7 +1091,11 @@ class DPM_Solver:
1067
  # Compute the remaining values by `order`-th order multistep DPM-Solver.
1068
  for step in range(order, steps + 1):
1069
  vec_t = timesteps[step].expand(x.shape[0])
1070
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, order, solver_type=solver_type)
 
 
 
 
1071
  for i in range(order - 1):
1072
  t_prev_list[i] = t_prev_list[i + 1]
1073
  model_prev_list[i] = model_prev_list[i + 1]
@@ -1077,23 +1105,22 @@ class DPM_Solver:
1077
  model_prev_list[-1] = self.model_fn(x, vec_t)
1078
  elif method in ['singlestep', 'singlestep_fixed']:
1079
  if method == 'singlestep':
1080
- orders = self.get_orders_for_singlestep_solver(steps=steps, order=order)
1081
- timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1082
  elif method == 'singlestep_fixed':
1083
  K = steps // order
1084
  orders = [order,] * K
1085
- timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=(K * order), device=device)
1086
- with torch.no_grad():
1087
- i = 0
1088
- for order in orders:
1089
- vec_s, vec_t = timesteps[i].expand(x.shape[0]), timesteps[i + order].expand(x.shape[0])
1090
- h = self.noise_schedule.marginal_lambda(timesteps[i + order]) - self.noise_schedule.marginal_lambda(timesteps[i])
1091
- r1 = None if order <= 1 else (self.noise_schedule.marginal_lambda(timesteps[i + 1]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
1092
- r2 = None if order <= 2 else (self.noise_schedule.marginal_lambda(timesteps[i + 2]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
1093
- x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1094
- i += order
1095
- if denoise:
1096
- x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1097
  return x
1098
 
1099
 
 
394
  if self.thresholding:
395
  p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
396
  s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
397
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
398
+ x0 = torch.clamp(x0, -s, s) / s
399
  return x0
400
 
401
  def model_fn(self, x, t):
 
436
  else:
437
  raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
438
 
439
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
440
  """
441
  Get the order of each step for sampling by the singlestep DPM-Solver.
442
 
 
458
  Args:
459
  order: A `int`. The max order for the solver (2 or 3).
460
  steps: A `int`. The total number of function evaluations (NFE).
461
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
462
+ - 'logSNR': uniform logSNR for the time steps.
463
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
464
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
465
+ t_T: A `float`. The starting time of the sampling (default is T).
466
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
467
+ device: A torch device.
468
  Returns:
469
  orders: A list of the solver order of each step.
470
  """
 
476
  orders = [3,] * (K - 1) + [1]
477
  else:
478
  orders = [3,] * (K - 1) + [2]
 
479
  elif order == 2:
 
480
  if steps % 2 == 0:
481
+ K = steps // 2
482
  orders = [2,] * K
483
  else:
484
+ K = steps // 2 + 1
485
+ orders = [2,] * (K - 1) + [1]
486
  elif order == 1:
487
+ K = 1
488
+ orders = [1,] * steps
489
  else:
490
  raise ValueError("'order' must be '1' or '2' or '3'.")
491
+ if skip_type == 'logSNR':
492
+ # To reproduce the results in DPM-Solver paper
493
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
494
+ else:
495
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)]
496
+ return timesteps_outer, orders
497
 
498
+ def denoise_to_zero_fn(self, x, s):
499
  """
500
  Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
501
  """
 
963
  return x
964
 
965
  def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
966
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
967
+ atol=0.0078, rtol=0.05,
968
  ):
969
  """
970
  Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
 
1048
  order: A `int`. The order of DPM-Solver.
1049
  skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1050
  method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1051
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1052
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1053
+
1054
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1055
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1056
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1057
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1058
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1059
+ it for high-resolutional images.
1060
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1061
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1062
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1063
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1064
  solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1065
  atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1066
  rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
 
1091
  # Compute the remaining values by `order`-th order multistep DPM-Solver.
1092
  for step in range(order, steps + 1):
1093
  vec_t = timesteps[step].expand(x.shape[0])
1094
+ if lower_order_final and steps < 15:
1095
+ step_order = min(order, steps + 1 - step)
1096
+ else:
1097
+ step_order = order
1098
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type)
1099
  for i in range(order - 1):
1100
  t_prev_list[i] = t_prev_list[i + 1]
1101
  model_prev_list[i] = model_prev_list[i + 1]
 
1105
  model_prev_list[-1] = self.model_fn(x, vec_t)
1106
  elif method in ['singlestep', 'singlestep_fixed']:
1107
  if method == 'singlestep':
1108
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
 
1109
  elif method == 'singlestep_fixed':
1110
  K = steps // order
1111
  orders = [order,] * K
1112
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1113
+ for i, order in enumerate(orders):
1114
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1115
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device)
1116
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1117
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1118
+ h = lambda_inner[-1] - lambda_inner[0]
1119
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1120
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1121
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1122
+ if denoise_to_zero:
1123
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1124
  return x
1125
 
1126
 
ldm/models/diffusion/dpm_solver/sampler.py CHANGED
@@ -77,6 +77,6 @@ class DPMSolverSampler(object):
77
  )
78
 
79
  dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
80
- x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2)
81
 
82
  return x.to(device), None
 
77
  )
78
 
79
  dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
80
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
81
 
82
  return x.to(device), None