test / cldm /ae.py
Tu Bui
first commit
6142a25
import numpy as np
import einops
import torch
import torch as th
import torch.nn as nn
from torch.nn import functional as thf
import pytorch_lightning as pl
import torchvision
from copy import deepcopy
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config, default
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.modules.diffusionmodules.model import Encoder
import lpips
import kornia
from kornia import color
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class View(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
class SecretEncoder3(nn.Module):
def __init__(self, secret_len, base_res=16, resolution=64) -> None:
super().__init__()
log_resolution = int(np.log2(resolution))
log_base = int(np.log2(base_res))
self.secret_len = secret_len
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, base_res*base_res*3),
nn.SiLU(),
View(-1, 3, base_res, base_res),
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
zero_module(conv_nd(2, 3, 3, 3, padding=1))
) # secret len -> ch x res x res
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
def encode(self, x):
x = self.secret_scaler(x)
return x
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.encode(c)
return c, None
class SecretEncoder4(nn.Module):
"""same as SecretEncoder3 but with ch as input"""
def __init__(self, secret_len, ch=3, base_res=16, resolution=64) -> None:
super().__init__()
log_resolution = int(np.log2(resolution))
log_base = int(np.log2(base_res))
self.secret_len = secret_len
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, base_res*base_res*ch),
nn.SiLU(),
View(-1, ch, base_res, base_res),
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
zero_module(conv_nd(2, ch, ch, 3, padding=1))
) # secret len -> ch x res x res
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
def encode(self, x):
x = self.secret_scaler(x)
return x
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.encode(c)
return c, None
class SecretEncoder6(nn.Module):
"""join img emb with secret emb"""
def __init__(self, secret_len, ch=3, base_res=16, resolution=64, emode='c3') -> None:
super().__init__()
assert emode in ['c3', 'c2', 'm3']
if emode == 'c3': # c3: concat c and x each has ch channels
secret_ch = ch
join_ch = 2*ch
elif emode == 'c2': # c2: concat c (2) and x ave (1)
secret_ch = 2
join_ch = ch
elif emode == 'm3': # m3: multiply c (ch) and x (ch)
secret_ch = ch
join_ch = ch
# m3: multiply c (ch) and x ave (1)
log_resolution = int(np.log2(resolution))
log_base = int(np.log2(base_res))
self.secret_len = secret_len
self.emode = emode
self.resolution = resolution
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, base_res*base_res*secret_ch),
nn.SiLU(),
View(-1, secret_ch, base_res, base_res),
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
) # secret len -> ch x res x res
self.join_encoder = nn.Sequential(
conv_nd(2, join_ch, join_ch, 3, padding=1),
nn.SiLU(),
conv_nd(2, join_ch, ch, 3, padding=1),
nn.SiLU(),
conv_nd(2, ch, ch, 3, padding=1),
nn.SiLU()
)
self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1))
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
def encode(self, x):
x = self.secret_scaler(x)
return x
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.encode(c)
if self.emode == 'c3':
x = torch.cat([x, c], dim=1)
elif self.emode == 'c2':
x = torch.cat([x.mean(dim=1, keepdim=True), c], dim=1)
elif self.emode == 'm3':
x = x * c
dx = self.join_encoder(x)
dx = self.out_layer(dx)
return dx, None
class SecretEncoder5(nn.Module):
"""same as SecretEncoder3 but with ch as input"""
def __init__(self, secret_len, ch=3, base_res=16, resolution=64, joint=False) -> None:
super().__init__()
log_resolution = int(np.log2(resolution))
log_base = int(np.log2(base_res))
self.secret_len = secret_len
self.joint = joint
self.resolution = resolution
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, base_res*base_res*ch),
nn.SiLU(),
View(-1, ch, base_res, base_res),
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
) # secret len -> ch x res x res
if joint:
self.join_encoder = nn.Sequential(
conv_nd(2, 2*ch, 2*ch, 3, padding=1),
nn.SiLU(),
conv_nd(2, 2*ch, ch, 3, padding=1),
nn.SiLU()
)
self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1))
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
def encode(self, x):
x = self.secret_scaler(x)
return x
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.encode(c)
if self.joint:
x = thf.interpolate(x, size=(self.resolution, self.resolution), mode="bilinear", align_corners=False, antialias=True)
c = self.join_encoder(torch.cat([x, c], dim=1))
c = self.out_layer(c)
return c, None
class SecretEncoder2(nn.Module):
def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False) -> None:
super().__init__()
log_resolution = int(np.log2(ddconfig.resolution))
self.secret_len = secret_len
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.encoder.conv_out = zero_module(self.encoder.conv_out)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, 32*32*ddconfig.out_ch),
nn.SiLU(),
View(-1, ddconfig.out_ch, 32, 32),
nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
# zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
) # secret len -> ch x res x res
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
# self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
misses, ignores = self.load_state_dict(sd, strict=False)
print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}")
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
self.encoder.load_state_dict(ae_model.encoder.state_dict())
self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict())
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
posterior = h
return posterior
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.secret_scaler(c)
x = torch.cat([x, c], dim=1)
z = self.encode(x)
# z = self.out_layer(z)
return z, None
class SecretEncoder7(nn.Module):
def __init__(self, secret_len, ddconfig, ckpt_path=None,
ignore_keys=[],embed_dim=3,
ema_decay=None) -> None:
super().__init__()
log_resolution = int(np.log2(ddconfig.resolution))
self.secret_len = secret_len
self.encoder = Encoder(**ddconfig)
# self.encoder.conv_out = zero_module(self.encoder.conv_out)
self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, 32*32*2),
nn.SiLU(),
View(-1, 2, 32, 32),
# nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
# zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
) # secret len -> ch x res x res
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
# self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
misses, ignores = self.load_state_dict(sd, strict=False)
print(f"[SecretEncoder7] Restored from {path}, misses: {len(misses)}, ignores: {len(ignores)}. Do not worry as we are not using the decoder and the secret encoder is a novel module.")
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
# return None
self.encoder.load_state_dict(deepcopy(ae_model.encoder.state_dict()))
self.quant_conv.load_state_dict(deepcopy(ae_model.quant_conv.state_dict()))
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.secret_scaler(c) # [B, 2, 32, 32]
# c = thf.interpolate(c, size=x.shape[-2:], mode="bilinear", align_corners=False)
c = thf.interpolate(c, size=x.shape[-2:], mode="nearest")
x = 0.2125 * x[:,0,...] + 0.7154 *x[:,1,...] + 0.0721 * x[:,2,...]
x = torch.cat([x.unsqueeze(1), c], dim=1)
z = self.encode(x)
# z = self.out_layer(z)
return z, None
class SecretEncoder(nn.Module):
def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False) -> None:
super().__init__()
log_resolution = int(np.log2(ddconfig.resolution))
self.secret_len = secret_len
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, 32*32*ddconfig.out_ch),
nn.SiLU(),
View(-1, ddconfig.out_ch, 32, 32),
nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
) # secret len -> ch x res x res
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
misses, ignores = self.load_state_dict(sd, strict=False)
print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}")
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
self.encoder.load_state_dict(ae_model.encoder.state_dict())
self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict())
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.secret_scaler(c)
x = x + c
posterior = self.encode(x)
z = posterior.sample()
z = self.out_layer(z)
return z, posterior
class ControlAE(pl.LightningModule):
def __init__(self,
first_stage_key,
first_stage_config,
control_key,
control_config,
decoder_config,
loss_config,
noise_config='__none__',
use_ema=False,
secret_warmup=False,
scale_factor=1.,
ckpt_path="__none__",
):
super().__init__()
self.scale_factor = scale_factor
self.control_key = control_key
self.first_stage_key = first_stage_key
self.ae = instantiate_from_config(first_stage_config)
self.control = instantiate_from_config(control_config)
self.decoder = instantiate_from_config(decoder_config)
self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase
if noise_config != '__none__':
print('Using noise')
self.noise = instantiate_from_config(noise_config)
# copy weights from first stage
self.control.copy_encoder_weight(self.ae)
# freeze first stage
self.ae.eval()
self.ae.train = disabled_train
for p in self.ae.parameters():
p.requires_grad = False
self.loss_layer = instantiate_from_config(loss_config)
# early training phase
# self.fixed_input = True
self.fixed_x = None
self.fixed_img = None
self.fixed_input_recon = None
self.fixed_control = None
self.register_buffer("fixed_input", torch.tensor(True))
# secret warmup
self.secret_warmup = secret_warmup
self.secret_baselen = 2
self.secret_len = control_config.params.secret_len
if self.secret_warmup:
assert self.secret_len == 2**(int(np.log2(self.secret_len)))
self.use_ema = use_ema
if self.use_ema:
print('Using EMA')
self.control_ema = LitEma(self.control)
self.decoder_ema = LitEma(self.decoder)
print(f"Keeping EMAs of {len(list(self.control_ema.buffers()) + list(self.decoder_ema.buffers()))}.")
if ckpt_path != '__none__':
self.init_from_ckpt(ckpt_path, ignore_keys=[])
def get_warmup_secret(self, old_secret):
# old_secret: [B, secret_len]
# new_secret: [B, secret_len]
if self.secret_warmup:
bsz = old_secret.shape[0]
nrepeats = self.secret_len // self.secret_baselen
new_secret = torch.zeros((bsz, self.secret_baselen), dtype=torch.float).random_(0, 2).repeat_interleave(nrepeats, dim=1)
return new_secret.to(old_secret.device)
else:
return old_secret
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.control_ema.store(self.control.parameters())
self.decoder_ema.store(self.decoder.parameters())
self.control_ema.copy_to(self.control)
self.decoder_ema.copy_to(self.decoder)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.control_ema.restore(self.control.parameters())
self.decoder_ema.restore(self.decoder.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.control_ema(self.control)
self.decoder_ema(self.decoder)
def compute_loss(self, pred, target):
# return thf.mse_loss(pred, target, reduction="none").mean(dim=(1, 2, 3))
lpips_loss = self.lpips_loss(pred, target).mean(dim=[1,2,3])
pred_yuv = color.rgb_to_yuv((pred + 1) / 2)
target_yuv = color.rgb_to_yuv((target + 1) / 2)
yuv_loss = torch.mean((pred_yuv - target_yuv)**2, dim=[2,3])
yuv_loss = 1.5*torch.mm(yuv_loss, self.yuv_scales).squeeze(1)
return lpips_loss + yuv_loss
def forward(self, x, image, c):
if self.control.__class__.__name__ == 'SecretEncoder6':
eps, posterior = self.control(x, c)
else:
eps, posterior = self.control(image, c)
return x + eps, posterior
@torch.no_grad()
def get_input(self, batch, return_first_stage=False, bs=None):
image = batch[self.first_stage_key]
control = batch[self.control_key]
control = self.get_warmup_secret(control)
if bs is not None:
image = image[:bs]
control = control[:bs]
else:
bs = image.shape[0]
# encode image 1st stage
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
x = self.encode_first_stage(image).detach()
image_rec = self.decode_first_stage(x).detach()
# check if using fixed input (early training phase)
# if self.training and self.fixed_input:
if self.fixed_input:
if self.fixed_x is None: # first iteration
print('[TRAINING] Warmup - using fixed input image for now!')
self.fixed_x = x.detach().clone()[:bs]
self.fixed_img = image.detach().clone()[:bs]
self.fixed_input_recon = image_rec.detach().clone()[:bs]
self.fixed_control = control.detach().clone()[:bs] # use for log_images with fixed_input option only
x, image, image_rec = self.fixed_x, self.fixed_img, self.fixed_input_recon
out = [x, control]
if return_first_stage:
out.extend([image, image_rec])
return out
def decode_first_stage(self, z):
z = 1./self.scale_factor * z
image_rec = self.ae.decode(z)
return image_rec
def encode_first_stage(self, image):
encoder_posterior = self.ae.encode(image)
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def shared_step(self, batch):
x, c, img, _ = self.get_input(batch, return_first_stage=True)
# import pdb; pdb.set_trace()
x, posterior = self(x, img, c)
image_rec = self.decode_first_stage(x)
# resize
if img.shape[-1] > 256:
img = thf.interpolate(img, size=(256, 256), mode='bilinear', align_corners=False).detach()
image_rec = thf.interpolate(image_rec, size=(256, 256), mode='bilinear', align_corners=False)
if hasattr(self, 'noise') and self.noise.is_activated():
image_rec_noised = self.noise(image_rec, self.global_step, p=0.9)
else:
image_rec_noised = self.crop(image_rec) # center crop
image_rec_noised = torch.clamp(image_rec_noised, -1, 1)
pred = self.decoder(image_rec_noised)
loss, loss_dict = self.loss_layer(img, image_rec, posterior, c, pred, self.global_step)
bit_acc = loss_dict["bit_acc"]
bit_acc_ = bit_acc.item()
if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated():
self.loss_layer.activate_ramp(self.global_step)
if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage
if hasattr(self, 'noise') and (not self.noise.is_activated()):
self.noise.activate(self.global_step)
if (bit_acc_ > 0.9) and self.fixed_input: # execute only once
print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.')
self.fixed_input = ~self.fixed_input
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
# if self.use_scheduler:
# lr = self.optimizers().param_groups[0]['lr']
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'}
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
@torch.no_grad()
def log_images(self, batch, fixed_input=False, **kwargs):
log = dict()
if fixed_input and self.fixed_img is not None:
x, c, img, img_recon = self.fixed_x, self.fixed_control, self.fixed_img, self.fixed_input_recon
else:
x, c, img, img_recon = self.get_input(batch, return_first_stage=True)
x, _ = self(x, img, c)
image_out = self.decode_first_stage(x)
if hasattr(self, 'noise') and self.noise.is_activated():
img_noise = self.noise(image_out, self.global_step, p=1.0)
log['noised'] = img_noise
log['input'] = img
log['output'] = image_out
log['recon'] = img_recon
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.control.parameters()) + list(self.decoder.parameters())
optimizer = torch.optim.AdamW(params, lr=lr)
return optimizer