Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from ....util import default, instantiate_from_config | |
from ..lpips.loss.lpips import LPIPS | |
class LatentLPIPS(nn.Module): | |
def __init__( | |
self, | |
decoder_config, | |
perceptual_weight=1.0, | |
latent_weight=1.0, | |
scale_input_to_tgt_size=False, | |
scale_tgt_to_input_size=False, | |
perceptual_weight_on_inputs=0.0, | |
): | |
super().__init__() | |
self.scale_input_to_tgt_size = scale_input_to_tgt_size | |
self.scale_tgt_to_input_size = scale_tgt_to_input_size | |
self.init_decoder(decoder_config) | |
self.perceptual_loss = LPIPS().eval() | |
self.perceptual_weight = perceptual_weight | |
self.latent_weight = latent_weight | |
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs | |
def init_decoder(self, config): | |
self.decoder = instantiate_from_config(config) | |
if hasattr(self.decoder, "encoder"): | |
del self.decoder.encoder | |
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): | |
log = dict() | |
loss = (latent_inputs - latent_predictions) ** 2 | |
log[f"{split}/latent_l2_loss"] = loss.mean().detach() | |
image_reconstructions = None | |
if self.perceptual_weight > 0.0: | |
image_reconstructions = self.decoder.decode(latent_predictions) | |
image_targets = self.decoder.decode(latent_inputs) | |
perceptual_loss = self.perceptual_loss( | |
image_targets.contiguous(), image_reconstructions.contiguous() | |
) | |
loss = ( | |
self.latent_weight * loss.mean() | |
+ self.perceptual_weight * perceptual_loss.mean() | |
) | |
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() | |
if self.perceptual_weight_on_inputs > 0.0: | |
image_reconstructions = default( | |
image_reconstructions, self.decoder.decode(latent_predictions) | |
) | |
if self.scale_input_to_tgt_size: | |
image_inputs = torch.nn.functional.interpolate( | |
image_inputs, | |
image_reconstructions.shape[2:], | |
mode="bicubic", | |
antialias=True, | |
) | |
elif self.scale_tgt_to_input_size: | |
image_reconstructions = torch.nn.functional.interpolate( | |
image_reconstructions, | |
image_inputs.shape[2:], | |
mode="bicubic", | |
antialias=True, | |
) | |
perceptual_loss2 = self.perceptual_loss( | |
image_inputs.contiguous(), image_reconstructions.contiguous() | |
) | |
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() | |
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() | |
return loss, log | |