Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch.nn as nn | |
class FeatureMatchingLoss(nn.Module): | |
r"""Compute feature matching loss""" | |
def __init__(self, criterion='l1'): | |
super(FeatureMatchingLoss, self).__init__() | |
if criterion == 'l1': | |
self.criterion = nn.L1Loss() | |
elif criterion == 'l2' or criterion == 'mse': | |
self.criterion = nn.MSELoss() | |
else: | |
raise ValueError('Criterion %s is not recognized' % criterion) | |
def forward(self, fake_features, real_features): | |
r"""Return the target vector for the binary cross entropy loss | |
computation. | |
Args: | |
fake_features (list of lists): Discriminator features of fake images. | |
real_features (list of lists): Discriminator features of real images. | |
Returns: | |
(tensor): Loss value. | |
""" | |
num_d = len(fake_features) | |
dis_weight = 1.0 / num_d | |
loss = fake_features[0][0].new_tensor(0) | |
for i in range(num_d): | |
for j in range(len(fake_features[i])): | |
tmp_loss = self.criterion(fake_features[i][j], | |
real_features[i][j].detach()) | |
loss += dis_weight * tmp_loss | |
return loss | |