sat3density / imaginaire /losses /feature_matching.py
venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch.nn as nn
class FeatureMatchingLoss(nn.Module):
r"""Compute feature matching loss"""
def __init__(self, criterion='l1'):
super(FeatureMatchingLoss, self).__init__()
if criterion == 'l1':
self.criterion = nn.L1Loss()
elif criterion == 'l2' or criterion == 'mse':
self.criterion = nn.MSELoss()
else:
raise ValueError('Criterion %s is not recognized' % criterion)
def forward(self, fake_features, real_features):
r"""Return the target vector for the binary cross entropy loss
computation.
Args:
fake_features (list of lists): Discriminator features of fake images.
real_features (list of lists): Discriminator features of real images.
Returns:
(tensor): Loss value.
"""
num_d = len(fake_features)
dis_weight = 1.0 / num_d
loss = fake_features[0][0].new_tensor(0)
for i in range(num_d):
for j in range(len(fake_features[i])):
tmp_loss = self.criterion(fake_features[i][j],
real_features[i][j].detach())
loss += dis_weight * tmp_loss
return loss