Fabrice-TIERCELIN commited on
Commit
87c1b97
1 Parent(s): f616433

Upload 2 files

Browse files
Files changed (2) hide show
  1. sgm/models/autoencoder.py +335 -0
  2. sgm/models/diffusion.py +320 -0
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import abstractmethod
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from omegaconf import ListConfig
9
+ from packaging import version
10
+ from safetensors.torch import load_file as load_safetensors
11
+
12
+ from ..modules.diffusionmodules.model import Decoder, Encoder
13
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
14
+ from ..modules.ema import LitEma
15
+ from ..util import default, get_obj_from_str, instantiate_from_config
16
+
17
+
18
+ class AbstractAutoencoder(pl.LightningModule):
19
+ """
20
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
21
+ unCLIP models, etc. Hence, it is fairly general, and specific features
22
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ ema_decay: Union[None, float] = None,
28
+ monitor: Union[None, str] = None,
29
+ input_key: str = "jpg",
30
+ ckpt_path: Union[None, str] = None,
31
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
32
+ ):
33
+ super().__init__()
34
+ self.input_key = input_key
35
+ self.use_ema = ema_decay is not None
36
+ if monitor is not None:
37
+ self.monitor = monitor
38
+
39
+ if self.use_ema:
40
+ self.model_ema = LitEma(self, decay=ema_decay)
41
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def init_from_ckpt(
50
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
51
+ ) -> None:
52
+ if path.endswith("ckpt"):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ elif path.endswith("safetensors"):
55
+ sd = load_safetensors(path)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ keys = list(sd.keys())
60
+ for k in keys:
61
+ for ik in ignore_keys:
62
+ if re.match(ik, k):
63
+ print("Deleting key {} from state_dict.".format(k))
64
+ del sd[k]
65
+ missing, unexpected = self.load_state_dict(sd, strict=False)
66
+ print(
67
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
68
+ )
69
+ if len(missing) > 0:
70
+ print(f"Missing Keys: {missing}")
71
+ if len(unexpected) > 0:
72
+ print(f"Unexpected Keys: {unexpected}")
73
+
74
+ @abstractmethod
75
+ def get_input(self, batch) -> Any:
76
+ raise NotImplementedError()
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ # for EMA computation
80
+ if self.use_ema:
81
+ self.model_ema(self)
82
+
83
+ @contextmanager
84
+ def ema_scope(self, context=None):
85
+ if self.use_ema:
86
+ self.model_ema.store(self.parameters())
87
+ self.model_ema.copy_to(self)
88
+ if context is not None:
89
+ print(f"{context}: Switched to EMA weights")
90
+ try:
91
+ yield None
92
+ finally:
93
+ if self.use_ema:
94
+ self.model_ema.restore(self.parameters())
95
+ if context is not None:
96
+ print(f"{context}: Restored training weights")
97
+
98
+ @abstractmethod
99
+ def encode(self, *args, **kwargs) -> torch.Tensor:
100
+ raise NotImplementedError("encode()-method of abstract base class called")
101
+
102
+ @abstractmethod
103
+ def decode(self, *args, **kwargs) -> torch.Tensor:
104
+ raise NotImplementedError("decode()-method of abstract base class called")
105
+
106
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
107
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
108
+ return get_obj_from_str(cfg["target"])(
109
+ params, lr=lr, **cfg.get("params", dict())
110
+ )
111
+
112
+ def configure_optimizers(self) -> Any:
113
+ raise NotImplementedError()
114
+
115
+
116
+ class AutoencodingEngine(AbstractAutoencoder):
117
+ """
118
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
119
+ (we also restore them explicitly as special cases for legacy reasons).
120
+ Regularizations such as KL or VQ are moved to the regularizer class.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *args,
126
+ encoder_config: Dict,
127
+ decoder_config: Dict,
128
+ loss_config: Dict,
129
+ regularizer_config: Dict,
130
+ optimizer_config: Union[Dict, None] = None,
131
+ lr_g_factor: float = 1.0,
132
+ **kwargs,
133
+ ):
134
+ super().__init__(*args, **kwargs)
135
+ # todo: add options to freeze encoder/decoder
136
+ self.encoder = instantiate_from_config(encoder_config)
137
+ self.decoder = instantiate_from_config(decoder_config)
138
+ self.loss = instantiate_from_config(loss_config)
139
+ self.regularization = instantiate_from_config(regularizer_config)
140
+ self.optimizer_config = default(
141
+ optimizer_config, {"target": "torch.optim.Adam"}
142
+ )
143
+ self.lr_g_factor = lr_g_factor
144
+
145
+ def get_input(self, batch: Dict) -> torch.Tensor:
146
+ # assuming unified data format, dataloader returns a dict.
147
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
148
+ return batch[self.input_key]
149
+
150
+ def get_autoencoder_params(self) -> list:
151
+ params = (
152
+ list(self.encoder.parameters())
153
+ + list(self.decoder.parameters())
154
+ + list(self.regularization.get_trainable_parameters())
155
+ + list(self.loss.get_trainable_autoencoder_parameters())
156
+ )
157
+ return params
158
+
159
+ def get_discriminator_params(self) -> list:
160
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
161
+ return params
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.get_last_layer()
165
+
166
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
167
+ z = self.encoder(x)
168
+ z, reg_log = self.regularization(z)
169
+ if return_reg_log:
170
+ return z, reg_log
171
+ return z
172
+
173
+ def decode(self, z: Any) -> torch.Tensor:
174
+ x = self.decoder(z)
175
+ return x
176
+
177
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ z, reg_log = self.encode(x, return_reg_log=True)
179
+ dec = self.decode(z)
180
+ return z, dec, reg_log
181
+
182
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
183
+ x = self.get_input(batch)
184
+ z, xrec, regularization_log = self(x)
185
+
186
+ if optimizer_idx == 0:
187
+ # autoencode
188
+ aeloss, log_dict_ae = self.loss(
189
+ regularization_log,
190
+ x,
191
+ xrec,
192
+ optimizer_idx,
193
+ self.global_step,
194
+ last_layer=self.get_last_layer(),
195
+ split="train",
196
+ )
197
+
198
+ self.log_dict(
199
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
200
+ )
201
+ return aeloss
202
+
203
+ if optimizer_idx == 1:
204
+ # discriminator
205
+ discloss, log_dict_disc = self.loss(
206
+ regularization_log,
207
+ x,
208
+ xrec,
209
+ optimizer_idx,
210
+ self.global_step,
211
+ last_layer=self.get_last_layer(),
212
+ split="train",
213
+ )
214
+ self.log_dict(
215
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
216
+ )
217
+ return discloss
218
+
219
+ def validation_step(self, batch, batch_idx) -> Dict:
220
+ log_dict = self._validation_step(batch, batch_idx)
221
+ with self.ema_scope():
222
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
223
+ log_dict.update(log_dict_ema)
224
+ return log_dict
225
+
226
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
227
+ x = self.get_input(batch)
228
+
229
+ z, xrec, regularization_log = self(x)
230
+ aeloss, log_dict_ae = self.loss(
231
+ regularization_log,
232
+ x,
233
+ xrec,
234
+ 0,
235
+ self.global_step,
236
+ last_layer=self.get_last_layer(),
237
+ split="val" + postfix,
238
+ )
239
+
240
+ discloss, log_dict_disc = self.loss(
241
+ regularization_log,
242
+ x,
243
+ xrec,
244
+ 1,
245
+ self.global_step,
246
+ last_layer=self.get_last_layer(),
247
+ split="val" + postfix,
248
+ )
249
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
250
+ log_dict_ae.update(log_dict_disc)
251
+ self.log_dict(log_dict_ae)
252
+ return log_dict_ae
253
+
254
+ def configure_optimizers(self) -> Any:
255
+ ae_params = self.get_autoencoder_params()
256
+ disc_params = self.get_discriminator_params()
257
+
258
+ opt_ae = self.instantiate_optimizer_from_config(
259
+ ae_params,
260
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
261
+ self.optimizer_config,
262
+ )
263
+ opt_disc = self.instantiate_optimizer_from_config(
264
+ disc_params, self.learning_rate, self.optimizer_config
265
+ )
266
+
267
+ return [opt_ae, opt_disc], []
268
+
269
+ @torch.no_grad()
270
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
271
+ log = dict()
272
+ x = self.get_input(batch)
273
+ _, xrec, _ = self(x)
274
+ log["inputs"] = x
275
+ log["reconstructions"] = xrec
276
+ with self.ema_scope():
277
+ _, xrec_ema, _ = self(x)
278
+ log["reconstructions_ema"] = xrec_ema
279
+ return log
280
+
281
+
282
+ class AutoencoderKL(AutoencodingEngine):
283
+ def __init__(self, embed_dim: int, **kwargs):
284
+ ddconfig = kwargs.pop("ddconfig")
285
+ ckpt_path = kwargs.pop("ckpt_path", None)
286
+ ignore_keys = kwargs.pop("ignore_keys", ())
287
+ super().__init__(
288
+ encoder_config={"target": "torch.nn.Identity"},
289
+ decoder_config={"target": "torch.nn.Identity"},
290
+ regularizer_config={"target": "torch.nn.Identity"},
291
+ loss_config=kwargs.pop("lossconfig"),
292
+ **kwargs,
293
+ )
294
+ assert ddconfig["double_z"]
295
+ self.encoder = Encoder(**ddconfig)
296
+ self.decoder = Decoder(**ddconfig)
297
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
298
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
299
+ self.embed_dim = embed_dim
300
+
301
+ if ckpt_path is not None:
302
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
303
+
304
+ def encode(self, x):
305
+ assert (
306
+ not self.training
307
+ ), f"{self.__class__.__name__} only supports inference currently"
308
+ h = self.encoder(x)
309
+ moments = self.quant_conv(h)
310
+ posterior = DiagonalGaussianDistribution(moments)
311
+ return posterior
312
+
313
+ def decode(self, z, **decoder_kwargs):
314
+ z = self.post_quant_conv(z)
315
+ dec = self.decoder(z, **decoder_kwargs)
316
+ return dec
317
+
318
+
319
+ class AutoencoderKLInferenceWrapper(AutoencoderKL):
320
+ def encode(self, x):
321
+ return super().encode(x).sample()
322
+
323
+
324
+ class IdentityFirstStage(AbstractAutoencoder):
325
+ def __init__(self, *args, **kwargs):
326
+ super().__init__(*args, **kwargs)
327
+
328
+ def get_input(self, x: Any) -> Any:
329
+ return x
330
+
331
+ def encode(self, x: Any, *args, **kwargs) -> Any:
332
+ return x
333
+
334
+ def decode(self, x: Any, *args, **kwargs) -> Any:
335
+ return x
sgm/models/diffusion.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Dict, List, Tuple, Union
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from omegaconf import ListConfig, OmegaConf
7
+ from safetensors.torch import load_file as load_safetensors
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+
10
+ from ..modules import UNCONDITIONAL_CONFIG
11
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
12
+ from ..modules.ema import LitEma
13
+ from ..util import (
14
+ default,
15
+ disabled_train,
16
+ get_obj_from_str,
17
+ instantiate_from_config,
18
+ log_txt_as_img,
19
+ )
20
+
21
+
22
+ class DiffusionEngine(pl.LightningModule):
23
+ def __init__(
24
+ self,
25
+ network_config,
26
+ denoiser_config,
27
+ first_stage_config,
28
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
29
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
30
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
31
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
32
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
33
+ network_wrapper: Union[None, str] = None,
34
+ ckpt_path: Union[None, str] = None,
35
+ use_ema: bool = False,
36
+ ema_decay_rate: float = 0.9999,
37
+ scale_factor: float = 1.0,
38
+ disable_first_stage_autocast=False,
39
+ input_key: str = "jpg",
40
+ log_keys: Union[List, None] = None,
41
+ no_cond_log: bool = False,
42
+ compile_model: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.log_keys = log_keys
46
+ self.input_key = input_key
47
+ self.optimizer_config = default(
48
+ optimizer_config, {"target": "torch.optim.AdamW"}
49
+ )
50
+ model = instantiate_from_config(network_config)
51
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
52
+ model, compile_model=compile_model
53
+ )
54
+
55
+ self.denoiser = instantiate_from_config(denoiser_config)
56
+ self.sampler = (
57
+ instantiate_from_config(sampler_config)
58
+ if sampler_config is not None
59
+ else None
60
+ )
61
+ self.conditioner = instantiate_from_config(
62
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
63
+ )
64
+ self.scheduler_config = scheduler_config
65
+ self._init_first_stage(first_stage_config)
66
+
67
+ self.loss_fn = (
68
+ instantiate_from_config(loss_fn_config)
69
+ if loss_fn_config is not None
70
+ else None
71
+ )
72
+
73
+ self.use_ema = use_ema
74
+ if self.use_ema:
75
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
76
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
77
+
78
+ self.scale_factor = scale_factor
79
+ self.disable_first_stage_autocast = disable_first_stage_autocast
80
+ self.no_cond_log = no_cond_log
81
+
82
+ if ckpt_path is not None:
83
+ self.init_from_ckpt(ckpt_path)
84
+
85
+ def init_from_ckpt(
86
+ self,
87
+ path: str,
88
+ ) -> None:
89
+ if path.endswith("ckpt"):
90
+ sd = torch.load(path, map_location="cpu")["state_dict"]
91
+ elif path.endswith("safetensors"):
92
+ sd = load_safetensors(path)
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ missing, unexpected = self.load_state_dict(sd, strict=False)
97
+ print(
98
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
99
+ )
100
+ if len(missing) > 0:
101
+ print(f"Missing Keys: {missing}")
102
+ if len(unexpected) > 0:
103
+ print(f"Unexpected Keys: {unexpected}")
104
+
105
+ def _init_first_stage(self, config):
106
+ model = instantiate_from_config(config).eval()
107
+ model.train = disabled_train
108
+ for param in model.parameters():
109
+ param.requires_grad = False
110
+ self.first_stage_model = model
111
+
112
+ def get_input(self, batch):
113
+ # assuming unified data format, dataloader returns a dict.
114
+ # image tensors should be scaled to -1 ... 1 and in bchw format
115
+ return batch[self.input_key]
116
+
117
+ @torch.no_grad()
118
+ def decode_first_stage(self, z):
119
+ z = 1.0 / self.scale_factor * z
120
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
121
+ out = self.first_stage_model.decode(z)
122
+ return out
123
+
124
+ @torch.no_grad()
125
+ def encode_first_stage(self, x):
126
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
127
+ z = self.first_stage_model.encode(x)
128
+ z = self.scale_factor * z
129
+ return z
130
+
131
+ def forward(self, x, batch):
132
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
133
+ loss_mean = loss.mean()
134
+ loss_dict = {"loss": loss_mean}
135
+ return loss_mean, loss_dict
136
+
137
+ def shared_step(self, batch: Dict) -> Any:
138
+ x = self.get_input(batch)
139
+ x = self.encode_first_stage(x)
140
+ batch["global_step"] = self.global_step
141
+ loss, loss_dict = self(x, batch)
142
+ return loss, loss_dict
143
+
144
+ def training_step(self, batch, batch_idx):
145
+ loss, loss_dict = self.shared_step(batch)
146
+
147
+ self.log_dict(
148
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
149
+ )
150
+
151
+ self.log(
152
+ "global_step",
153
+ self.global_step,
154
+ prog_bar=True,
155
+ logger=True,
156
+ on_step=True,
157
+ on_epoch=False,
158
+ )
159
+
160
+ # if self.scheduler_config is not None:
161
+ lr = self.optimizers().param_groups[0]["lr"]
162
+ self.log(
163
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
164
+ )
165
+
166
+ return loss
167
+
168
+ def on_train_start(self, *args, **kwargs):
169
+ if self.sampler is None or self.loss_fn is None:
170
+ raise ValueError("Sampler and loss function need to be set for training.")
171
+
172
+ def on_train_batch_end(self, *args, **kwargs):
173
+ if self.use_ema:
174
+ self.model_ema(self.model)
175
+
176
+ @contextmanager
177
+ def ema_scope(self, context=None):
178
+ if self.use_ema:
179
+ self.model_ema.store(self.model.parameters())
180
+ self.model_ema.copy_to(self.model)
181
+ if context is not None:
182
+ print(f"{context}: Switched to EMA weights")
183
+ try:
184
+ yield None
185
+ finally:
186
+ if self.use_ema:
187
+ self.model_ema.restore(self.model.parameters())
188
+ if context is not None:
189
+ print(f"{context}: Restored training weights")
190
+
191
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
192
+ return get_obj_from_str(cfg["target"])(
193
+ params, lr=lr, **cfg.get("params", dict())
194
+ )
195
+
196
+ def configure_optimizers(self):
197
+ lr = self.learning_rate
198
+ params = list(self.model.parameters())
199
+ for embedder in self.conditioner.embedders:
200
+ if embedder.is_trainable:
201
+ params = params + list(embedder.parameters())
202
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
203
+ if self.scheduler_config is not None:
204
+ scheduler = instantiate_from_config(self.scheduler_config)
205
+ print("Setting up LambdaLR scheduler...")
206
+ scheduler = [
207
+ {
208
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
209
+ "interval": "step",
210
+ "frequency": 1,
211
+ }
212
+ ]
213
+ return [opt], scheduler
214
+ return opt
215
+
216
+ @torch.no_grad()
217
+ def sample(
218
+ self,
219
+ cond: Dict,
220
+ uc: Union[Dict, None] = None,
221
+ batch_size: int = 16,
222
+ shape: Union[None, Tuple, List] = None,
223
+ **kwargs,
224
+ ):
225
+ randn = torch.randn(batch_size, *shape).to(self.device)
226
+
227
+ denoiser = lambda input, sigma, c: self.denoiser(
228
+ self.model, input, sigma, c, **kwargs
229
+ )
230
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
231
+ return samples
232
+
233
+ @torch.no_grad()
234
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
235
+ """
236
+ Defines heuristics to log different conditionings.
237
+ These can be lists of strings (text-to-image), tensors, ints, ...
238
+ """
239
+ image_h, image_w = batch[self.input_key].shape[2:]
240
+ log = dict()
241
+
242
+ for embedder in self.conditioner.embedders:
243
+ if (
244
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
245
+ ) and not self.no_cond_log:
246
+ x = batch[embedder.input_key][:n]
247
+ if isinstance(x, torch.Tensor):
248
+ if x.dim() == 1:
249
+ # class-conditional, convert integer to string
250
+ x = [str(x[i].item()) for i in range(x.shape[0])]
251
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
252
+ elif x.dim() == 2:
253
+ # size and crop cond and the like
254
+ x = [
255
+ "x".join([str(xx) for xx in x[i].tolist()])
256
+ for i in range(x.shape[0])
257
+ ]
258
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
259
+ else:
260
+ raise NotImplementedError()
261
+ elif isinstance(x, (List, ListConfig)):
262
+ if isinstance(x[0], str):
263
+ # strings
264
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
265
+ else:
266
+ raise NotImplementedError()
267
+ else:
268
+ raise NotImplementedError()
269
+ log[embedder.input_key] = xc
270
+ return log
271
+
272
+ @torch.no_grad()
273
+ def log_images(
274
+ self,
275
+ batch: Dict,
276
+ N: int = 8,
277
+ sample: bool = True,
278
+ ucg_keys: List[str] = None,
279
+ **kwargs,
280
+ ) -> Dict:
281
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
282
+ if ucg_keys:
283
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
284
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
285
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
286
+ )
287
+ else:
288
+ ucg_keys = conditioner_input_keys
289
+ log = dict()
290
+
291
+ x = self.get_input(batch)
292
+
293
+ c, uc = self.conditioner.get_unconditional_conditioning(
294
+ batch,
295
+ force_uc_zero_embeddings=ucg_keys
296
+ if len(self.conditioner.embedders) > 0
297
+ else [],
298
+ )
299
+
300
+ sampling_kwargs = {}
301
+
302
+ N = min(x.shape[0], N)
303
+ x = x.to(self.device)[:N]
304
+ log["inputs"] = x
305
+ z = self.encode_first_stage(x)
306
+ log["reconstructions"] = self.decode_first_stage(z)
307
+ log.update(self.log_conditionings(batch, N))
308
+
309
+ for k in c:
310
+ if isinstance(c[k], torch.Tensor):
311
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
312
+
313
+ if sample:
314
+ with self.ema_scope("Plotting"):
315
+ samples = self.sample(
316
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
317
+ )
318
+ samples = self.decode_first_stage(samples)
319
+ log["samples"] = samples
320
+ return log