import torch # import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union from ldm_patched.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm_patched.ldm.util import instantiate_from_config from ldm_patched.ldm.modules.ema import LitEma import ldm_patched.modules.ops class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = True): super().__init__() self.sample = sample def get_trainable_parameters(self) -> Any: yield from () def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: log = dict() posterior = DiagonalGaussianDistribution(z) if self.sample: z = posterior.sample() else: z = posterior.mode() kl_loss = posterior.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] log["kl_loss"] = kl_loss return z, log class AbstractAutoencoder(torch.nn.Module): """ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, unCLIP models, etc. Hence, it is fairly general, and specific features (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. """ def __init__( self, ema_decay: Union[None, float] = None, monitor: Union[None, str] = None, input_key: str = "jpg", **kwargs, ): super().__init__() self.input_key = input_key self.use_ema = ema_decay is not None if monitor is not None: self.monitor = monitor if self.use_ema: self.model_ema = LitEma(self, decay=ema_decay) logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") def get_input(self, batch) -> Any: raise NotImplementedError() def on_train_batch_end(self, *args, **kwargs): # for EMA computation if self.use_ema: self.model_ema(self) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: logpy.info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: logpy.info(f"{context}: Restored training weights") def encode(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("encode()-method of abstract base class called") def decode(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("decode()-method of abstract base class called") def instantiate_optimizer_from_config(self, params, lr, cfg): logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") return get_obj_from_str(cfg["target"])( params, lr=lr, **cfg.get("params", dict()) ) def configure_optimizers(self) -> Any: raise NotImplementedError() class AutoencodingEngine(AbstractAutoencoder): """ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL (we also restore them explicitly as special cases for legacy reasons). Regularizations such as KL or VQ are moved to the regularizer class. """ def __init__( self, *args, encoder_config: Dict, decoder_config: Dict, regularizer_config: Dict, **kwargs, ): super().__init__(*args, **kwargs) self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) self.regularization: AbstractRegularizer = instantiate_from_config( regularizer_config ) def get_last_layer(self): return self.decoder.get_last_layer() def encode( self, x: torch.Tensor, return_reg_log: bool = False, unregularized: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: z = self.encoder(x) if unregularized: return z, dict() z, reg_log = self.regularization(z) if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: x = self.decoder(z, **kwargs) return x def forward( self, x: torch.Tensor, **additional_decode_kwargs ) -> Tuple[torch.Tensor, torch.Tensor, dict]: z, reg_log = self.encode(x, return_reg_log=True) dec = self.decode(z, **additional_decode_kwargs) return z, dec, reg_log class AutoencodingEngineLegacy(AutoencodingEngine): def __init__(self, embed_dim: int, **kwargs): self.max_batch_size = kwargs.pop("max_batch_size", None) ddconfig = kwargs.pop("ddconfig") super().__init__( encoder_config={ "target": "ldm_patched.ldm.modules.diffusionmodules.model.Encoder", "params": ddconfig, }, decoder_config={ "target": "ldm_patched.ldm.modules.diffusionmodules.model.Decoder", "params": ddconfig, }, **kwargs, ) self.quant_conv = ldm_patched.modules.ops.disable_weight_init.Conv2d( (1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * embed_dim, 1, ) self.post_quant_conv = ldm_patched.modules.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim def get_autoencoder_params(self) -> list: params = super().get_autoencoder_params() return params def encode( self, x: torch.Tensor, return_reg_log: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.max_batch_size is None: z = self.encoder(x) z = self.quant_conv(z) else: N = x.shape[0] bs = self.max_batch_size n_batches = int(math.ceil(N / bs)) z = list() for i_batch in range(n_batches): z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) z_batch = self.quant_conv(z_batch) z.append(z_batch) z = torch.cat(z, 0) z, reg_log = self.regularization(z) if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) else: N = z.shape[0] bs = self.max_batch_size n_batches = int(math.ceil(N / bs)) dec = list() for i_batch in range(n_batches): dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) dec_batch = self.decoder(dec_batch, **decoder_kwargs) dec.append(dec_batch) dec = torch.cat(dec, 0) return dec class AutoencoderKL(AutoencodingEngineLegacy): def __init__(self, **kwargs): if "lossconfig" in kwargs: kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( regularizer_config={ "target": ( "ldm_patched.ldm.models.autoencoder.DiagonalGaussianRegularizer" ) }, **kwargs, )