test / cldm /ae.py
Tu Bui
first commit
6142a25
raw
history blame
No virus
29 kB
import numpy as np
import einops
import torch
import torch as th
import torch.nn as nn
from torch.nn import functional as thf
import pytorch_lightning as pl
import torchvision
from copy import deepcopy
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config, default
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.modules.diffusionmodules.model import Encoder
import lpips
import kornia
from kornia import color
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class View(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
class SecretEncoder3(nn.Module):
def __init__(self, secret_len, base_res=16, resolution=64) -> None:
super().__init__()
log_resolution = int(np.log2(resolution))
log_base = int(np.log2(base_res))
self.secret_len = secret_len
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, base_res*base_res*3),
nn.SiLU(),
View(-1, 3, base_res, base_res),
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
zero_module(conv_nd(2, 3, 3, 3, padding=1))
) # secret len -> ch x res x res
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
def encode(self, x):
x = self.secret_scaler(x)
return x
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.encode(c)
return c, None
class SecretEncoder4(nn.Module):
"""same as SecretEncoder3 but with ch as input"""
def __init__(self, secret_len, ch=3, base_res=16, resolution=64) -> None:
super().__init__()
log_resolution = int(np.log2(resolution))
log_base = int(np.log2(base_res))
self.secret_len = secret_len
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, base_res*base_res*ch),
nn.SiLU(),
View(-1, ch, base_res, base_res),
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
zero_module(conv_nd(2, ch, ch, 3, padding=1))
) # secret len -> ch x res x res
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
def encode(self, x):
x = self.secret_scaler(x)
return x
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.encode(c)
return c, None
class SecretEncoder6(nn.Module):
"""join img emb with secret emb"""
def __init__(self, secret_len, ch=3, base_res=16, resolution=64, emode='c3') -> None:
super().__init__()
assert emode in ['c3', 'c2', 'm3']
if emode == 'c3': # c3: concat c and x each has ch channels
secret_ch = ch
join_ch = 2*ch
elif emode == 'c2': # c2: concat c (2) and x ave (1)
secret_ch = 2
join_ch = ch
elif emode == 'm3': # m3: multiply c (ch) and x (ch)
secret_ch = ch
join_ch = ch
# m3: multiply c (ch) and x ave (1)
log_resolution = int(np.log2(resolution))
log_base = int(np.log2(base_res))
self.secret_len = secret_len
self.emode = emode
self.resolution = resolution
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, base_res*base_res*secret_ch),
nn.SiLU(),
View(-1, secret_ch, base_res, base_res),
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
) # secret len -> ch x res x res
self.join_encoder = nn.Sequential(
conv_nd(2, join_ch, join_ch, 3, padding=1),
nn.SiLU(),
conv_nd(2, join_ch, ch, 3, padding=1),
nn.SiLU(),
conv_nd(2, ch, ch, 3, padding=1),
nn.SiLU()
)
self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1))
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
def encode(self, x):
x = self.secret_scaler(x)
return x
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.encode(c)
if self.emode == 'c3':
x = torch.cat([x, c], dim=1)
elif self.emode == 'c2':
x = torch.cat([x.mean(dim=1, keepdim=True), c], dim=1)
elif self.emode == 'm3':
x = x * c
dx = self.join_encoder(x)
dx = self.out_layer(dx)
return dx, None
class SecretEncoder5(nn.Module):
"""same as SecretEncoder3 but with ch as input"""
def __init__(self, secret_len, ch=3, base_res=16, resolution=64, joint=False) -> None:
super().__init__()
log_resolution = int(np.log2(resolution))
log_base = int(np.log2(base_res))
self.secret_len = secret_len
self.joint = joint
self.resolution = resolution
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, base_res*base_res*ch),
nn.SiLU(),
View(-1, ch, base_res, base_res),
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
) # secret len -> ch x res x res
if joint:
self.join_encoder = nn.Sequential(
conv_nd(2, 2*ch, 2*ch, 3, padding=1),
nn.SiLU(),
conv_nd(2, 2*ch, ch, 3, padding=1),
nn.SiLU()
)
self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1))
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
def encode(self, x):
x = self.secret_scaler(x)
return x
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.encode(c)
if self.joint:
x = thf.interpolate(x, size=(self.resolution, self.resolution), mode="bilinear", align_corners=False, antialias=True)
c = self.join_encoder(torch.cat([x, c], dim=1))
c = self.out_layer(c)
return c, None
class SecretEncoder2(nn.Module):
def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False) -> None:
super().__init__()
log_resolution = int(np.log2(ddconfig.resolution))
self.secret_len = secret_len
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.encoder.conv_out = zero_module(self.encoder.conv_out)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, 32*32*ddconfig.out_ch),
nn.SiLU(),
View(-1, ddconfig.out_ch, 32, 32),
nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
# zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
) # secret len -> ch x res x res
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
# self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
misses, ignores = self.load_state_dict(sd, strict=False)
print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}")
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
return None
self.encoder.load_state_dict(ae_model.encoder.state_dict())
self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict())
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
posterior = h
return posterior
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.secret_scaler(c)
x = torch.cat([x, c], dim=1)
z = self.encode(x)
# z = self.out_layer(z)
return z, None
class SecretEncoder7(nn.Module):
def __init__(self, secret_len, ddconfig, ckpt_path=None,
ignore_keys=[],embed_dim=3,
ema_decay=None) -> None:
super().__init__()
log_resolution = int(np.log2(ddconfig.resolution))
self.secret_len = secret_len
self.encoder = Encoder(**ddconfig)
# self.encoder.conv_out = zero_module(self.encoder.conv_out)
self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, 32*32*2),
nn.SiLU(),
View(-1, 2, 32, 32),
# nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
# zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
) # secret len -> ch x res x res
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
# self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
misses, ignores = self.load_state_dict(sd, strict=False)
print(f"[SecretEncoder7] Restored from {path}, misses: {len(misses)}, ignores: {len(ignores)}. Do not worry as we are not using the decoder and the secret encoder is a novel module.")
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
# return None
self.encoder.load_state_dict(deepcopy(ae_model.encoder.state_dict()))
self.quant_conv.load_state_dict(deepcopy(ae_model.quant_conv.state_dict()))
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.secret_scaler(c) # [B, 2, 32, 32]
# c = thf.interpolate(c, size=x.shape[-2:], mode="bilinear", align_corners=False)
c = thf.interpolate(c, size=x.shape[-2:], mode="nearest")
x = 0.2125 * x[:,0,...] + 0.7154 *x[:,1,...] + 0.0721 * x[:,2,...]
x = torch.cat([x.unsqueeze(1), c], dim=1)
z = self.encode(x)
# z = self.out_layer(z)
return z, None
class SecretEncoder(nn.Module):
def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False) -> None:
super().__init__()
log_resolution = int(np.log2(ddconfig.resolution))
self.secret_len = secret_len
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.secret_scaler = nn.Sequential(
nn.Linear(secret_len, 32*32*ddconfig.out_ch),
nn.SiLU(),
View(-1, ddconfig.out_ch, 32, 32),
nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
) # secret len -> ch x res x res
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
misses, ignores = self.load_state_dict(sd, strict=False)
print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}")
def copy_encoder_weight(self, ae_model):
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
self.encoder.load_state_dict(ae_model.encoder.state_dict())
self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict())
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def forward(self, x, c):
# x: [B, C, H, W], c: [B, secret_len]
c = self.secret_scaler(c)
x = x + c
posterior = self.encode(x)
z = posterior.sample()
z = self.out_layer(z)
return z, posterior
class ControlAE(pl.LightningModule):
def __init__(self,
first_stage_key,
first_stage_config,
control_key,
control_config,
decoder_config,
loss_config,
noise_config='__none__',
use_ema=False,
secret_warmup=False,
scale_factor=1.,
ckpt_path="__none__",
):
super().__init__()
self.scale_factor = scale_factor
self.control_key = control_key
self.first_stage_key = first_stage_key
self.ae = instantiate_from_config(first_stage_config)
self.control = instantiate_from_config(control_config)
self.decoder = instantiate_from_config(decoder_config)
self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase
if noise_config != '__none__':
print('Using noise')
self.noise = instantiate_from_config(noise_config)
# copy weights from first stage
self.control.copy_encoder_weight(self.ae)
# freeze first stage
self.ae.eval()
self.ae.train = disabled_train
for p in self.ae.parameters():
p.requires_grad = False
self.loss_layer = instantiate_from_config(loss_config)
# early training phase
# self.fixed_input = True
self.fixed_x = None
self.fixed_img = None
self.fixed_input_recon = None
self.fixed_control = None
self.register_buffer("fixed_input", torch.tensor(True))
# secret warmup
self.secret_warmup = secret_warmup
self.secret_baselen = 2
self.secret_len = control_config.params.secret_len
if self.secret_warmup:
assert self.secret_len == 2**(int(np.log2(self.secret_len)))
self.use_ema = use_ema
if self.use_ema:
print('Using EMA')
self.control_ema = LitEma(self.control)
self.decoder_ema = LitEma(self.decoder)
print(f"Keeping EMAs of {len(list(self.control_ema.buffers()) + list(self.decoder_ema.buffers()))}.")
if ckpt_path != '__none__':
self.init_from_ckpt(ckpt_path, ignore_keys=[])
def get_warmup_secret(self, old_secret):
# old_secret: [B, secret_len]
# new_secret: [B, secret_len]
if self.secret_warmup:
bsz = old_secret.shape[0]
nrepeats = self.secret_len // self.secret_baselen
new_secret = torch.zeros((bsz, self.secret_baselen), dtype=torch.float).random_(0, 2).repeat_interleave(nrepeats, dim=1)
return new_secret.to(old_secret.device)
else:
return old_secret
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.control_ema.store(self.control.parameters())
self.decoder_ema.store(self.decoder.parameters())
self.control_ema.copy_to(self.control)
self.decoder_ema.copy_to(self.decoder)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.control_ema.restore(self.control.parameters())
self.decoder_ema.restore(self.decoder.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.control_ema(self.control)
self.decoder_ema(self.decoder)
def compute_loss(self, pred, target):
# return thf.mse_loss(pred, target, reduction="none").mean(dim=(1, 2, 3))
lpips_loss = self.lpips_loss(pred, target).mean(dim=[1,2,3])
pred_yuv = color.rgb_to_yuv((pred + 1) / 2)
target_yuv = color.rgb_to_yuv((target + 1) / 2)
yuv_loss = torch.mean((pred_yuv - target_yuv)**2, dim=[2,3])
yuv_loss = 1.5*torch.mm(yuv_loss, self.yuv_scales).squeeze(1)
return lpips_loss + yuv_loss
def forward(self, x, image, c):
if self.control.__class__.__name__ == 'SecretEncoder6':
eps, posterior = self.control(x, c)
else:
eps, posterior = self.control(image, c)
return x + eps, posterior
@torch.no_grad()
def get_input(self, batch, return_first_stage=False, bs=None):
image = batch[self.first_stage_key]
control = batch[self.control_key]
control = self.get_warmup_secret(control)
if bs is not None:
image = image[:bs]
control = control[:bs]
else:
bs = image.shape[0]
# encode image 1st stage
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
x = self.encode_first_stage(image).detach()
image_rec = self.decode_first_stage(x).detach()
# check if using fixed input (early training phase)
# if self.training and self.fixed_input:
if self.fixed_input:
if self.fixed_x is None: # first iteration
print('[TRAINING] Warmup - using fixed input image for now!')
self.fixed_x = x.detach().clone()[:bs]
self.fixed_img = image.detach().clone()[:bs]
self.fixed_input_recon = image_rec.detach().clone()[:bs]
self.fixed_control = control.detach().clone()[:bs] # use for log_images with fixed_input option only
x, image, image_rec = self.fixed_x, self.fixed_img, self.fixed_input_recon
out = [x, control]
if return_first_stage:
out.extend([image, image_rec])
return out
def decode_first_stage(self, z):
z = 1./self.scale_factor * z
image_rec = self.ae.decode(z)
return image_rec
def encode_first_stage(self, image):
encoder_posterior = self.ae.encode(image)
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def shared_step(self, batch):
x, c, img, _ = self.get_input(batch, return_first_stage=True)
# import pdb; pdb.set_trace()
x, posterior = self(x, img, c)
image_rec = self.decode_first_stage(x)
# resize
if img.shape[-1] > 256:
img = thf.interpolate(img, size=(256, 256), mode='bilinear', align_corners=False).detach()
image_rec = thf.interpolate(image_rec, size=(256, 256), mode='bilinear', align_corners=False)
if hasattr(self, 'noise') and self.noise.is_activated():
image_rec_noised = self.noise(image_rec, self.global_step, p=0.9)
else:
image_rec_noised = self.crop(image_rec) # center crop
image_rec_noised = torch.clamp(image_rec_noised, -1, 1)
pred = self.decoder(image_rec_noised)
loss, loss_dict = self.loss_layer(img, image_rec, posterior, c, pred, self.global_step)
bit_acc = loss_dict["bit_acc"]
bit_acc_ = bit_acc.item()
if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated():
self.loss_layer.activate_ramp(self.global_step)
if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage
if hasattr(self, 'noise') and (not self.noise.is_activated()):
self.noise.activate(self.global_step)
if (bit_acc_ > 0.9) and self.fixed_input: # execute only once
print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.')
self.fixed_input = ~self.fixed_input
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
# if self.use_scheduler:
# lr = self.optimizers().param_groups[0]['lr']
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'}
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
@torch.no_grad()
def log_images(self, batch, fixed_input=False, **kwargs):
log = dict()
if fixed_input and self.fixed_img is not None:
x, c, img, img_recon = self.fixed_x, self.fixed_control, self.fixed_img, self.fixed_input_recon
else:
x, c, img, img_recon = self.get_input(batch, return_first_stage=True)
x, _ = self(x, img, c)
image_out = self.decode_first_stage(x)
if hasattr(self, 'noise') and self.noise.is_activated():
img_noise = self.noise(image_out, self.global_step, p=1.0)
log['noised'] = img_noise
log['input'] = img
log['output'] = image_out
log['recon'] = img_recon
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.control.parameters()) + list(self.decoder.parameters())
optimizer = torch.optim.AdamW(params, lr=lr)
return optimizer