raannakasturi commited on
Commit
e91364a
1 Parent(s): 8cbc4b1

Delete ddpm.py

Browse files
Files changed (1) hide show
  1. ddpm.py +0 -1873
ddpm.py DELETED
@@ -1,1873 +0,0 @@
1
- """
2
- wild mixture of
3
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
- https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
- https://github.com/CompVis/taming-transformers
6
- -- merci
7
- """
8
-
9
- import torch
10
- import torch.nn as nn
11
- import numpy as np
12
- import pytorch_lightning as pl
13
- from torch.optim.lr_scheduler import LambdaLR
14
- from einops import rearrange, repeat
15
- from contextlib import contextmanager, nullcontext
16
- from functools import partial
17
- import itertools
18
- from tqdm import tqdm
19
- from torchvision.utils import make_grid
20
- from pytorch_lightning.utilities.rank_zero import rank_zero_only
21
- from omegaconf import ListConfig
22
-
23
- from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
24
- from ldm.modules.ema import LitEma
25
- from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
26
- from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
27
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
28
- from ldm.models.diffusion.ddim import DDIMSampler
29
-
30
-
31
- __conditioning_keys__ = {'concat': 'c_concat',
32
- 'crossattn': 'c_crossattn',
33
- 'adm': 'y'}
34
-
35
-
36
- def disabled_train(self, mode=True):
37
- """Overwrite model.train with this function to make sure train/eval mode
38
- does not change anymore."""
39
- return self
40
-
41
-
42
- def uniform_on_device(r1, r2, shape, device):
43
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
44
-
45
-
46
- class DDPM(pl.LightningModule):
47
- # classic DDPM with Gaussian diffusion, in image space
48
- def __init__(self,
49
- unet_config,
50
- timesteps=1000,
51
- beta_schedule="linear",
52
- loss_type="l2",
53
- ckpt_path=None,
54
- ignore_keys=[],
55
- load_only_unet=False,
56
- monitor="val/loss",
57
- use_ema=True,
58
- first_stage_key="image",
59
- image_size=256,
60
- channels=3,
61
- log_every_t=100,
62
- clip_denoised=True,
63
- linear_start=1e-4,
64
- linear_end=2e-2,
65
- cosine_s=8e-3,
66
- given_betas=None,
67
- original_elbo_weight=0.,
68
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
69
- l_simple_weight=1.,
70
- conditioning_key=None,
71
- parameterization="eps", # all assuming fixed variance schedules
72
- scheduler_config=None,
73
- use_positional_encodings=False,
74
- learn_logvar=False,
75
- logvar_init=0.,
76
- make_it_fit=False,
77
- ucg_training=None,
78
- reset_ema=False,
79
- reset_num_ema_updates=False,
80
- ):
81
- super().__init__()
82
- assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
83
- self.parameterization = parameterization
84
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
85
- self.cond_stage_model = None
86
- self.clip_denoised = clip_denoised
87
- self.log_every_t = log_every_t
88
- self.first_stage_key = first_stage_key
89
- self.image_size = image_size # try conv?
90
- self.channels = channels
91
- self.use_positional_encodings = use_positional_encodings
92
- self.model = DiffusionWrapper(unet_config, conditioning_key)
93
- count_params(self.model, verbose=True)
94
- self.use_ema = use_ema
95
- if self.use_ema:
96
- self.model_ema = LitEma(self.model)
97
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
98
-
99
- self.use_scheduler = scheduler_config is not None
100
- if self.use_scheduler:
101
- self.scheduler_config = scheduler_config
102
-
103
- self.v_posterior = v_posterior
104
- self.original_elbo_weight = original_elbo_weight
105
- self.l_simple_weight = l_simple_weight
106
-
107
- if monitor is not None:
108
- self.monitor = monitor
109
- self.make_it_fit = make_it_fit
110
- if reset_ema: assert exists(ckpt_path)
111
- if ckpt_path is not None:
112
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
113
- if reset_ema:
114
- assert self.use_ema
115
- print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
116
- self.model_ema = LitEma(self.model)
117
- if reset_num_ema_updates:
118
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
119
- assert self.use_ema
120
- self.model_ema.reset_num_updates()
121
-
122
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
123
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
124
-
125
- self.loss_type = loss_type
126
-
127
- self.learn_logvar = learn_logvar
128
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
129
- if self.learn_logvar:
130
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
131
-
132
- self.ucg_training = ucg_training or dict()
133
- if self.ucg_training:
134
- self.ucg_prng = np.random.RandomState()
135
-
136
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
137
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
138
- if exists(given_betas):
139
- betas = given_betas
140
- else:
141
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
142
- cosine_s=cosine_s)
143
- alphas = 1. - betas
144
- alphas_cumprod = np.cumprod(alphas, axis=0)
145
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
146
-
147
- timesteps, = betas.shape
148
- self.num_timesteps = int(timesteps)
149
- self.linear_start = linear_start
150
- self.linear_end = linear_end
151
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
152
-
153
- to_torch = partial(torch.tensor, dtype=torch.float32)
154
-
155
- self.register_buffer('betas', to_torch(betas))
156
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
157
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
158
-
159
- # calculations for diffusion q(x_t | x_{t-1}) and others
160
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
161
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
162
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
163
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
164
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
165
-
166
- # calculations for posterior q(x_{t-1} | x_t, x_0)
167
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
168
- 1. - alphas_cumprod) + self.v_posterior * betas
169
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
170
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
171
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
172
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
173
- self.register_buffer('posterior_mean_coef1', to_torch(
174
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
175
- self.register_buffer('posterior_mean_coef2', to_torch(
176
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
177
-
178
- if self.parameterization == "eps":
179
- lvlb_weights = self.betas ** 2 / (
180
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
181
- elif self.parameterization == "x0":
182
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
183
- elif self.parameterization == "v":
184
- lvlb_weights = torch.ones_like(self.betas ** 2 / (
185
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
186
- else:
187
- raise NotImplementedError("mu not supported")
188
- lvlb_weights[0] = lvlb_weights[1]
189
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
190
- assert not torch.isnan(self.lvlb_weights).all()
191
-
192
- @contextmanager
193
- def ema_scope(self, context=None):
194
- if self.use_ema:
195
- self.model_ema.store(self.model.parameters())
196
- self.model_ema.copy_to(self.model)
197
- if context is not None:
198
- print(f"{context}: Switched to EMA weights")
199
- try:
200
- yield None
201
- finally:
202
- if self.use_ema:
203
- self.model_ema.restore(self.model.parameters())
204
- if context is not None:
205
- print(f"{context}: Restored training weights")
206
-
207
- @torch.no_grad()
208
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
209
- sd = torch.load(path, map_location="cpu")
210
- if "state_dict" in list(sd.keys()):
211
- sd = sd["state_dict"]
212
- keys = list(sd.keys())
213
- for k in keys:
214
- for ik in ignore_keys:
215
- if k.startswith(ik):
216
- print("Deleting key {} from state_dict.".format(k))
217
- del sd[k]
218
- if self.make_it_fit:
219
- n_params = len([name for name, _ in
220
- itertools.chain(self.named_parameters(),
221
- self.named_buffers())])
222
- for name, param in tqdm(
223
- itertools.chain(self.named_parameters(),
224
- self.named_buffers()),
225
- desc="Fitting old weights to new weights",
226
- total=n_params
227
- ):
228
- if not name in sd:
229
- continue
230
- old_shape = sd[name].shape
231
- new_shape = param.shape
232
- assert len(old_shape) == len(new_shape)
233
- if len(new_shape) > 2:
234
- # we only modify first two axes
235
- assert new_shape[2:] == old_shape[2:]
236
- # assumes first axis corresponds to output dim
237
- if not new_shape == old_shape:
238
- new_param = param.clone()
239
- old_param = sd[name]
240
- if len(new_shape) == 1:
241
- for i in range(new_param.shape[0]):
242
- new_param[i] = old_param[i % old_shape[0]]
243
- elif len(new_shape) >= 2:
244
- for i in range(new_param.shape[0]):
245
- for j in range(new_param.shape[1]):
246
- new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
247
-
248
- n_used_old = torch.ones(old_shape[1])
249
- for j in range(new_param.shape[1]):
250
- n_used_old[j % old_shape[1]] += 1
251
- n_used_new = torch.zeros(new_shape[1])
252
- for j in range(new_param.shape[1]):
253
- n_used_new[j] = n_used_old[j % old_shape[1]]
254
-
255
- n_used_new = n_used_new[None, :]
256
- while len(n_used_new.shape) < len(new_shape):
257
- n_used_new = n_used_new.unsqueeze(-1)
258
- new_param /= n_used_new
259
-
260
- sd[name] = new_param
261
-
262
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
263
- sd, strict=False)
264
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
265
- if len(missing) > 0:
266
- print(f"Missing Keys:\n {missing}")
267
- if len(unexpected) > 0:
268
- print(f"\nUnexpected Keys:\n {unexpected}")
269
-
270
- def q_mean_variance(self, x_start, t):
271
- """
272
- Get the distribution q(x_t | x_0).
273
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
274
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
275
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
276
- """
277
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
278
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
279
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
280
- return mean, variance, log_variance
281
-
282
- def predict_start_from_noise(self, x_t, t, noise):
283
- return (
284
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
285
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
286
- )
287
-
288
- def predict_start_from_z_and_v(self, x_t, t, v):
289
- # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
290
- # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
291
- return (
292
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
293
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
294
- )
295
-
296
- def predict_eps_from_z_and_v(self, x_t, t, v):
297
- return (
298
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
299
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
300
- )
301
-
302
- def q_posterior(self, x_start, x_t, t):
303
- posterior_mean = (
304
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
305
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
306
- )
307
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
308
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
309
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
310
-
311
- def p_mean_variance(self, x, t, clip_denoised: bool):
312
- model_out = self.model(x, t)
313
- if self.parameterization == "eps":
314
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
315
- elif self.parameterization == "x0":
316
- x_recon = model_out
317
- if clip_denoised:
318
- x_recon.clamp_(-1., 1.)
319
-
320
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
321
- return model_mean, posterior_variance, posterior_log_variance
322
-
323
- @torch.no_grad()
324
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
325
- b, *_, device = *x.shape, x.device
326
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
327
- noise = noise_like(x.shape, device, repeat_noise)
328
- # no noise when t == 0
329
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
330
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
331
-
332
- @torch.no_grad()
333
- def p_sample_loop(self, shape, return_intermediates=False):
334
- device = self.betas.device
335
- b = shape[0]
336
- img = torch.randn(shape, device=device)
337
- intermediates = [img]
338
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
339
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
340
- clip_denoised=self.clip_denoised)
341
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
342
- intermediates.append(img)
343
- if return_intermediates:
344
- return img, intermediates
345
- return img
346
-
347
- @torch.no_grad()
348
- def sample(self, batch_size=16, return_intermediates=False):
349
- image_size = self.image_size
350
- channels = self.channels
351
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
352
- return_intermediates=return_intermediates)
353
-
354
- def q_sample(self, x_start, t, noise=None):
355
- noise = default(noise, lambda: torch.randn_like(x_start))
356
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
357
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
358
-
359
- def get_v(self, x, noise, t):
360
- return (
361
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
362
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
363
- )
364
-
365
- def get_loss(self, pred, target, mean=True):
366
- if self.loss_type == 'l1':
367
- loss = (target - pred).abs()
368
- if mean:
369
- loss = loss.mean()
370
- elif self.loss_type == 'l2':
371
- if mean:
372
- loss = torch.nn.functional.mse_loss(target, pred)
373
- else:
374
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
375
- else:
376
- raise NotImplementedError("unknown loss type '{loss_type}'")
377
-
378
- return loss
379
-
380
- def p_losses(self, x_start, t, noise=None):
381
- noise = default(noise, lambda: torch.randn_like(x_start))
382
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
383
- model_out = self.model(x_noisy, t)
384
-
385
- loss_dict = {}
386
- if self.parameterization == "eps":
387
- target = noise
388
- elif self.parameterization == "x0":
389
- target = x_start
390
- elif self.parameterization == "v":
391
- target = self.get_v(x_start, noise, t)
392
- else:
393
- raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
394
-
395
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
396
-
397
- log_prefix = 'train' if self.training else 'val'
398
-
399
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
400
- loss_simple = loss.mean() * self.l_simple_weight
401
-
402
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
403
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
404
-
405
- loss = loss_simple + self.original_elbo_weight * loss_vlb
406
-
407
- loss_dict.update({f'{log_prefix}/loss': loss})
408
-
409
- return loss, loss_dict
410
-
411
- def forward(self, x, *args, **kwargs):
412
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
413
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
414
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
415
- return self.p_losses(x, t, *args, **kwargs)
416
-
417
- def get_input(self, batch, k):
418
- x = batch[k]
419
- if len(x.shape) == 3:
420
- x = x[..., None]
421
- x = rearrange(x, 'b h w c -> b c h w')
422
- x = x.to(memory_format=torch.contiguous_format).float()
423
- return x
424
-
425
- def shared_step(self, batch):
426
- x = self.get_input(batch, self.first_stage_key)
427
- loss, loss_dict = self(x)
428
- return loss, loss_dict
429
-
430
- def training_step(self, batch, batch_idx):
431
- for k in self.ucg_training:
432
- p = self.ucg_training[k]["p"]
433
- val = self.ucg_training[k]["val"]
434
- if val is None:
435
- val = ""
436
- for i in range(len(batch[k])):
437
- if self.ucg_prng.choice(2, p=[1 - p, p]):
438
- batch[k][i] = val
439
-
440
- loss, loss_dict = self.shared_step(batch)
441
-
442
- self.log_dict(loss_dict, prog_bar=True,
443
- logger=True, on_step=True, on_epoch=True)
444
-
445
- self.log("global_step", self.global_step,
446
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
447
-
448
- if self.use_scheduler:
449
- lr = self.optimizers().param_groups[0]['lr']
450
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
451
-
452
- return loss
453
-
454
- @torch.no_grad()
455
- def validation_step(self, batch, batch_idx):
456
- _, loss_dict_no_ema = self.shared_step(batch)
457
- with self.ema_scope():
458
- _, loss_dict_ema = self.shared_step(batch)
459
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
460
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
461
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
462
-
463
- def on_train_batch_end(self, *args, **kwargs):
464
- if self.use_ema:
465
- self.model_ema(self.model)
466
-
467
- def _get_rows_from_list(self, samples):
468
- n_imgs_per_row = len(samples)
469
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
470
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
471
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
472
- return denoise_grid
473
-
474
- @torch.no_grad()
475
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
476
- log = dict()
477
- x = self.get_input(batch, self.first_stage_key)
478
- N = min(x.shape[0], N)
479
- n_row = min(x.shape[0], n_row)
480
- x = x.to(self.device)[:N]
481
- log["inputs"] = x
482
-
483
- # get diffusion row
484
- diffusion_row = list()
485
- x_start = x[:n_row]
486
-
487
- for t in range(self.num_timesteps):
488
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
489
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
490
- t = t.to(self.device).long()
491
- noise = torch.randn_like(x_start)
492
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
493
- diffusion_row.append(x_noisy)
494
-
495
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
496
-
497
- if sample:
498
- # get denoise row
499
- with self.ema_scope("Plotting"):
500
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
501
-
502
- log["samples"] = samples
503
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
504
-
505
- if return_keys:
506
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
507
- return log
508
- else:
509
- return {key: log[key] for key in return_keys}
510
- return log
511
-
512
- def configure_optimizers(self):
513
- lr = self.learning_rate
514
- params = list(self.model.parameters())
515
- if self.learn_logvar:
516
- params = params + [self.logvar]
517
- opt = torch.optim.AdamW(params, lr=lr)
518
- return opt
519
-
520
-
521
- class LatentDiffusion(DDPM):
522
- """main class"""
523
-
524
- def __init__(self,
525
- first_stage_config,
526
- cond_stage_config,
527
- num_timesteps_cond=None,
528
- cond_stage_key="image",
529
- cond_stage_trainable=False,
530
- concat_mode=True,
531
- cond_stage_forward=None,
532
- conditioning_key=None,
533
- scale_factor=1.0,
534
- scale_by_std=False,
535
- force_null_conditioning=False,
536
- *args, **kwargs):
537
- self.force_null_conditioning = force_null_conditioning
538
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
539
- self.scale_by_std = scale_by_std
540
- assert self.num_timesteps_cond <= kwargs['timesteps']
541
- # for backwards compatibility after implementation of DiffusionWrapper
542
- if conditioning_key is None:
543
- conditioning_key = 'concat' if concat_mode else 'crossattn'
544
- if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
545
- conditioning_key = None
546
- ckpt_path = kwargs.pop("ckpt_path", None)
547
- reset_ema = kwargs.pop("reset_ema", False)
548
- reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
549
- ignore_keys = kwargs.pop("ignore_keys", [])
550
- super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
551
- self.concat_mode = concat_mode
552
- self.cond_stage_trainable = cond_stage_trainable
553
- self.cond_stage_key = cond_stage_key
554
- try:
555
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
556
- except:
557
- self.num_downs = 0
558
- if not scale_by_std:
559
- self.scale_factor = scale_factor
560
- else:
561
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
562
- self.instantiate_first_stage(first_stage_config)
563
- self.instantiate_cond_stage(cond_stage_config)
564
- self.cond_stage_forward = cond_stage_forward
565
- self.clip_denoised = False
566
- self.bbox_tokenizer = None
567
-
568
- self.restarted_from_ckpt = False
569
- if ckpt_path is not None:
570
- self.init_from_ckpt(ckpt_path, ignore_keys)
571
- self.restarted_from_ckpt = True
572
- if reset_ema:
573
- assert self.use_ema
574
- print(
575
- f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
576
- self.model_ema = LitEma(self.model)
577
- if reset_num_ema_updates:
578
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
579
- assert self.use_ema
580
- self.model_ema.reset_num_updates()
581
-
582
- def make_cond_schedule(self, ):
583
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
584
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
585
- self.cond_ids[:self.num_timesteps_cond] = ids
586
-
587
- @rank_zero_only
588
- @torch.no_grad()
589
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
590
- # only for very first batch
591
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
592
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
593
- # set rescale weight to 1./std of encodings
594
- print("### USING STD-RESCALING ###")
595
- x = super().get_input(batch, self.first_stage_key)
596
- x = x.to(self.device)
597
- encoder_posterior = self.encode_first_stage(x)
598
- z = self.get_first_stage_encoding(encoder_posterior).detach()
599
- del self.scale_factor
600
- self.register_buffer('scale_factor', 1. / z.flatten().std())
601
- print(f"setting self.scale_factor to {self.scale_factor}")
602
- print("### USING STD-RESCALING ###")
603
-
604
- def register_schedule(self,
605
- given_betas=None, beta_schedule="linear", timesteps=1000,
606
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
607
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
608
-
609
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
610
- if self.shorten_cond_schedule:
611
- self.make_cond_schedule()
612
-
613
- def instantiate_first_stage(self, config):
614
- model = instantiate_from_config(config)
615
- self.first_stage_model = model.eval()
616
- self.first_stage_model.train = disabled_train
617
- for param in self.first_stage_model.parameters():
618
- param.requires_grad = False
619
-
620
- def instantiate_cond_stage(self, config):
621
- if not self.cond_stage_trainable:
622
- if config == "__is_first_stage__":
623
- print("Using first stage also as cond stage.")
624
- self.cond_stage_model = self.first_stage_model
625
- elif config == "__is_unconditional__":
626
- print(f"Training {self.__class__.__name__} as an unconditional model.")
627
- self.cond_stage_model = None
628
- # self.be_unconditional = True
629
- else:
630
- model = instantiate_from_config(config)
631
- self.cond_stage_model = model.eval()
632
- self.cond_stage_model.train = disabled_train
633
- for param in self.cond_stage_model.parameters():
634
- param.requires_grad = False
635
- else:
636
- assert config != '__is_first_stage__'
637
- assert config != '__is_unconditional__'
638
- model = instantiate_from_config(config)
639
- self.cond_stage_model = model
640
-
641
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
642
- denoise_row = []
643
- for zd in tqdm(samples, desc=desc):
644
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
645
- force_not_quantize=force_no_decoder_quantization))
646
- n_imgs_per_row = len(denoise_row)
647
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
648
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
649
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
650
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
651
- return denoise_grid
652
-
653
- def get_first_stage_encoding(self, encoder_posterior):
654
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
655
- z = encoder_posterior.sample()
656
- elif isinstance(encoder_posterior, torch.Tensor):
657
- z = encoder_posterior
658
- else:
659
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
660
- return self.scale_factor * z
661
-
662
- def get_learned_conditioning(self, c):
663
- if self.cond_stage_forward is None:
664
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
665
- c = self.cond_stage_model.encode(c)
666
- if isinstance(c, DiagonalGaussianDistribution):
667
- c = c.mode()
668
- else:
669
- c = self.cond_stage_model(c)
670
- else:
671
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
672
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
673
- return c
674
-
675
- def meshgrid(self, h, w):
676
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
677
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
678
-
679
- arr = torch.cat([y, x], dim=-1)
680
- return arr
681
-
682
- def delta_border(self, h, w):
683
- """
684
- :param h: height
685
- :param w: width
686
- :return: normalized distance to image border,
687
- wtith min distance = 0 at border and max dist = 0.5 at image center
688
- """
689
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
690
- arr = self.meshgrid(h, w) / lower_right_corner
691
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
692
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
693
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
694
- return edge_dist
695
-
696
- def get_weighting(self, h, w, Ly, Lx, device):
697
- weighting = self.delta_border(h, w)
698
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
699
- self.split_input_params["clip_max_weight"], )
700
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
701
-
702
- if self.split_input_params["tie_braker"]:
703
- L_weighting = self.delta_border(Ly, Lx)
704
- L_weighting = torch.clip(L_weighting,
705
- self.split_input_params["clip_min_tie_weight"],
706
- self.split_input_params["clip_max_tie_weight"])
707
-
708
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
709
- weighting = weighting * L_weighting
710
- return weighting
711
-
712
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
713
- """
714
- :param x: img of size (bs, c, h, w)
715
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
716
- """
717
- bs, nc, h, w = x.shape
718
-
719
- # number of crops in image
720
- Ly = (h - kernel_size[0]) // stride[0] + 1
721
- Lx = (w - kernel_size[1]) // stride[1] + 1
722
-
723
- if uf == 1 and df == 1:
724
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
725
- unfold = torch.nn.Unfold(**fold_params)
726
-
727
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
728
-
729
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
730
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
731
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
732
-
733
- elif uf > 1 and df == 1:
734
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
735
- unfold = torch.nn.Unfold(**fold_params)
736
-
737
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
738
- dilation=1, padding=0,
739
- stride=(stride[0] * uf, stride[1] * uf))
740
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
741
-
742
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
743
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
744
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
745
-
746
- elif df > 1 and uf == 1:
747
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
748
- unfold = torch.nn.Unfold(**fold_params)
749
-
750
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
751
- dilation=1, padding=0,
752
- stride=(stride[0] // df, stride[1] // df))
753
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
754
-
755
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
756
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
757
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
758
-
759
- else:
760
- raise NotImplementedError
761
-
762
- return fold, unfold, normalization, weighting
763
-
764
- @torch.no_grad()
765
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
766
- cond_key=None, return_original_cond=False, bs=None, return_x=False):
767
- x = super().get_input(batch, k)
768
- if bs is not None:
769
- x = x[:bs]
770
- x = x.to(self.device)
771
- encoder_posterior = self.encode_first_stage(x)
772
- z = self.get_first_stage_encoding(encoder_posterior).detach()
773
-
774
- if self.model.conditioning_key is not None and not self.force_null_conditioning:
775
- if cond_key is None:
776
- cond_key = self.cond_stage_key
777
- if cond_key != self.first_stage_key:
778
- if cond_key in ['caption', 'coordinates_bbox', "txt"]:
779
- xc = batch[cond_key]
780
- elif cond_key in ['class_label', 'cls']:
781
- xc = batch
782
- else:
783
- xc = super().get_input(batch, cond_key).to(self.device)
784
- else:
785
- xc = x
786
- if not self.cond_stage_trainable or force_c_encode:
787
- if isinstance(xc, dict) or isinstance(xc, list):
788
- c = self.get_learned_conditioning(xc)
789
- else:
790
- c = self.get_learned_conditioning(xc.to(self.device))
791
- else:
792
- c = xc
793
- if bs is not None:
794
- c = c[:bs]
795
-
796
- if self.use_positional_encodings:
797
- pos_x, pos_y = self.compute_latent_shifts(batch)
798
- ckey = __conditioning_keys__[self.model.conditioning_key]
799
- c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
800
-
801
- else:
802
- c = None
803
- xc = None
804
- if self.use_positional_encodings:
805
- pos_x, pos_y = self.compute_latent_shifts(batch)
806
- c = {'pos_x': pos_x, 'pos_y': pos_y}
807
- out = [z, c]
808
- if return_first_stage_outputs:
809
- xrec = self.decode_first_stage(z)
810
- out.extend([x, xrec])
811
- if return_x:
812
- out.extend([x])
813
- if return_original_cond:
814
- out.append(xc)
815
- return out
816
-
817
- @torch.no_grad()
818
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
819
- if predict_cids:
820
- if z.dim() == 4:
821
- z = torch.argmax(z.exp(), dim=1).long()
822
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
823
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
824
-
825
- z = 1. / self.scale_factor * z
826
- return self.first_stage_model.decode(z)
827
-
828
- @torch.no_grad()
829
- def encode_first_stage(self, x):
830
- return self.first_stage_model.encode(x)
831
-
832
- def shared_step(self, batch, **kwargs):
833
- x, c = self.get_input(batch, self.first_stage_key)
834
- loss = self(x, c)
835
- return loss
836
-
837
- def forward(self, x, c, *args, **kwargs):
838
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
839
- if self.model.conditioning_key is not None:
840
- assert c is not None
841
- if self.cond_stage_trainable:
842
- c = self.get_learned_conditioning(c)
843
- if self.shorten_cond_schedule: # TODO: drop this option
844
- tc = self.cond_ids[t].to(self.device)
845
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
846
- return self.p_losses(x, c, t, *args, **kwargs)
847
-
848
- def apply_model(self, x_noisy, t, cond, return_ids=False):
849
- if isinstance(cond, dict):
850
- # hybrid case, cond is expected to be a dict
851
- pass
852
- else:
853
- if not isinstance(cond, list):
854
- cond = [cond]
855
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
856
- cond = {key: cond}
857
-
858
- x_recon = self.model(x_noisy, t, **cond)
859
-
860
- if isinstance(x_recon, tuple) and not return_ids:
861
- return x_recon[0]
862
- else:
863
- return x_recon
864
-
865
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
866
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
867
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
868
-
869
- def _prior_bpd(self, x_start):
870
- """
871
- Get the prior KL term for the variational lower-bound, measured in
872
- bits-per-dim.
873
- This term can't be optimized, as it only depends on the encoder.
874
- :param x_start: the [N x C x ...] tensor of inputs.
875
- :return: a batch of [N] KL values (in bits), one per batch element.
876
- """
877
- batch_size = x_start.shape[0]
878
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
879
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
880
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
881
- return mean_flat(kl_prior) / np.log(2.0)
882
-
883
- def p_losses(self, x_start, cond, t, noise=None):
884
- noise = default(noise, lambda: torch.randn_like(x_start))
885
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
886
- model_output = self.apply_model(x_noisy, t, cond)
887
-
888
- loss_dict = {}
889
- prefix = 'train' if self.training else 'val'
890
-
891
- if self.parameterization == "x0":
892
- target = x_start
893
- elif self.parameterization == "eps":
894
- target = noise
895
- elif self.parameterization == "v":
896
- target = self.get_v(x_start, noise, t)
897
- else:
898
- raise NotImplementedError()
899
-
900
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
901
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
902
-
903
- logvar_t = self.logvar[t].to(self.device)
904
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
905
- # loss = loss_simple / torch.exp(self.logvar) + self.logvar
906
- if self.learn_logvar:
907
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
908
- loss_dict.update({'logvar': self.logvar.data.mean()})
909
-
910
- loss = self.l_simple_weight * loss.mean()
911
-
912
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
913
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
914
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
915
- loss += (self.original_elbo_weight * loss_vlb)
916
- loss_dict.update({f'{prefix}/loss': loss})
917
-
918
- return loss, loss_dict
919
-
920
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
921
- return_x0=False, score_corrector=None, corrector_kwargs=None):
922
- t_in = t
923
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
924
-
925
- if score_corrector is not None:
926
- assert self.parameterization == "eps"
927
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
928
-
929
- if return_codebook_ids:
930
- model_out, logits = model_out
931
-
932
- if self.parameterization == "eps":
933
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
934
- elif self.parameterization == "x0":
935
- x_recon = model_out
936
- else:
937
- raise NotImplementedError()
938
-
939
- if clip_denoised:
940
- x_recon.clamp_(-1., 1.)
941
- if quantize_denoised:
942
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
943
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
944
- if return_codebook_ids:
945
- return model_mean, posterior_variance, posterior_log_variance, logits
946
- elif return_x0:
947
- return model_mean, posterior_variance, posterior_log_variance, x_recon
948
- else:
949
- return model_mean, posterior_variance, posterior_log_variance
950
-
951
- @torch.no_grad()
952
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
953
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
954
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
955
- b, *_, device = *x.shape, x.device
956
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
957
- return_codebook_ids=return_codebook_ids,
958
- quantize_denoised=quantize_denoised,
959
- return_x0=return_x0,
960
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
961
- if return_codebook_ids:
962
- raise DeprecationWarning("Support dropped.")
963
- model_mean, _, model_log_variance, logits = outputs
964
- elif return_x0:
965
- model_mean, _, model_log_variance, x0 = outputs
966
- else:
967
- model_mean, _, model_log_variance = outputs
968
-
969
- noise = noise_like(x.shape, device, repeat_noise) * temperature
970
- if noise_dropout > 0.:
971
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
972
- # no noise when t == 0
973
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
974
-
975
- if return_codebook_ids:
976
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
977
- if return_x0:
978
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
979
- else:
980
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
981
-
982
- @torch.no_grad()
983
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
984
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
985
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
986
- log_every_t=None):
987
- if not log_every_t:
988
- log_every_t = self.log_every_t
989
- timesteps = self.num_timesteps
990
- if batch_size is not None:
991
- b = batch_size if batch_size is not None else shape[0]
992
- shape = [batch_size] + list(shape)
993
- else:
994
- b = batch_size = shape[0]
995
- if x_T is None:
996
- img = torch.randn(shape, device=self.device)
997
- else:
998
- img = x_T
999
- intermediates = []
1000
- if cond is not None:
1001
- if isinstance(cond, dict):
1002
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1003
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1004
- else:
1005
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1006
-
1007
- if start_T is not None:
1008
- timesteps = min(timesteps, start_T)
1009
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1010
- total=timesteps) if verbose else reversed(
1011
- range(0, timesteps))
1012
- if type(temperature) == float:
1013
- temperature = [temperature] * timesteps
1014
-
1015
- for i in iterator:
1016
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1017
- if self.shorten_cond_schedule:
1018
- assert self.model.conditioning_key != 'hybrid'
1019
- tc = self.cond_ids[ts].to(cond.device)
1020
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1021
-
1022
- img, x0_partial = self.p_sample(img, cond, ts,
1023
- clip_denoised=self.clip_denoised,
1024
- quantize_denoised=quantize_denoised, return_x0=True,
1025
- temperature=temperature[i], noise_dropout=noise_dropout,
1026
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1027
- if mask is not None:
1028
- assert x0 is not None
1029
- img_orig = self.q_sample(x0, ts)
1030
- img = img_orig * mask + (1. - mask) * img
1031
-
1032
- if i % log_every_t == 0 or i == timesteps - 1:
1033
- intermediates.append(x0_partial)
1034
- if callback: callback(i)
1035
- if img_callback: img_callback(img, i)
1036
- return img, intermediates
1037
-
1038
- @torch.no_grad()
1039
- def p_sample_loop(self, cond, shape, return_intermediates=False,
1040
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1041
- mask=None, x0=None, img_callback=None, start_T=None,
1042
- log_every_t=None):
1043
-
1044
- if not log_every_t:
1045
- log_every_t = self.log_every_t
1046
- device = self.betas.device
1047
- b = shape[0]
1048
- if x_T is None:
1049
- img = torch.randn(shape, device=device)
1050
- else:
1051
- img = x_T
1052
-
1053
- intermediates = [img]
1054
- if timesteps is None:
1055
- timesteps = self.num_timesteps
1056
-
1057
- if start_T is not None:
1058
- timesteps = min(timesteps, start_T)
1059
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1060
- range(0, timesteps))
1061
-
1062
- if mask is not None:
1063
- assert x0 is not None
1064
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1065
-
1066
- for i in iterator:
1067
- ts = torch.full((b,), i, device=device, dtype=torch.long)
1068
- if self.shorten_cond_schedule:
1069
- assert self.model.conditioning_key != 'hybrid'
1070
- tc = self.cond_ids[ts].to(cond.device)
1071
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1072
-
1073
- img = self.p_sample(img, cond, ts,
1074
- clip_denoised=self.clip_denoised,
1075
- quantize_denoised=quantize_denoised)
1076
- if mask is not None:
1077
- img_orig = self.q_sample(x0, ts)
1078
- img = img_orig * mask + (1. - mask) * img
1079
-
1080
- if i % log_every_t == 0 or i == timesteps - 1:
1081
- intermediates.append(img)
1082
- if callback: callback(i)
1083
- if img_callback: img_callback(img, i)
1084
-
1085
- if return_intermediates:
1086
- return img, intermediates
1087
- return img
1088
-
1089
- @torch.no_grad()
1090
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1091
- verbose=True, timesteps=None, quantize_denoised=False,
1092
- mask=None, x0=None, shape=None, **kwargs):
1093
- if shape is None:
1094
- shape = (batch_size, self.channels, self.image_size, self.image_size)
1095
- if cond is not None:
1096
- if isinstance(cond, dict):
1097
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1098
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1099
- else:
1100
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1101
- return self.p_sample_loop(cond,
1102
- shape,
1103
- return_intermediates=return_intermediates, x_T=x_T,
1104
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1105
- mask=mask, x0=x0)
1106
-
1107
- @torch.no_grad()
1108
- def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1109
- if ddim:
1110
- ddim_sampler = DDIMSampler(self)
1111
- shape = (self.channels, self.image_size, self.image_size)
1112
- samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
1113
- shape, cond, verbose=False, **kwargs)
1114
-
1115
- else:
1116
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1117
- return_intermediates=True, **kwargs)
1118
-
1119
- return samples, intermediates
1120
-
1121
- @torch.no_grad()
1122
- def get_unconditional_conditioning(self, batch_size, null_label=None):
1123
- if null_label is not None:
1124
- xc = null_label
1125
- if isinstance(xc, ListConfig):
1126
- xc = list(xc)
1127
- if isinstance(xc, dict) or isinstance(xc, list):
1128
- c = self.get_learned_conditioning(xc)
1129
- else:
1130
- if hasattr(xc, "to"):
1131
- xc = xc.to(self.device)
1132
- c = self.get_learned_conditioning(xc)
1133
- else:
1134
- if self.cond_stage_key in ["class_label", "cls"]:
1135
- xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
1136
- return self.get_learned_conditioning(xc)
1137
- else:
1138
- raise NotImplementedError("todo")
1139
- if isinstance(c, list): # in case the encoder gives us a list
1140
- for i in range(len(c)):
1141
- c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
1142
- else:
1143
- c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
1144
- return c
1145
-
1146
- @torch.no_grad()
1147
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
1148
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1149
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1150
- use_ema_scope=True,
1151
- **kwargs):
1152
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1153
- use_ddim = ddim_steps is not None
1154
-
1155
- log = dict()
1156
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1157
- return_first_stage_outputs=True,
1158
- force_c_encode=True,
1159
- return_original_cond=True,
1160
- bs=N)
1161
- N = min(x.shape[0], N)
1162
- n_row = min(x.shape[0], n_row)
1163
- log["inputs"] = x
1164
- log["reconstruction"] = xrec
1165
- if self.model.conditioning_key is not None:
1166
- if hasattr(self.cond_stage_model, "decode"):
1167
- xc = self.cond_stage_model.decode(c)
1168
- log["conditioning"] = xc
1169
- elif self.cond_stage_key in ["caption", "txt"]:
1170
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1171
- log["conditioning"] = xc
1172
- elif self.cond_stage_key in ['class_label', "cls"]:
1173
- try:
1174
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1175
- log['conditioning'] = xc
1176
- except KeyError:
1177
- # probably no "human_label" in batch
1178
- pass
1179
- elif isimage(xc):
1180
- log["conditioning"] = xc
1181
- if ismap(xc):
1182
- log["original_conditioning"] = self.to_rgb(xc)
1183
-
1184
- if plot_diffusion_rows:
1185
- # get diffusion row
1186
- diffusion_row = list()
1187
- z_start = z[:n_row]
1188
- for t in range(self.num_timesteps):
1189
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1190
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1191
- t = t.to(self.device).long()
1192
- noise = torch.randn_like(z_start)
1193
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1194
- diffusion_row.append(self.decode_first_stage(z_noisy))
1195
-
1196
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1197
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1198
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1199
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1200
- log["diffusion_row"] = diffusion_grid
1201
-
1202
- if sample:
1203
- # get denoise row
1204
- with ema_scope("Sampling"):
1205
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1206
- ddim_steps=ddim_steps, eta=ddim_eta)
1207
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1208
- x_samples = self.decode_first_stage(samples)
1209
- log["samples"] = x_samples
1210
- if plot_denoise_rows:
1211
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1212
- log["denoise_row"] = denoise_grid
1213
-
1214
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1215
- self.first_stage_model, IdentityFirstStage):
1216
- # also display when quantizing x0 while sampling
1217
- with ema_scope("Plotting Quantized Denoised"):
1218
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1219
- ddim_steps=ddim_steps, eta=ddim_eta,
1220
- quantize_denoised=True)
1221
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1222
- # quantize_denoised=True)
1223
- x_samples = self.decode_first_stage(samples.to(self.device))
1224
- log["samples_x0_quantized"] = x_samples
1225
-
1226
- if unconditional_guidance_scale > 1.0:
1227
- uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1228
- if self.model.conditioning_key == "crossattn-adm":
1229
- uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1230
- with ema_scope("Sampling with classifier-free guidance"):
1231
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1232
- ddim_steps=ddim_steps, eta=ddim_eta,
1233
- unconditional_guidance_scale=unconditional_guidance_scale,
1234
- unconditional_conditioning=uc,
1235
- )
1236
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1237
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1238
-
1239
- if inpaint:
1240
- # make a simple center square
1241
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
1242
- mask = torch.ones(N, h, w).to(self.device)
1243
- # zeros will be filled in
1244
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1245
- mask = mask[:, None, ...]
1246
- with ema_scope("Plotting Inpaint"):
1247
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1248
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1249
- x_samples = self.decode_first_stage(samples.to(self.device))
1250
- log["samples_inpainting"] = x_samples
1251
- log["mask"] = mask
1252
-
1253
- # outpaint
1254
- mask = 1. - mask
1255
- with ema_scope("Plotting Outpaint"):
1256
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1257
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1258
- x_samples = self.decode_first_stage(samples.to(self.device))
1259
- log["samples_outpainting"] = x_samples
1260
-
1261
- if plot_progressive_rows:
1262
- with ema_scope("Plotting Progressives"):
1263
- img, progressives = self.progressive_denoising(c,
1264
- shape=(self.channels, self.image_size, self.image_size),
1265
- batch_size=N)
1266
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1267
- log["progressive_row"] = prog_row
1268
-
1269
- if return_keys:
1270
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1271
- return log
1272
- else:
1273
- return {key: log[key] for key in return_keys}
1274
- return log
1275
-
1276
- def configure_optimizers(self):
1277
- lr = self.learning_rate
1278
- params = list(self.model.parameters())
1279
- if self.cond_stage_trainable:
1280
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1281
- params = params + list(self.cond_stage_model.parameters())
1282
- if self.learn_logvar:
1283
- print('Diffusion model optimizing logvar')
1284
- params.append(self.logvar)
1285
- opt = torch.optim.AdamW(params, lr=lr)
1286
- if self.use_scheduler:
1287
- assert 'target' in self.scheduler_config
1288
- scheduler = instantiate_from_config(self.scheduler_config)
1289
-
1290
- print("Setting up LambdaLR scheduler...")
1291
- scheduler = [
1292
- {
1293
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1294
- 'interval': 'step',
1295
- 'frequency': 1
1296
- }]
1297
- return [opt], scheduler
1298
- return opt
1299
-
1300
- @torch.no_grad()
1301
- def to_rgb(self, x):
1302
- x = x.float()
1303
- if not hasattr(self, "colorize"):
1304
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1305
- x = nn.functional.conv2d(x, weight=self.colorize)
1306
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1307
- return x
1308
-
1309
-
1310
- class DiffusionWrapper(pl.LightningModule):
1311
- def __init__(self, diff_model_config, conditioning_key):
1312
- super().__init__()
1313
- self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
1314
- self.diffusion_model = instantiate_from_config(diff_model_config)
1315
- self.conditioning_key = conditioning_key
1316
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
1317
-
1318
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
1319
- if self.conditioning_key is None:
1320
- out = self.diffusion_model(x, t)
1321
- elif self.conditioning_key == 'concat':
1322
- xc = torch.cat([x] + c_concat, dim=1)
1323
- out = self.diffusion_model(xc, t)
1324
- elif self.conditioning_key == 'crossattn':
1325
- if not self.sequential_cross_attn:
1326
- cc = torch.cat(c_crossattn, 1)
1327
- else:
1328
- cc = c_crossattn
1329
- if hasattr(self, "scripted_diffusion_model"):
1330
- # TorchScript changes names of the arguments
1331
- # with argument cc defined as context=cc scripted model will produce
1332
- # an error: RuntimeError: forward() is missing value for argument 'argument_3'.
1333
- out = self.scripted_diffusion_model(x, t, cc)
1334
- else:
1335
- out = self.diffusion_model(x, t, context=cc)
1336
- elif self.conditioning_key == 'hybrid':
1337
- xc = torch.cat([x] + c_concat, dim=1)
1338
- cc = torch.cat(c_crossattn, 1)
1339
- out = self.diffusion_model(xc, t, context=cc)
1340
- elif self.conditioning_key == 'hybrid-adm':
1341
- assert c_adm is not None
1342
- xc = torch.cat([x] + c_concat, dim=1)
1343
- cc = torch.cat(c_crossattn, 1)
1344
- out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1345
- elif self.conditioning_key == 'crossattn-adm':
1346
- assert c_adm is not None
1347
- cc = torch.cat(c_crossattn, 1)
1348
- out = self.diffusion_model(x, t, context=cc, y=c_adm)
1349
- elif self.conditioning_key == 'adm':
1350
- cc = c_crossattn[0]
1351
- out = self.diffusion_model(x, t, y=cc)
1352
- else:
1353
- raise NotImplementedError()
1354
-
1355
- return out
1356
-
1357
-
1358
- class LatentUpscaleDiffusion(LatentDiffusion):
1359
- def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
1360
- super().__init__(*args, **kwargs)
1361
- # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1362
- assert not self.cond_stage_trainable
1363
- self.instantiate_low_stage(low_scale_config)
1364
- self.low_scale_key = low_scale_key
1365
- self.noise_level_key = noise_level_key
1366
-
1367
- def instantiate_low_stage(self, config):
1368
- model = instantiate_from_config(config)
1369
- self.low_scale_model = model.eval()
1370
- self.low_scale_model.train = disabled_train
1371
- for param in self.low_scale_model.parameters():
1372
- param.requires_grad = False
1373
-
1374
- @torch.no_grad()
1375
- def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1376
- if not log_mode:
1377
- z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1378
- else:
1379
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1380
- force_c_encode=True, return_original_cond=True, bs=bs)
1381
- x_low = batch[self.low_scale_key][:bs]
1382
- x_low = rearrange(x_low, 'b h w c -> b c h w')
1383
- x_low = x_low.to(memory_format=torch.contiguous_format).float()
1384
- zx, noise_level = self.low_scale_model(x_low)
1385
- if self.noise_level_key is not None:
1386
- # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
1387
- raise NotImplementedError('TODO')
1388
-
1389
- all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1390
- if log_mode:
1391
- # TODO: maybe disable if too expensive
1392
- x_low_rec = self.low_scale_model.decode(zx)
1393
- return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1394
- return z, all_conds
1395
-
1396
- @torch.no_grad()
1397
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1398
- plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
1399
- unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
1400
- **kwargs):
1401
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1402
- use_ddim = ddim_steps is not None
1403
-
1404
- log = dict()
1405
- z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
1406
- log_mode=True)
1407
- N = min(x.shape[0], N)
1408
- n_row = min(x.shape[0], n_row)
1409
- log["inputs"] = x
1410
- log["reconstruction"] = xrec
1411
- log["x_lr"] = x_low
1412
- log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
1413
- if self.model.conditioning_key is not None:
1414
- if hasattr(self.cond_stage_model, "decode"):
1415
- xc = self.cond_stage_model.decode(c)
1416
- log["conditioning"] = xc
1417
- elif self.cond_stage_key in ["caption", "txt"]:
1418
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1419
- log["conditioning"] = xc
1420
- elif self.cond_stage_key in ['class_label', 'cls']:
1421
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1422
- log['conditioning'] = xc
1423
- elif isimage(xc):
1424
- log["conditioning"] = xc
1425
- if ismap(xc):
1426
- log["original_conditioning"] = self.to_rgb(xc)
1427
-
1428
- if plot_diffusion_rows:
1429
- # get diffusion row
1430
- diffusion_row = list()
1431
- z_start = z[:n_row]
1432
- for t in range(self.num_timesteps):
1433
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1434
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1435
- t = t.to(self.device).long()
1436
- noise = torch.randn_like(z_start)
1437
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1438
- diffusion_row.append(self.decode_first_stage(z_noisy))
1439
-
1440
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1441
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1442
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1443
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1444
- log["diffusion_row"] = diffusion_grid
1445
-
1446
- if sample:
1447
- # get denoise row
1448
- with ema_scope("Sampling"):
1449
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1450
- ddim_steps=ddim_steps, eta=ddim_eta)
1451
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1452
- x_samples = self.decode_first_stage(samples)
1453
- log["samples"] = x_samples
1454
- if plot_denoise_rows:
1455
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1456
- log["denoise_row"] = denoise_grid
1457
-
1458
- if unconditional_guidance_scale > 1.0:
1459
- uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1460
- # TODO explore better "unconditional" choices for the other keys
1461
- # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1462
- uc = dict()
1463
- for k in c:
1464
- if k == "c_crossattn":
1465
- assert isinstance(c[k], list) and len(c[k]) == 1
1466
- uc[k] = [uc_tmp]
1467
- elif k == "c_adm": # todo: only run with text-based guidance?
1468
- assert isinstance(c[k], torch.Tensor)
1469
- #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1470
- uc[k] = c[k]
1471
- elif isinstance(c[k], list):
1472
- uc[k] = [c[k][i] for i in range(len(c[k]))]
1473
- else:
1474
- uc[k] = c[k]
1475
-
1476
- with ema_scope("Sampling with classifier-free guidance"):
1477
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1478
- ddim_steps=ddim_steps, eta=ddim_eta,
1479
- unconditional_guidance_scale=unconditional_guidance_scale,
1480
- unconditional_conditioning=uc,
1481
- )
1482
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1483
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1484
-
1485
- if plot_progressive_rows:
1486
- with ema_scope("Plotting Progressives"):
1487
- img, progressives = self.progressive_denoising(c,
1488
- shape=(self.channels, self.image_size, self.image_size),
1489
- batch_size=N)
1490
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1491
- log["progressive_row"] = prog_row
1492
-
1493
- return log
1494
-
1495
-
1496
- class LatentFinetuneDiffusion(LatentDiffusion):
1497
- """
1498
- Basis for different finetunas, such as inpainting or depth2image
1499
- To disable finetuning mode, set finetune_keys to None
1500
- """
1501
-
1502
- def __init__(self,
1503
- concat_keys: tuple,
1504
- finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
1505
- "model_ema.diffusion_modelinput_blocks00weight"
1506
- ),
1507
- keep_finetune_dims=4,
1508
- # if model was trained without concat mode before and we would like to keep these channels
1509
- c_concat_log_start=None, # to log reconstruction of c_concat codes
1510
- c_concat_log_end=None,
1511
- *args, **kwargs
1512
- ):
1513
- ckpt_path = kwargs.pop("ckpt_path", None)
1514
- ignore_keys = kwargs.pop("ignore_keys", list())
1515
- super().__init__(*args, **kwargs)
1516
- self.finetune_keys = finetune_keys
1517
- self.concat_keys = concat_keys
1518
- self.keep_dims = keep_finetune_dims
1519
- self.c_concat_log_start = c_concat_log_start
1520
- self.c_concat_log_end = c_concat_log_end
1521
- if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
1522
- if exists(ckpt_path):
1523
- self.init_from_ckpt(ckpt_path, ignore_keys)
1524
-
1525
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
1526
- sd = torch.load(path, map_location="cpu")
1527
- if "state_dict" in list(sd.keys()):
1528
- sd = sd["state_dict"]
1529
- keys = list(sd.keys())
1530
- for k in keys:
1531
- for ik in ignore_keys:
1532
- if k.startswith(ik):
1533
- print("Deleting key {} from state_dict.".format(k))
1534
- del sd[k]
1535
-
1536
- # make it explicit, finetune by including extra input channels
1537
- if exists(self.finetune_keys) and k in self.finetune_keys:
1538
- new_entry = None
1539
- for name, param in self.named_parameters():
1540
- if name in self.finetune_keys:
1541
- print(
1542
- f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
1543
- new_entry = torch.zeros_like(param) # zero init
1544
- assert exists(new_entry), 'did not find matching parameter to modify'
1545
- new_entry[:, :self.keep_dims, ...] = sd[k]
1546
- sd[k] = new_entry
1547
-
1548
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
1549
- sd, strict=False)
1550
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
1551
- if len(missing) > 0:
1552
- print(f"Missing Keys: {missing}")
1553
- if len(unexpected) > 0:
1554
- print(f"Unexpected Keys: {unexpected}")
1555
-
1556
- @torch.no_grad()
1557
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1558
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1559
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1560
- use_ema_scope=True,
1561
- **kwargs):
1562
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1563
- use_ddim = ddim_steps is not None
1564
-
1565
- log = dict()
1566
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
1567
- c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
1568
- N = min(x.shape[0], N)
1569
- n_row = min(x.shape[0], n_row)
1570
- log["inputs"] = x
1571
- log["reconstruction"] = xrec
1572
- if self.model.conditioning_key is not None:
1573
- if hasattr(self.cond_stage_model, "decode"):
1574
- xc = self.cond_stage_model.decode(c)
1575
- log["conditioning"] = xc
1576
- elif self.cond_stage_key in ["caption", "txt"]:
1577
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1578
- log["conditioning"] = xc
1579
- elif self.cond_stage_key in ['class_label', 'cls']:
1580
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1581
- log['conditioning'] = xc
1582
- elif isimage(xc):
1583
- log["conditioning"] = xc
1584
- if ismap(xc):
1585
- log["original_conditioning"] = self.to_rgb(xc)
1586
-
1587
- if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
1588
- log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
1589
-
1590
- if plot_diffusion_rows:
1591
- # get diffusion row
1592
- diffusion_row = list()
1593
- z_start = z[:n_row]
1594
- for t in range(self.num_timesteps):
1595
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1596
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1597
- t = t.to(self.device).long()
1598
- noise = torch.randn_like(z_start)
1599
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1600
- diffusion_row.append(self.decode_first_stage(z_noisy))
1601
-
1602
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1603
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1604
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1605
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1606
- log["diffusion_row"] = diffusion_grid
1607
-
1608
- if sample:
1609
- # get denoise row
1610
- with ema_scope("Sampling"):
1611
- samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1612
- batch_size=N, ddim=use_ddim,
1613
- ddim_steps=ddim_steps, eta=ddim_eta)
1614
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1615
- x_samples = self.decode_first_stage(samples)
1616
- log["samples"] = x_samples
1617
- if plot_denoise_rows:
1618
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1619
- log["denoise_row"] = denoise_grid
1620
-
1621
- if unconditional_guidance_scale > 1.0:
1622
- uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1623
- uc_cat = c_cat
1624
- uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
1625
- with ema_scope("Sampling with classifier-free guidance"):
1626
- samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1627
- batch_size=N, ddim=use_ddim,
1628
- ddim_steps=ddim_steps, eta=ddim_eta,
1629
- unconditional_guidance_scale=unconditional_guidance_scale,
1630
- unconditional_conditioning=uc_full,
1631
- )
1632
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1633
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1634
-
1635
- return log
1636
-
1637
-
1638
- class LatentInpaintDiffusion(LatentFinetuneDiffusion):
1639
- """
1640
- can either run as pure inpainting model (only concat mode) or with mixed conditionings,
1641
- e.g. mask as concat and text via cross-attn.
1642
- To disable finetuning mode, set finetune_keys to None
1643
- """
1644
-
1645
- def __init__(self,
1646
- concat_keys=("mask", "masked_image"),
1647
- masked_image_key="masked_image",
1648
- *args, **kwargs
1649
- ):
1650
- super().__init__(concat_keys, *args, **kwargs)
1651
- self.masked_image_key = masked_image_key
1652
- assert self.masked_image_key in concat_keys
1653
-
1654
- @torch.no_grad()
1655
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1656
- # note: restricted to non-trainable encoders currently
1657
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1658
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1659
- force_c_encode=True, return_original_cond=True, bs=bs)
1660
-
1661
- assert exists(self.concat_keys)
1662
- c_cat = list()
1663
- for ck in self.concat_keys:
1664
- cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1665
- if bs is not None:
1666
- cc = cc[:bs]
1667
- cc = cc.to(self.device)
1668
- bchw = z.shape
1669
- if ck != self.masked_image_key:
1670
- cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1671
- else:
1672
- cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1673
- c_cat.append(cc)
1674
- c_cat = torch.cat(c_cat, dim=1)
1675
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1676
- if return_first_stage_outputs:
1677
- return z, all_conds, x, xrec, xc
1678
- return z, all_conds
1679
-
1680
- @torch.no_grad()
1681
- def log_images(self, *args, **kwargs):
1682
- log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
1683
- log["masked_image"] = rearrange(args[0]["masked_image"],
1684
- 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1685
- return log
1686
-
1687
-
1688
- class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
1689
- """
1690
- condition on monocular depth estimation
1691
- """
1692
-
1693
- def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
1694
- super().__init__(concat_keys=concat_keys, *args, **kwargs)
1695
- self.depth_model = instantiate_from_config(depth_stage_config)
1696
- self.depth_stage_key = concat_keys[0]
1697
-
1698
- @torch.no_grad()
1699
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1700
- # note: restricted to non-trainable encoders currently
1701
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
1702
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1703
- force_c_encode=True, return_original_cond=True, bs=bs)
1704
-
1705
- assert exists(self.concat_keys)
1706
- assert len(self.concat_keys) == 1
1707
- c_cat = list()
1708
- for ck in self.concat_keys:
1709
- cc = batch[ck]
1710
- if bs is not None:
1711
- cc = cc[:bs]
1712
- cc = cc.to(self.device)
1713
- cc = self.depth_model(cc)
1714
- cc = torch.nn.functional.interpolate(
1715
- cc,
1716
- size=z.shape[2:],
1717
- mode="bicubic",
1718
- align_corners=False,
1719
- )
1720
-
1721
- depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
1722
- keepdim=True)
1723
- cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
1724
- c_cat.append(cc)
1725
- c_cat = torch.cat(c_cat, dim=1)
1726
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1727
- if return_first_stage_outputs:
1728
- return z, all_conds, x, xrec, xc
1729
- return z, all_conds
1730
-
1731
- @torch.no_grad()
1732
- def log_images(self, *args, **kwargs):
1733
- log = super().log_images(*args, **kwargs)
1734
- depth = self.depth_model(args[0][self.depth_stage_key])
1735
- depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
1736
- torch.amax(depth, dim=[1, 2, 3], keepdim=True)
1737
- log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
1738
- return log
1739
-
1740
-
1741
- class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1742
- """
1743
- condition on low-res image (and optionally on some spatial noise augmentation)
1744
- """
1745
- def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
1746
- low_scale_config=None, low_scale_key=None, *args, **kwargs):
1747
- super().__init__(concat_keys=concat_keys, *args, **kwargs)
1748
- self.reshuffle_patch_size = reshuffle_patch_size
1749
- self.low_scale_model = None
1750
- if low_scale_config is not None:
1751
- print("Initializing a low-scale model")
1752
- assert exists(low_scale_key)
1753
- self.instantiate_low_stage(low_scale_config)
1754
- self.low_scale_key = low_scale_key
1755
-
1756
- def instantiate_low_stage(self, config):
1757
- model = instantiate_from_config(config)
1758
- self.low_scale_model = model.eval()
1759
- self.low_scale_model.train = disabled_train
1760
- for param in self.low_scale_model.parameters():
1761
- param.requires_grad = False
1762
-
1763
- @torch.no_grad()
1764
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1765
- # note: restricted to non-trainable encoders currently
1766
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
1767
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1768
- force_c_encode=True, return_original_cond=True, bs=bs)
1769
-
1770
- assert exists(self.concat_keys)
1771
- assert len(self.concat_keys) == 1
1772
- # optionally make spatial noise_level here
1773
- c_cat = list()
1774
- noise_level = None
1775
- for ck in self.concat_keys:
1776
- cc = batch[ck]
1777
- cc = rearrange(cc, 'b h w c -> b c h w')
1778
- if exists(self.reshuffle_patch_size):
1779
- assert isinstance(self.reshuffle_patch_size, int)
1780
- cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
1781
- p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
1782
- if bs is not None:
1783
- cc = cc[:bs]
1784
- cc = cc.to(self.device)
1785
- if exists(self.low_scale_model) and ck == self.low_scale_key:
1786
- cc, noise_level = self.low_scale_model(cc)
1787
- c_cat.append(cc)
1788
- c_cat = torch.cat(c_cat, dim=1)
1789
- if exists(noise_level):
1790
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
1791
- else:
1792
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1793
- if return_first_stage_outputs:
1794
- return z, all_conds, x, xrec, xc
1795
- return z, all_conds
1796
-
1797
- @torch.no_grad()
1798
- def log_images(self, *args, **kwargs):
1799
- log = super().log_images(*args, **kwargs)
1800
- log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1801
- return log
1802
-
1803
-
1804
- class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
1805
- def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5,
1806
- freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
1807
- super().__init__(*args, **kwargs)
1808
- self.embed_key = embedding_key
1809
- self.embedding_dropout = embedding_dropout
1810
- self._init_embedder(embedder_config, freeze_embedder)
1811
- self._init_noise_aug(noise_aug_config)
1812
-
1813
- def _init_embedder(self, config, freeze=True):
1814
- embedder = instantiate_from_config(config)
1815
- if freeze:
1816
- self.embedder = embedder.eval()
1817
- self.embedder.train = disabled_train
1818
- for param in self.embedder.parameters():
1819
- param.requires_grad = False
1820
-
1821
- def _init_noise_aug(self, config):
1822
- if config is not None:
1823
- # use the KARLO schedule for noise augmentation on CLIP image embeddings
1824
- noise_augmentor = instantiate_from_config(config)
1825
- assert isinstance(noise_augmentor, nn.Module)
1826
- noise_augmentor = noise_augmentor.eval()
1827
- noise_augmentor.train = disabled_train
1828
- self.noise_augmentor = noise_augmentor
1829
- else:
1830
- self.noise_augmentor = None
1831
-
1832
- def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
1833
- outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
1834
- z, c = outputs[0], outputs[1]
1835
- img = batch[self.embed_key][:bs]
1836
- img = rearrange(img, 'b h w c -> b c h w')
1837
- c_adm = self.embedder(img)
1838
- if self.noise_augmentor is not None:
1839
- c_adm, noise_level_emb = self.noise_augmentor(c_adm)
1840
- # assume this gives embeddings of noise levels
1841
- c_adm = torch.cat((c_adm, noise_level_emb), 1)
1842
- if self.training:
1843
- c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
1844
- device=c_adm.device)[:, None]) * c_adm
1845
- all_conds = {"c_crossattn": [c], "c_adm": c_adm}
1846
- noutputs = [z, all_conds]
1847
- noutputs.extend(outputs[2:])
1848
- return noutputs
1849
-
1850
- @torch.no_grad()
1851
- def log_images(self, batch, N=8, n_row=4, **kwargs):
1852
- log = dict()
1853
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
1854
- return_original_cond=True)
1855
- log["inputs"] = x
1856
- log["reconstruction"] = xrec
1857
- assert self.model.conditioning_key is not None
1858
- assert self.cond_stage_key in ["caption", "txt"]
1859
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1860
- log["conditioning"] = xc
1861
- uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
1862
- unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
1863
-
1864
- uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1865
- ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
1866
- with ema_scope(f"Sampling"):
1867
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
1868
- ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
1869
- unconditional_guidance_scale=unconditional_guidance_scale,
1870
- unconditional_conditioning=uc_, )
1871
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1872
- log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1873
- return log