import torch import torch.nn as nn from ....util import default, instantiate_from_config from ..lpips.loss.lpips import LPIPS class LatentLPIPS(nn.Module): def __init__( self, decoder_config, perceptual_weight=1.0, latent_weight=1.0, scale_input_to_tgt_size=False, scale_tgt_to_input_size=False, perceptual_weight_on_inputs=0.0, ): super().__init__() self.scale_input_to_tgt_size = scale_input_to_tgt_size self.scale_tgt_to_input_size = scale_tgt_to_input_size self.init_decoder(decoder_config) self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight self.latent_weight = latent_weight self.perceptual_weight_on_inputs = perceptual_weight_on_inputs def init_decoder(self, config): self.decoder = instantiate_from_config(config) if hasattr(self.decoder, "encoder"): del self.decoder.encoder def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): log = dict() loss = (latent_inputs - latent_predictions) ** 2 log[f"{split}/latent_l2_loss"] = loss.mean().detach() image_reconstructions = None if self.perceptual_weight > 0.0: image_reconstructions = self.decoder.decode(latent_predictions) image_targets = self.decoder.decode(latent_inputs) perceptual_loss = self.perceptual_loss( image_targets.contiguous(), image_reconstructions.contiguous() ) loss = ( self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() ) log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() if self.perceptual_weight_on_inputs > 0.0: image_reconstructions = default( image_reconstructions, self.decoder.decode(latent_predictions) ) if self.scale_input_to_tgt_size: image_inputs = torch.nn.functional.interpolate( image_inputs, image_reconstructions.shape[2:], mode="bicubic", antialias=True, ) elif self.scale_tgt_to_input_size: image_reconstructions = torch.nn.functional.interpolate( image_reconstructions, image_inputs.shape[2:], mode="bicubic", antialias=True, ) perceptual_loss2 = self.perceptual_loss( image_inputs.contiguous(), image_reconstructions.contiguous() ) loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() return loss, log