Dhenenjay commited on
Commit
d52c538
·
verified ·
1 Parent(s): 1ab92c9

Upload diffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion.py +305 -68
diffusion.py CHANGED
@@ -1,12 +1,11 @@
1
- """E3Diff Gaussian Diffusion - exact copy from original with fixed imports."""
2
-
3
  import math
4
  import torch
5
- from torch import nn
6
  import torch.nn.functional as F
7
  from inspect import isfunction
8
  from functools import partial
9
  import numpy as np
 
10
 
11
 
12
  def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
@@ -25,14 +24,19 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
25
  betas = np.linspace(linear_start, linear_end,
26
  n_timestep, dtype=np.float64)
27
  elif schedule == 'warmup10':
28
- betas = _warmup_beta(linear_start, linear_end, n_timestep, 0.1)
 
29
  elif schedule == 'warmup50':
30
- betas = _warmup_beta(linear_start, linear_end, n_timestep, 0.5)
 
31
  elif schedule == 'const':
32
  betas = linear_end * np.ones(n_timestep, dtype=np.float64)
33
- elif schedule == 'jsd':
34
- betas = 1. / np.linspace(n_timestep, 1, n_timestep, dtype=np.float64)
 
35
  elif schedule == "cosine":
 
 
36
  timesteps = (
37
  torch.arange(n_timestep + 1, dtype=torch.float64) /
38
  n_timestep + cosine_s
@@ -47,6 +51,8 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
47
  return betas
48
 
49
 
 
 
50
  def exists(x):
51
  return x is not None
52
 
@@ -67,8 +73,9 @@ class GaussianDiffusion(nn.Module):
67
  conditional=True,
68
  schedule_opt=None,
69
  xT_noise_r=0.1,
70
- seed=1,
71
  opt=None
 
72
  ):
73
  super().__init__()
74
  self.lq_noiselevel_val = schedule_opt["lq_noiselevel"]
@@ -81,6 +88,9 @@ class GaussianDiffusion(nn.Module):
81
  self.ddim = schedule_opt['ddim']
82
  self.xT_noise_r = xT_noise_r
83
  self.seed = seed
 
 
 
84
 
85
  def set_loss(self, device):
86
  if self.loss_type == 'l1':
@@ -88,7 +98,51 @@ class GaussianDiffusion(nn.Module):
88
  elif self.loss_type == 'l2':
89
  self.loss_func = nn.MSELoss(reduction='sum').to(device)
90
  else:
91
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def set_new_noise_schedule(self, schedule_opt, device, num_train_timesteps=1000):
94
  self.ddim = schedule_opt['ddim']
@@ -96,10 +150,10 @@ class GaussianDiffusion(nn.Module):
96
  to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
97
 
98
  betas = make_beta_schedule(
99
- schedule=schedule_opt['schedule'],
100
- n_timestep=num_train_timesteps,
101
- linear_start=schedule_opt['linear_start'],
102
- linear_end=schedule_opt['linear_end'])
103
  betas = betas.detach().cpu().numpy() if isinstance(
104
  betas, torch.Tensor) else betas
105
  alphas = 1. - betas
@@ -112,28 +166,43 @@ class GaussianDiffusion(nn.Module):
112
  self.num_timesteps = int(timesteps)
113
  self.register_buffer('betas', to_torch(betas))
114
  self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
115
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
 
116
 
117
  # calculations for diffusion q(x_t | x_{t-1}) and others
118
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
119
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
120
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
121
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
122
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
 
 
 
 
 
123
 
124
  # calculations for posterior q(x_{t-1} | x_t, x_0)
125
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
126
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
 
 
 
 
127
  self.register_buffer('posterior_log_variance_clipped', to_torch(
128
  np.log(np.maximum(posterior_variance, 1e-20))))
129
  self.register_buffer('posterior_mean_coef1', to_torch(
130
  betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
131
  self.register_buffer('posterior_mean_coef2', to_torch(
132
  (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
133
-
134
  self.schedule_type = schedule_opt['schedule']
135
- if self.ddim > 0:
 
136
  self.ddim_num_steps = schedule_opt['n_timestep']
 
 
 
 
137
 
138
  def predict_start_from_noise(self, x_t, t, noise):
139
  return self.sqrt_recip_alphas_cumprod[t] * x_t - \
@@ -145,7 +214,7 @@ class GaussianDiffusion(nn.Module):
145
  posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
146
  return posterior_mean, posterior_log_variance_clipped
147
 
148
- def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
149
  batch_size = x.shape[0]
150
  noise_level = torch.FloatTensor(
151
  [self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device)
@@ -163,73 +232,144 @@ class GaussianDiffusion(nn.Module):
163
  x_start=x_recon, x_t=x, t=t)
164
  return model_mean, posterior_log_variance, x_recon
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def ddim_sample(self, condition_x, img_or_shape, device, seed=1, img_s1=None):
167
- if self.schedule_type == 'linear':
 
 
 
168
  self.ddim_sampling_eta = 0.8
169
- simple_var = False
170
- threshold_x = False
171
- elif self.schedule_type == 'cosine':
172
  self.ddim_sampling_eta = 0.8
173
- simple_var = False
174
- threshold_x = False
175
 
176
- batch, total_timesteps, sampling_timesteps, eta = \
177
- img_or_shape[0], self.num_train_timesteps, \
178
- self.ddim_num_steps, self.ddim_sampling_eta
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  noisy_img_s1 = None
181
 
 
 
 
 
 
182
  if simple_var:
183
  eta = 1
184
  ts = torch.linspace(total_timesteps, 0, (sampling_timesteps + 1)).to(device).to(torch.long)
185
 
186
  x = torch.randn(img_or_shape).to(device)
187
  batch_size = x.shape[0]
 
188
  imgs = [x]
189
- img_onestep = [condition_x[:, :self.channels, ...]]
190
-
191
- tbar = range(1, sampling_timesteps + 1)
192
- for i in tbar:
 
 
193
  cur_t = ts[i - 1] - 1
194
  prev_t = ts[i] - 1
195
  noise_level = torch.FloatTensor(
196
- [self.sqrt_alphas_cumprod_prev[cur_t]]).repeat(batch_size, 1).to(x.device)
 
 
197
 
198
  alpha_prod_t = self.alphas_cumprod[cur_t]
199
  alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else 1
200
  beta_prod_t = 1 - alpha_prod_t
201
 
 
 
202
  # pred noise
203
- model_output = self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)
204
 
205
  sigma_2 = eta * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
206
  noise = torch.randn_like(x)
207
 
 
 
 
208
  pred_original_sample = (x - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
209
-
 
210
  if threshold_x:
211
  pred_original_sample = self._threshold_sample(pred_original_sample)
212
  else:
213
  pred_original_sample = pred_original_sample.clamp(-1, 1)
214
-
215
  pred_sample_direction = (1 - alpha_prod_t_prev - sigma_2) ** (0.5) * model_output
 
 
216
 
217
  if simple_var:
218
- third_term = (1 - alpha_prod_t / alpha_prod_t_prev) ** 0.5 * noise
219
  else:
220
- third_term = sigma_2 ** 0.5 * noise
221
-
222
  x = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + third_term
223
  imgs.append(x)
224
  img_onestep.append(pred_original_sample)
225
 
226
- imgs = torch.concat(imgs, dim=0)
227
- img_onestep = torch.concat(img_onestep, dim=0)
228
 
229
- return imgs, img_onestep
 
 
230
 
231
  @torch.no_grad()
232
- def p_sample(self, x, t, clip_denoised=True, condition_x=None):
233
  model_mean, model_log_variance, x_recon = self.p_mean_variance(
234
  x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
235
  noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
@@ -238,85 +378,182 @@ class GaussianDiffusion(nn.Module):
238
  @torch.no_grad()
239
  def p_sample_loop(self, x_in, continous=False, seed=1, img_s1=None):
240
  device = self.betas.device
 
241
  sample_inter = 1
242
-
243
  if not self.conditional:
244
  shape = x_in
245
  img = torch.randn(shape, device=device)
246
  ret_img = img
247
  if not self.ddim:
248
- for i in reversed(range(0, self.num_timesteps)):
249
  img, x_recon = self.p_sample(img, i)
250
  if i % sample_inter == 0:
251
  ret_img = torch.cat([ret_img, img], dim=0)
252
  else:
253
- for i in range(0, len(self.ddim_timesteps)):
254
  ddim_t = self.ddim_timesteps[i]
255
  img = self.ddim_sample(img, ddim_t)
256
  if i % sample_inter == 0:
257
  ret_img = torch.cat([ret_img, img], dim=0)
 
258
  else:
259
  x = x_in
260
  shape = (x.shape[0], self.channels, x.shape[-2], x.shape[-1])
261
 
262
- if self.xT_noise_r > 0:
 
 
 
263
  img0 = torch.randn(shape, device=device)
264
  x_start = x_in[:, 0:1, ...]
265
  continuous_sqrt_alpha_cumprod = torch.FloatTensor(
266
- np.random.uniform(
267
- self.sqrt_alphas_cumprod_prev[self.num_timesteps-1],
268
- self.sqrt_alphas_cumprod_prev[self.num_timesteps],
269
- size=x_start.shape[0]
270
- )).to(x_start.device)
271
  continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(x_start.shape[0], -1)
272
-
273
  noise = default(x_start, lambda: torch.randn_like(x_start))
274
  img = self.q_sample(
275
- x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise)
276
- img = self.xT_noise_r * img + (1 - self.xT_noise_r) * img0
 
 
 
277
  else:
278
  img = torch.randn(shape, device=device)
279
 
280
  ret_img = x
281
  img_onestep = x
282
 
283
- if self.opt['stage'] != 2:
284
  if not self.ddim:
285
- for i in reversed(range(0, self.num_timesteps)):
286
  img, x_recon = self.p_sample(img, i, condition_x=x)
287
  if i % sample_inter == 0:
288
- ret_img = torch.cat([ret_img[:, :self.channels, ...], img], dim=0)
289
- if i % sample_inter == 0 or i == self.num_timesteps - 1:
290
- img_onestep = torch.cat([img_onestep[:, :self.channels, ...], x_recon], dim=0)
 
291
  else:
292
  ret_img, img_onestep = self.ddim_sample(condition_x=x, img_or_shape=shape, device=device, seed=seed, img_s1=img_s1)
293
-
 
294
  if continous:
295
  return ret_img, img_onestep
296
  else:
297
  return ret_img[-x_in.shape[0]:], img_onestep
298
  else:
 
299
  self.ddim_num_steps = self.opt['ddim_steps']
300
  ret_img, img_onestep = self.ddim_sample(condition_x=x, img_or_shape=shape, device=device, seed=seed, img_s1=img_s1)
301
 
 
 
 
 
 
302
  if continous:
303
  return ret_img, img_onestep
304
  else:
305
  return ret_img[-x_in.shape[0]:], img_onestep
306
 
 
 
 
 
 
 
 
 
 
 
 
307
  @torch.no_grad()
308
  def sample(self, batch_size=1, continous=False):
309
  image_size = self.image_size
310
  channels = self.channels
311
  return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)
312
 
 
313
  @torch.no_grad()
314
- def super_resolution(self, x_in, continous=False, seed=1, img_s1=None):
 
315
  return self.p_sample_loop(x_in, continous, seed=seed, img_s1=img_s1)
316
 
 
 
 
 
 
317
  def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None):
318
  noise = default(noise, lambda: torch.randn_like(x_start))
 
 
319
  return (
320
  continuous_sqrt_alpha_cumprod * x_start +
321
- (1 - continuous_sqrt_alpha_cumprod ** 2).sqrt() * noise
322
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  import torch
3
+ from torch import device, nn, einsum
4
  import torch.nn.functional as F
5
  from inspect import isfunction
6
  from functools import partial
7
  import numpy as np
8
+ from tqdm import tqdm
9
 
10
 
11
  def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
 
24
  betas = np.linspace(linear_start, linear_end,
25
  n_timestep, dtype=np.float64)
26
  elif schedule == 'warmup10':
27
+ betas = _warmup_beta(linear_start, linear_end,
28
+ n_timestep, 0.1)
29
  elif schedule == 'warmup50':
30
+ betas = _warmup_beta(linear_start, linear_end,
31
+ n_timestep, 0.5)
32
  elif schedule == 'const':
33
  betas = linear_end * np.ones(n_timestep, dtype=np.float64)
34
+ elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
35
+ betas = 1. / np.linspace(n_timestep,
36
+ 1, n_timestep, dtype=np.float64)
37
  elif schedule == "cosine":
38
+ print('======================adopting cosine scheduler========================')
39
+
40
  timesteps = (
41
  torch.arange(n_timestep + 1, dtype=torch.float64) /
42
  n_timestep + cosine_s
 
51
  return betas
52
 
53
 
54
+ # gaussian diffusion trainer class
55
+
56
  def exists(x):
57
  return x is not None
58
 
 
73
  conditional=True,
74
  schedule_opt=None,
75
  xT_noise_r=0.1,
76
+ seed = 1,
77
  opt=None
78
+
79
  ):
80
  super().__init__()
81
  self.lq_noiselevel_val = schedule_opt["lq_noiselevel"]
 
88
  self.ddim = schedule_opt['ddim']
89
  self.xT_noise_r = xT_noise_r
90
  self.seed = seed
91
+ if schedule_opt is not None:
92
+ pass
93
+ # self.set_new_noise_schedule(schedule_opt)
94
 
95
  def set_loss(self, device):
96
  if self.loss_type == 'l1':
 
98
  elif self.loss_type == 'l2':
99
  self.loss_func = nn.MSELoss(reduction='sum').to(device)
100
  else:
101
+ raise NotImplementedError()
102
+
103
+
104
+
105
+
106
+ def betas_for_alpha_bar(
107
+ num_diffusion_timesteps,
108
+ max_beta=0.999,
109
+ alpha_transform_type="cosine",
110
+ ):
111
+ """
112
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
113
+ (1-beta) over time from t = [0,1].
114
+
115
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
116
+ to that part of the diffusion process.
117
+
118
+ Args:
119
+ num_diffusion_timesteps (`int`): the number of betas to produce.
120
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
121
+ prevent singularities.
122
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
123
+ Choose from `cosine` or `exp`
124
+
125
+ Returns:
126
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
127
+ """
128
+ if alpha_transform_type == "cosine":
129
+ def alpha_bar_fn(t):
130
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
131
+ elif alpha_transform_type == "exp":
132
+
133
+ def alpha_bar_fn(t):
134
+ return math.exp(t * -12.0)
135
+ else:
136
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
137
+
138
+ betas = []
139
+ for i in range(num_diffusion_timesteps):
140
+ t1 = i / num_diffusion_timesteps
141
+ t2 = (i + 1) / num_diffusion_timesteps
142
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
143
+ return torch.tensor(betas, dtype=torch.float32)
144
+
145
+
146
 
147
  def set_new_noise_schedule(self, schedule_opt, device, num_train_timesteps=1000):
148
  self.ddim = schedule_opt['ddim']
 
150
  to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
151
 
152
  betas = make_beta_schedule(
153
+ schedule=schedule_opt['schedule'],
154
+ n_timestep=num_train_timesteps,
155
+ linear_start=schedule_opt['linear_start'],
156
+ linear_end=schedule_opt['linear_end'])
157
  betas = betas.detach().cpu().numpy() if isinstance(
158
  betas, torch.Tensor) else betas
159
  alphas = 1. - betas
 
166
  self.num_timesteps = int(timesteps)
167
  self.register_buffer('betas', to_torch(betas))
168
  self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
169
+ self.register_buffer('alphas_cumprod_prev',
170
+ to_torch(alphas_cumprod_prev))
171
 
172
  # calculations for diffusion q(x_t | x_{t-1}) and others
173
+ self.register_buffer('sqrt_alphas_cumprod',
174
+ to_torch(np.sqrt(alphas_cumprod)))
175
+ self.register_buffer('sqrt_one_minus_alphas_cumprod',
176
+ to_torch(np.sqrt(1. - alphas_cumprod)))
177
+ self.register_buffer('log_one_minus_alphas_cumprod',
178
+ to_torch(np.log(1. - alphas_cumprod)))
179
+ self.register_buffer('sqrt_recip_alphas_cumprod',
180
+ to_torch(np.sqrt(1. / alphas_cumprod)))
181
+ self.register_buffer('sqrt_recipm1_alphas_cumprod',
182
+ to_torch(np.sqrt(1. / alphas_cumprod - 1)))
183
 
184
  # calculations for posterior q(x_{t-1} | x_t, x_0)
185
+ posterior_variance = betas * \
186
+ (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
187
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
188
+ self.register_buffer('posterior_variance',
189
+ to_torch(posterior_variance))
190
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
191
  self.register_buffer('posterior_log_variance_clipped', to_torch(
192
  np.log(np.maximum(posterior_variance, 1e-20))))
193
  self.register_buffer('posterior_mean_coef1', to_torch(
194
  betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
195
  self.register_buffer('posterior_mean_coef2', to_torch(
196
  (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
197
+
198
  self.schedule_type = schedule_opt['schedule']
199
+ if self.ddim>0: # use ddim
200
+ print('================ddim scheduler is adopted===================')
201
  self.ddim_num_steps = schedule_opt['n_timestep']
202
+ print('==========ddim sampling steps: {}==========='.format(self.ddim_num_steps))
203
+
204
+
205
+
206
 
207
  def predict_start_from_noise(self, x_t, t, noise):
208
  return self.sqrt_recip_alphas_cumprod[t] * x_t - \
 
214
  posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
215
  return posterior_mean, posterior_log_variance_clipped
216
 
217
+ def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None): # ddpm sample
218
  batch_size = x.shape[0]
219
  noise_level = torch.FloatTensor(
220
  [self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device)
 
232
  x_start=x_recon, x_t=x, t=t)
233
  return model_mean, posterior_log_variance, x_recon
234
 
235
+
236
+
237
+
238
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
239
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
240
+ """
241
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
242
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
243
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
244
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
245
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
246
+
247
+ https://arxiv.org/abs/2205.11487
248
+ """
249
+ dtype = sample.dtype
250
+ batch_size, channels, *remaining_dims = sample.shape
251
+
252
+ if dtype not in (torch.float32, torch.float64):
253
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
254
+
255
+ # Flatten sample for doing quantile calculation along each image
256
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
257
+
258
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
259
+
260
+ s = torch.quantile(abs_sample, 0.995, dim=1)
261
+ s = torch.clamp(s, min=1, max=1.0) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
262
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
263
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
264
+
265
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
266
+ sample = sample.to(dtype)
267
+
268
+ return sample
269
+
270
  def ddim_sample(self, condition_x, img_or_shape, device, seed=1, img_s1=None):
271
+ # self.device = torch.device('cuda:0')
272
+ # self.num_train_timesteps = 2000
273
+ # self.ddim_num_steps = 50
274
+ if self.schedule_type=='linear':
275
  self.ddim_sampling_eta = 0.8
276
+ simple_var=False
277
+ threshold_x = False # threshold_x 和 clip_x
278
+ elif self.schedule_type=='cosine':
279
  self.ddim_sampling_eta = 0.8
280
+ simple_var=False
 
281
 
282
+ threshold_x = False
 
 
283
 
284
+ # torch.manual_seed(seed)
285
+ batch, total_timesteps, sampling_timesteps, eta= \
286
+ img_or_shape[0], self.num_train_timesteps, \
287
+ self.ddim_num_steps, self.ddim_sampling_eta
288
+ # ----------------------------------------------------------------
289
+
290
+ #----------------conditioned augmentation------------------
291
+ # max_noise_level = 400
292
+ # b = img_s1.shape[0]
293
+ # low_res_noise = torch.randn_like(img_s1).to(img_s1.device)
294
+ # low_res_timesteps = self.lq_noiselevel_val #
295
+ # lq_noise_level = torch.FloatTensor(
296
+ # [self.sqrt_alphas_cumprod_prev[low_res_timesteps]]).repeat(b, 1).to(img_s1.device)
297
+
298
+ # noisy_img_s1 = self.q_sample(
299
+ # x_start=img_s1, continuous_sqrt_alpha_cumprod=lq_noise_level.view(-1, 1, 1, 1), noise=low_res_noise)
300
  noisy_img_s1 = None
301
 
302
+ #----------------------------------------------------
303
+
304
+
305
+
306
+
307
  if simple_var:
308
  eta = 1
309
  ts = torch.linspace(total_timesteps, 0, (sampling_timesteps + 1)).to(device).to(torch.long)
310
 
311
  x = torch.randn(img_or_shape).to(device)
312
  batch_size = x.shape[0]
313
+ # net = self.denoise_fn
314
  imgs = [x]
315
+ img_onestep = [condition_x[:,:self.channels,...]]
316
+ if self.opt['stage']!=2:
317
+ tbar = tqdm(range(1, sampling_timesteps + 1),f'seed{seed} DDIM sampling ({self.schedule_type}) with eta {eta} simple_var {simple_var}')
318
+ else:
319
+ tbar = range(1, sampling_timesteps + 1)
320
+ for i in tbar:
321
  cur_t = ts[i - 1] - 1
322
  prev_t = ts[i] - 1
323
  noise_level = torch.FloatTensor(
324
+ # [self.sqrt_alphas_cumprod_prev[cur_t+1]]).repeat(batch_size, 1).to(x.device)
325
+ [self.sqrt_alphas_cumprod_prev[cur_t]]).repeat(batch_size, 1).to(x.device)
326
+
327
 
328
  alpha_prod_t = self.alphas_cumprod[cur_t]
329
  alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else 1
330
  beta_prod_t = 1 - alpha_prod_t
331
 
332
+ # t_tensor = torch.tensor([cur_t] * batch_size,
333
+ # dtype=torch.long).to(device).unsqueeze(1)
334
  # pred noise
335
+ model_output = self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)
336
 
337
  sigma_2 = eta * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
338
  noise = torch.randn_like(x)
339
 
340
+ # first_term = (alpha_prod_t_prev / alpha_prod_t)**0.5 * x
341
+ # second_term = ((1 - alpha_prod_t_prev - sigma_2)**0.5 -(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)**0.5) * model_output
342
+ # x_start = first_term - (alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)**0.5 * model_output
343
  pred_original_sample = (x - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
344
+
345
+
346
  if threshold_x:
347
  pred_original_sample = self._threshold_sample(pred_original_sample)
348
  else:
349
  pred_original_sample = pred_original_sample.clamp(-1, 1)
350
+
351
  pred_sample_direction = (1 - alpha_prod_t_prev - sigma_2) ** (0.5) * model_output
352
+
353
+
354
 
355
  if simple_var:
356
+ third_term = (1 - alpha_prod_t / alpha_prod_t_prev)**0.5 * noise # var of ddpm
357
  else:
358
+ third_term = sigma_2**0.5 * noise #ddpm
359
+ # x = first_term + second_term + third_term
360
  x = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + third_term
361
  imgs.append(x)
362
  img_onestep.append(pred_original_sample)
363
 
364
+ imgs = torch.concat(imgs, dim = 0)
365
+ img_onestep = torch.concat(img_onestep, dim = 0)
366
 
367
+ # torch.seed()
368
+ return imgs, img_onestep
369
+
370
 
371
  @torch.no_grad()
372
+ def p_sample(self, x, t, clip_denoised=True, condition_x=None): # sr3 sample
373
  model_mean, model_log_variance, x_recon = self.p_mean_variance(
374
  x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
375
  noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
 
378
  @torch.no_grad()
379
  def p_sample_loop(self, x_in, continous=False, seed=1, img_s1=None):
380
  device = self.betas.device
381
+ # sample_inter = (1 | (self.num_timesteps//20))
382
  sample_inter = 1
 
383
  if not self.conditional:
384
  shape = x_in
385
  img = torch.randn(shape, device=device)
386
  ret_img = img
387
  if not self.ddim:
388
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
389
  img, x_recon = self.p_sample(img, i)
390
  if i % sample_inter == 0:
391
  ret_img = torch.cat([ret_img, img], dim=0)
392
  else:
393
+ for i in tqdm(range(0, len(self.ddim_timesteps)), desc='sampling loop time step', total=len(self.ddim_timesteps)):
394
  ddim_t = self.ddim_timesteps[i]
395
  img = self.ddim_sample(img, ddim_t)
396
  if i % sample_inter == 0:
397
  ret_img = torch.cat([ret_img, img], dim=0)
398
+
399
  else:
400
  x = x_in
401
  shape = (x.shape[0], self.channels, x.shape[-2], x.shape[-1])
402
 
403
+ # ---------ddpm zT as the inital noise------------------------------------
404
+ if self.xT_noise_r>0:
405
+ # ratio = 0.1
406
+ print('adopting ddpm inversion as initial noise, ratio is {}'.format(self.xT_noise_r))
407
  img0 = torch.randn(shape, device=device)
408
  x_start = x_in[:, 0:1, ...]
409
  continuous_sqrt_alpha_cumprod = torch.FloatTensor(
410
+ np.random.uniform(
411
+ self.sqrt_alphas_cumprod_prev[self.num_timesteps-1],
412
+ self.sqrt_alphas_cumprod_prev[self.num_timesteps],
413
+ size=x_start.shape[0]
414
+ )).to(x_start.device)
415
  continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(x_start.shape[0], -1)
416
+
417
  noise = default(x_start, lambda: torch.randn_like(x_start))
418
  img = self.q_sample(
419
+ x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise)
420
+
421
+
422
+ img = self.xT_noise_r*img + (1-self.xT_noise_r)*img0
423
+ #-------------------------------------------------------------------------
424
  else:
425
  img = torch.randn(shape, device=device)
426
 
427
  ret_img = x
428
  img_onestep = x
429
 
430
+ if self.opt['stage']!=2:
431
  if not self.ddim:
432
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='ddpm sampling loop time step', total=self.num_timesteps):
433
  img, x_recon = self.p_sample(img, i, condition_x=x)
434
  if i % sample_inter == 0:
435
+ ret_img = torch.cat([ret_img[:,:self.channels,...], img], dim=0)
436
+ if i % sample_inter==0 or i==self.num_timesteps-1:
437
+ img_onestep = torch.cat([img_onestep[:,:self.channels,...], x_recon], dim=0)
438
+
439
  else:
440
  ret_img, img_onestep = self.ddim_sample(condition_x=x, img_or_shape=shape, device=device, seed=seed, img_s1=img_s1)
441
+
442
+
443
  if continous:
444
  return ret_img, img_onestep
445
  else:
446
  return ret_img[-x_in.shape[0]:], img_onestep
447
  else:
448
+ # timestep = self.num_timesteps-1
449
  self.ddim_num_steps = self.opt['ddim_steps']
450
  ret_img, img_onestep = self.ddim_sample(condition_x=x, img_or_shape=shape, device=device, seed=seed, img_s1=img_s1)
451
 
452
+
453
+ # img, x_recon = self.p_sample(img, timestep, condition_x=x)
454
+ # ret_img = torch.cat([ret_img[:,:self.channels,...], x_recon], dim=0)
455
+ # img_onestep = torch.cat([img_onestep[:,:self.channels,...], x_recon], dim=0)
456
+
457
  if continous:
458
  return ret_img, img_onestep
459
  else:
460
  return ret_img[-x_in.shape[0]:], img_onestep
461
 
462
+ # for i in tqdm(range(0, len(self.ddim_timesteps)), desc='ddim sampling loop time step', total=len(self.ddim_timesteps)):
463
+ # ddim_t = self.ddim_timesteps[i]
464
+ # img = self.ddim_sample(img, ddim_t, condition_x=x)
465
+ # if i % sample_inter == 0:
466
+ # ret_img = torch.cat([ret_img[:,:self.channels,...], img], dim=0)
467
+
468
+
469
+ # 20, 8, 2hw
470
+
471
+
472
+
473
  @torch.no_grad()
474
  def sample(self, batch_size=1, continous=False):
475
  image_size = self.image_size
476
  channels = self.channels
477
  return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)
478
 
479
+
480
  @torch.no_grad()
481
+ def super_resolution(self, x_in, continous=False, seed=1, img_s1=None): # test
482
+
483
  return self.p_sample_loop(x_in, continous, seed=seed, img_s1=img_s1)
484
 
485
+
486
+
487
+
488
+
489
+
490
  def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None):
491
  noise = default(noise, lambda: torch.randn_like(x_start))
492
+
493
+ # random gama
494
  return (
495
  continuous_sqrt_alpha_cumprod * x_start +
496
+ (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise
497
  )
498
+
499
+ def p_losses(self, x_in, noise=None):
500
+ # x_in {'HR': img_EO[0:1], 'LR': img_s1[0:1], 'condition': img_ppb[0:1], 'SR': img_s1[0:1], 'Index': index, 'filename':filename}
501
+ x_start = x_in['HR']
502
+
503
+
504
+
505
+ [b, c, h, w] = x_start.shape
506
+ if self.opt['stage'] ==2:
507
+ t = 999
508
+ self.ddim_num_steps = self.opt['ddim_steps']
509
+ x = x_in['SR']
510
+ shape = (x.shape[0], self.channels, x.shape[-2], x.shape[-1])
511
+ ret_img, img_onestep = self.ddim_sample(condition_x=x, img_or_shape=shape, device=x.device, seed=self.seed, img_s1=x)
512
+ x_recon = ret_img[-x.shape[0]:]
513
+
514
+
515
+ else:
516
+ t = np.random.randint(1, self.num_timesteps + 1)
517
+
518
+ continuous_sqrt_alpha_cumprod = torch.FloatTensor(
519
+ np.random.uniform(
520
+ self.sqrt_alphas_cumprod_prev[t-1],
521
+ self.sqrt_alphas_cumprod_prev[t],
522
+ size=b
523
+ )).to(x_start.device)
524
+ continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(b, -1)
525
+
526
+ #-----------pixel loss-------------
527
+ noise = default(noise, lambda: torch.randn_like(x_start))
528
+ x_noisy = self.q_sample(
529
+ x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise)
530
+
531
+
532
+ ##low_res_timesteps in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
533
+ if not self.conditional:
534
+ x_recon = self.denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod)
535
+ else:
536
+
537
+ x_recon, condition_feats = self.denoise_fn(
538
+ torch.cat([x_in['SR'], x_noisy], dim=1),
539
+ continuous_sqrt_alpha_cumprod,
540
+ # noisy_img_s1,
541
+ # class_label=lq_continuous_sqrt_alpha_cumprod,
542
+ return_condition=True
543
+ )
544
+ if self.opt['stage']==2:
545
+ l_pix = self.loss_func(x_start, x_recon)
546
+
547
+ else:
548
+ l_pix = self.loss_func(noise, x_recon)
549
+
550
+
551
+ x_pred = x_recon
552
+ condition_feats=None
553
+
554
+
555
+ return l_pix, x_start, x_pred, condition_feats, torch.tensor(t, device=l_pix.device)
556
+
557
+
558
+ def forward(self, x, *args, **kwargs):
559
+ return self.p_losses(x, *args, **kwargs)