import numpy as np import einops import torch import torch as th import torch.nn as nn from torch.nn import functional as thf import pytorch_lightning as pl import torchvision from copy import deepcopy from ldm.modules.diffusionmodules.util import ( conv_nd, linear, zero_module, timestep_embedding, ) from contextlib import contextmanager, nullcontext from einops import rearrange, repeat from torchvision.utils import make_grid from ldm.modules.attention import SpatialTransformer from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.util import log_txt_as_img, exists, instantiate_from_config, default from ldm.models.diffusion.ddim import DDIMSampler from ldm.modules.ema import LitEma from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution from ldm.modules.diffusionmodules.model import Encoder import lpips import kornia from kornia import color def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self class View(nn.Module): def __init__(self, *shape): super().__init__() self.shape = shape def forward(self, x): return x.view(*self.shape) class SecretEncoder3(nn.Module): def __init__(self, secret_len, base_res=16, resolution=64) -> None: super().__init__() log_resolution = int(np.log2(resolution)) log_base = int(np.log2(base_res)) self.secret_len = secret_len self.secret_scaler = nn.Sequential( nn.Linear(secret_len, base_res*base_res*3), nn.SiLU(), View(-1, 3, base_res, base_res), nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256 zero_module(conv_nd(2, 3, 3, 3, padding=1)) ) # secret len -> ch x res x res def copy_encoder_weight(self, ae_model): # misses, ignores = self.load_state_dict(ae_state_dict, strict=False) return None def encode(self, x): x = self.secret_scaler(x) return x def forward(self, x, c): # x: [B, C, H, W], c: [B, secret_len] c = self.encode(c) return c, None class SecretEncoder4(nn.Module): """same as SecretEncoder3 but with ch as input""" def __init__(self, secret_len, ch=3, base_res=16, resolution=64) -> None: super().__init__() log_resolution = int(np.log2(resolution)) log_base = int(np.log2(base_res)) self.secret_len = secret_len self.secret_scaler = nn.Sequential( nn.Linear(secret_len, base_res*base_res*ch), nn.SiLU(), View(-1, ch, base_res, base_res), nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256 zero_module(conv_nd(2, ch, ch, 3, padding=1)) ) # secret len -> ch x res x res def copy_encoder_weight(self, ae_model): # misses, ignores = self.load_state_dict(ae_state_dict, strict=False) return None def encode(self, x): x = self.secret_scaler(x) return x def forward(self, x, c): # x: [B, C, H, W], c: [B, secret_len] c = self.encode(c) return c, None class SecretEncoder6(nn.Module): """join img emb with secret emb""" def __init__(self, secret_len, ch=3, base_res=16, resolution=64, emode='c3') -> None: super().__init__() assert emode in ['c3', 'c2', 'm3'] if emode == 'c3': # c3: concat c and x each has ch channels secret_ch = ch join_ch = 2*ch elif emode == 'c2': # c2: concat c (2) and x ave (1) secret_ch = 2 join_ch = ch elif emode == 'm3': # m3: multiply c (ch) and x (ch) secret_ch = ch join_ch = ch # m3: multiply c (ch) and x ave (1) log_resolution = int(np.log2(resolution)) log_base = int(np.log2(base_res)) self.secret_len = secret_len self.emode = emode self.resolution = resolution self.secret_scaler = nn.Sequential( nn.Linear(secret_len, base_res*base_res*secret_ch), nn.SiLU(), View(-1, secret_ch, base_res, base_res), nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256 ) # secret len -> ch x res x res self.join_encoder = nn.Sequential( conv_nd(2, join_ch, join_ch, 3, padding=1), nn.SiLU(), conv_nd(2, join_ch, ch, 3, padding=1), nn.SiLU(), conv_nd(2, ch, ch, 3, padding=1), nn.SiLU() ) self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1)) def copy_encoder_weight(self, ae_model): # misses, ignores = self.load_state_dict(ae_state_dict, strict=False) return None def encode(self, x): x = self.secret_scaler(x) return x def forward(self, x, c): # x: [B, C, H, W], c: [B, secret_len] c = self.encode(c) if self.emode == 'c3': x = torch.cat([x, c], dim=1) elif self.emode == 'c2': x = torch.cat([x.mean(dim=1, keepdim=True), c], dim=1) elif self.emode == 'm3': x = x * c dx = self.join_encoder(x) dx = self.out_layer(dx) return dx, None class SecretEncoder5(nn.Module): """same as SecretEncoder3 but with ch as input""" def __init__(self, secret_len, ch=3, base_res=16, resolution=64, joint=False) -> None: super().__init__() log_resolution = int(np.log2(resolution)) log_base = int(np.log2(base_res)) self.secret_len = secret_len self.joint = joint self.resolution = resolution self.secret_scaler = nn.Sequential( nn.Linear(secret_len, base_res*base_res*ch), nn.SiLU(), View(-1, ch, base_res, base_res), nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256 ) # secret len -> ch x res x res if joint: self.join_encoder = nn.Sequential( conv_nd(2, 2*ch, 2*ch, 3, padding=1), nn.SiLU(), conv_nd(2, 2*ch, ch, 3, padding=1), nn.SiLU() ) self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1)) def copy_encoder_weight(self, ae_model): # misses, ignores = self.load_state_dict(ae_state_dict, strict=False) return None def encode(self, x): x = self.secret_scaler(x) return x def forward(self, x, c): # x: [B, C, H, W], c: [B, secret_len] c = self.encode(c) if self.joint: x = thf.interpolate(x, size=(self.resolution, self.resolution), mode="bilinear", align_corners=False, antialias=True) c = self.join_encoder(torch.cat([x, c], dim=1)) c = self.out_layer(c) return c, None class SecretEncoder2(nn.Module): def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, ema_decay=None, learn_logvar=False) -> None: super().__init__() log_resolution = int(np.log2(ddconfig.resolution)) self.secret_len = secret_len self.learn_logvar = learn_logvar self.image_key = image_key self.encoder = Encoder(**ddconfig) self.encoder.conv_out = zero_module(self.encoder.conv_out) self.embed_dim = embed_dim if colorize_nlabels is not None: assert type(colorize_nlabels)==int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor self.secret_scaler = nn.Sequential( nn.Linear(secret_len, 32*32*ddconfig.out_ch), nn.SiLU(), View(-1, ddconfig.out_ch, 32, 32), nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256 # zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) ) # secret len -> ch x res x res # out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1) # self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) self.use_ema = ema_decay is not None if self.use_ema: self.ema_decay = ema_decay assert 0. < ema_decay < 1. self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] misses, ignores = self.load_state_dict(sd, strict=False) print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}") def copy_encoder_weight(self, ae_model): # misses, ignores = self.load_state_dict(ae_state_dict, strict=False) return None self.encoder.load_state_dict(ae_model.encoder.state_dict()) self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict()) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: print(f"{context}: Restored training weights") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self) def encode(self, x): h = self.encoder(x) posterior = h return posterior def forward(self, x, c): # x: [B, C, H, W], c: [B, secret_len] c = self.secret_scaler(c) x = torch.cat([x, c], dim=1) z = self.encode(x) # z = self.out_layer(z) return z, None class SecretEncoder7(nn.Module): def __init__(self, secret_len, ddconfig, ckpt_path=None, ignore_keys=[],embed_dim=3, ema_decay=None) -> None: super().__init__() log_resolution = int(np.log2(ddconfig.resolution)) self.secret_len = secret_len self.encoder = Encoder(**ddconfig) # self.encoder.conv_out = zero_module(self.encoder.conv_out) self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) self.secret_scaler = nn.Sequential( nn.Linear(secret_len, 32*32*2), nn.SiLU(), View(-1, 2, 32, 32), # nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256 # zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) ) # secret len -> ch x res x res # out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1) # self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) self.use_ema = ema_decay is not None if self.use_ema: self.ema_decay = ema_decay assert 0. < ema_decay < 1. self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] misses, ignores = self.load_state_dict(sd, strict=False) print(f"[SecretEncoder7] Restored from {path}, misses: {len(misses)}, ignores: {len(ignores)}. Do not worry as we are not using the decoder and the secret encoder is a novel module.") def copy_encoder_weight(self, ae_model): # misses, ignores = self.load_state_dict(ae_state_dict, strict=False) # return None self.encoder.load_state_dict(deepcopy(ae_model.encoder.state_dict())) self.quant_conv.load_state_dict(deepcopy(ae_model.quant_conv.state_dict())) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: print(f"{context}: Restored training weights") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self) def encode(self, x): h = self.encoder(x) h = self.quant_conv(h) return h def forward(self, x, c): # x: [B, C, H, W], c: [B, secret_len] c = self.secret_scaler(c) # [B, 2, 32, 32] # c = thf.interpolate(c, size=x.shape[-2:], mode="bilinear", align_corners=False) c = thf.interpolate(c, size=x.shape[-2:], mode="nearest") x = 0.2125 * x[:,0,...] + 0.7154 *x[:,1,...] + 0.0721 * x[:,2,...] x = torch.cat([x.unsqueeze(1), c], dim=1) z = self.encode(x) # z = self.out_layer(z) return z, None class SecretEncoder(nn.Module): def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, ema_decay=None, learn_logvar=False) -> None: super().__init__() log_resolution = int(np.log2(ddconfig.resolution)) self.secret_len = secret_len self.learn_logvar = learn_logvar self.image_key = image_key self.encoder = Encoder(**ddconfig) assert ddconfig["double_z"] self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.embed_dim = embed_dim if colorize_nlabels is not None: assert type(colorize_nlabels)==int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor self.use_ema = ema_decay is not None if self.use_ema: self.ema_decay = ema_decay assert 0. < ema_decay < 1. self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.secret_scaler = nn.Sequential( nn.Linear(secret_len, 32*32*ddconfig.out_ch), nn.SiLU(), View(-1, ddconfig.out_ch, 32, 32), nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256 zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) ) # secret len -> ch x res x res # out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1) self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] misses, ignores = self.load_state_dict(sd, strict=False) print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}") def copy_encoder_weight(self, ae_model): # misses, ignores = self.load_state_dict(ae_state_dict, strict=False) self.encoder.load_state_dict(ae_model.encoder.state_dict()) self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict()) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: print(f"{context}: Restored training weights") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self) def encode(self, x): h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def forward(self, x, c): # x: [B, C, H, W], c: [B, secret_len] c = self.secret_scaler(c) x = x + c posterior = self.encode(x) z = posterior.sample() z = self.out_layer(z) return z, posterior class ControlAE(pl.LightningModule): def __init__(self, first_stage_key, first_stage_config, control_key, control_config, decoder_config, loss_config, noise_config='__none__', use_ema=False, secret_warmup=False, scale_factor=1., ckpt_path="__none__", ): super().__init__() self.scale_factor = scale_factor self.control_key = control_key self.first_stage_key = first_stage_key self.ae = instantiate_from_config(first_stage_config) self.control = instantiate_from_config(control_config) self.decoder = instantiate_from_config(decoder_config) self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase if noise_config != '__none__': print('Using noise') self.noise = instantiate_from_config(noise_config) # copy weights from first stage self.control.copy_encoder_weight(self.ae) # freeze first stage self.ae.eval() self.ae.train = disabled_train for p in self.ae.parameters(): p.requires_grad = False self.loss_layer = instantiate_from_config(loss_config) # early training phase # self.fixed_input = True self.fixed_x = None self.fixed_img = None self.fixed_input_recon = None self.fixed_control = None self.register_buffer("fixed_input", torch.tensor(True)) # secret warmup self.secret_warmup = secret_warmup self.secret_baselen = 2 self.secret_len = control_config.params.secret_len if self.secret_warmup: assert self.secret_len == 2**(int(np.log2(self.secret_len))) self.use_ema = use_ema if self.use_ema: print('Using EMA') self.control_ema = LitEma(self.control) self.decoder_ema = LitEma(self.decoder) print(f"Keeping EMAs of {len(list(self.control_ema.buffers()) + list(self.decoder_ema.buffers()))}.") if ckpt_path != '__none__': self.init_from_ckpt(ckpt_path, ignore_keys=[]) def get_warmup_secret(self, old_secret): # old_secret: [B, secret_len] # new_secret: [B, secret_len] if self.secret_warmup: bsz = old_secret.shape[0] nrepeats = self.secret_len // self.secret_baselen new_secret = torch.zeros((bsz, self.secret_baselen), dtype=torch.float).random_(0, 2).repeat_interleave(nrepeats, dim=1) return new_secret.to(old_secret.device) else: return old_secret def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] self.load_state_dict(sd, strict=False) print(f"Restored from {path}") @contextmanager def ema_scope(self, context=None): if self.use_ema: self.control_ema.store(self.control.parameters()) self.decoder_ema.store(self.decoder.parameters()) self.control_ema.copy_to(self.control) self.decoder_ema.copy_to(self.decoder) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.control_ema.restore(self.control.parameters()) self.decoder_ema.restore(self.decoder.parameters()) if context is not None: print(f"{context}: Restored training weights") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.control_ema(self.control) self.decoder_ema(self.decoder) def compute_loss(self, pred, target): # return thf.mse_loss(pred, target, reduction="none").mean(dim=(1, 2, 3)) lpips_loss = self.lpips_loss(pred, target).mean(dim=[1,2,3]) pred_yuv = color.rgb_to_yuv((pred + 1) / 2) target_yuv = color.rgb_to_yuv((target + 1) / 2) yuv_loss = torch.mean((pred_yuv - target_yuv)**2, dim=[2,3]) yuv_loss = 1.5*torch.mm(yuv_loss, self.yuv_scales).squeeze(1) return lpips_loss + yuv_loss def forward(self, x, image, c): if self.control.__class__.__name__ == 'SecretEncoder6': eps, posterior = self.control(x, c) else: eps, posterior = self.control(image, c) return x + eps, posterior @torch.no_grad() def get_input(self, batch, return_first_stage=False, bs=None): image = batch[self.first_stage_key] control = batch[self.control_key] control = self.get_warmup_secret(control) if bs is not None: image = image[:bs] control = control[:bs] else: bs = image.shape[0] # encode image 1st stage image = einops.rearrange(image, "b h w c -> b c h w").contiguous() x = self.encode_first_stage(image).detach() image_rec = self.decode_first_stage(x).detach() # check if using fixed input (early training phase) # if self.training and self.fixed_input: if self.fixed_input: if self.fixed_x is None: # first iteration print('[TRAINING] Warmup - using fixed input image for now!') self.fixed_x = x.detach().clone()[:bs] self.fixed_img = image.detach().clone()[:bs] self.fixed_input_recon = image_rec.detach().clone()[:bs] self.fixed_control = control.detach().clone()[:bs] # use for log_images with fixed_input option only x, image, image_rec = self.fixed_x, self.fixed_img, self.fixed_input_recon out = [x, control] if return_first_stage: out.extend([image, image_rec]) return out def decode_first_stage(self, z): z = 1./self.scale_factor * z image_rec = self.ae.decode(z) return image_rec def encode_first_stage(self, image): encoder_posterior = self.ae.encode(image) if isinstance(encoder_posterior, DiagonalGaussianDistribution): z = encoder_posterior.sample() elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") return self.scale_factor * z def shared_step(self, batch): x, c, img, _ = self.get_input(batch, return_first_stage=True) # import pdb; pdb.set_trace() x, posterior = self(x, img, c) image_rec = self.decode_first_stage(x) # resize if img.shape[-1] > 256: img = thf.interpolate(img, size=(256, 256), mode='bilinear', align_corners=False).detach() image_rec = thf.interpolate(image_rec, size=(256, 256), mode='bilinear', align_corners=False) if hasattr(self, 'noise') and self.noise.is_activated(): image_rec_noised = self.noise(image_rec, self.global_step, p=0.9) else: image_rec_noised = self.crop(image_rec) # center crop image_rec_noised = torch.clamp(image_rec_noised, -1, 1) pred = self.decoder(image_rec_noised) loss, loss_dict = self.loss_layer(img, image_rec, posterior, c, pred, self.global_step) bit_acc = loss_dict["bit_acc"] bit_acc_ = bit_acc.item() if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated(): self.loss_layer.activate_ramp(self.global_step) if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage if hasattr(self, 'noise') and (not self.noise.is_activated()): self.noise.activate(self.global_step) if (bit_acc_ > 0.9) and self.fixed_input: # execute only once print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.') self.fixed_input = ~self.fixed_input return loss, loss_dict def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) # if self.use_scheduler: # lr = self.optimizers().param_groups[0]['lr'] # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss @torch.no_grad() def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'} with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @torch.no_grad() def log_images(self, batch, fixed_input=False, **kwargs): log = dict() if fixed_input and self.fixed_img is not None: x, c, img, img_recon = self.fixed_x, self.fixed_control, self.fixed_img, self.fixed_input_recon else: x, c, img, img_recon = self.get_input(batch, return_first_stage=True) x, _ = self(x, img, c) image_out = self.decode_first_stage(x) if hasattr(self, 'noise') and self.noise.is_activated(): img_noise = self.noise(image_out, self.global_step, p=1.0) log['noised'] = img_noise log['input'] = img log['output'] = image_out log['recon'] = img_recon return log def configure_optimizers(self): lr = self.learning_rate params = list(self.control.parameters()) + list(self.decoder.parameters()) optimizer = torch.optim.AdamW(params, lr=lr) return optimizer