test / cldm /loss.py
Tu Bui
first commit
6142a25
import torch
import torch.nn as nn
from lpips import LPIPS
from kornia import color
# from taming.modules.losses.vqperceptual import *
class ImageSecretLoss(nn.Module):
def __init__(self, recon_type='rgb', recon_weight=1., perceptual_weight=1.0, secret_weight=10., kl_weight=0.000001, logvar_init=0.0, ramp=100000, max_image_weight_ratio=2.) -> None:
super().__init__()
self.recon_type = recon_type
assert recon_type in ['rgb', 'yuv']
if recon_type == 'yuv':
self.register_buffer('yuv_scales', torch.tensor([1,100,100]).unsqueeze(1).float()) # [3,1]
self.recon_weight = recon_weight
self.perceptual_weight = perceptual_weight
self.secret_weight = secret_weight
self.kl_weight = kl_weight
self.ramp = ramp
self.max_image_weight = max_image_weight_ratio * secret_weight - 1
self.register_buffer('ramp_on', torch.tensor(False))
self.register_buffer('step0', torch.tensor(1e9)) # large number
self.perceptual_loss = LPIPS().eval()
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.bce = nn.BCEWithLogitsLoss(reduction="none")
def activate_ramp(self, global_step):
if not self.ramp_on: # do not activate ramp twice
self.step0 = torch.tensor(global_step)
self.ramp_on = ~self.ramp_on
print('[TRAINING] Activate ramp for image loss at step ', global_step)
def compute_recon_loss(self, inputs, reconstructions):
if self.recon_type == 'rgb':
rec_loss = torch.abs(inputs - reconstructions).mean(dim=[1,2,3])
elif self.recon_type == 'yuv':
reconstructions_yuv = color.rgb_to_yuv((reconstructions + 1) / 2)
inputs_yuv = color.rgb_to_yuv((inputs + 1) / 2)
yuv_loss = torch.mean((reconstructions_yuv - inputs_yuv)**2, dim=[2,3])
rec_loss = torch.mm(yuv_loss, self.yuv_scales).squeeze(1)
else:
raise ValueError(f"Unknown recon type {self.recon_type}")
return rec_loss
def forward(self, inputs, reconstructions, posteriors, secret_gt, secret_pred, global_step):
loss_dict = {}
rec_loss = self.compute_recon_loss(inputs.contiguous(), reconstructions.contiguous())
loss = rec_loss*self.recon_weight
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()).mean(dim=[1,2,3])
loss += self.perceptual_weight * p_loss
loss_dict['p_loss'] = p_loss.mean()
loss = loss / torch.exp(self.logvar) + self.logvar
if self.kl_weight > 0:
kl_loss = posteriors.kl()
loss += kl_loss*self.kl_weight
loss_dict['kl_loss'] = kl_loss.mean()
image_weight = 1 + min(self.max_image_weight, max(0., self.max_image_weight*(global_step - self.step0.item())/self.ramp))
secret_loss = self.bce(secret_pred, secret_gt).mean(dim=1)
loss = (loss*image_weight + secret_loss*self.secret_weight) / (image_weight+self.secret_weight)
# loss dict update
bit_acc = ((secret_pred.detach() > 0).float() == secret_gt).float().mean()
loss_dict['bit_acc'] = bit_acc
loss_dict['loss'] = loss.mean()
loss_dict['img_lw'] = image_weight/self.secret_weight
loss_dict['rec_loss'] = rec_loss.mean()
loss_dict['secret_loss'] = secret_loss.mean()
return loss.mean(), loss_dict