Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
from torch import nn | |
from imaginaire.discriminators.multires_patch import MultiResPatchDiscriminator | |
from imaginaire.discriminators.residual import ResDiscriminator | |
class Discriminator(nn.Module): | |
r"""MUNIT discriminator. It can be either a multi-resolution patch | |
discriminator like in the original implementation, or a | |
global residual discriminator. | |
Args: | |
dis_cfg (obj): Discriminator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file | |
""" | |
def __init__(self, dis_cfg, data_cfg): | |
super().__init__() | |
if getattr(dis_cfg, 'patch_wise', True): | |
# Use the multi-resolution patch discriminator. It works better for | |
# scene images and when you want to preserve pixel-wise | |
# correspondence during translation. | |
self.discriminator_a = \ | |
MultiResPatchDiscriminator(**vars(dis_cfg)) | |
self.discriminator_b = \ | |
MultiResPatchDiscriminator(**vars(dis_cfg)) | |
else: | |
# Use the global residual discriminator. It works better if images | |
# have a single centered object (e.g., animal faces, shoes). | |
self.discriminator_a = ResDiscriminator(**vars(dis_cfg)) | |
self.discriminator_b = ResDiscriminator(**vars(dis_cfg)) | |
def forward(self, data, net_G_output, gan_recon=False, real=True): | |
r"""Returns the output of the discriminator. | |
Args: | |
data (dict): | |
- images_a (tensor) : Images in domain A. | |
- images_b (tensor) : Images in domain B. | |
net_G_output (dict): | |
- images_ab (tensor) : Images translated from domain A to B by | |
the generator. | |
- images_ba (tensor) : Images translated from domain B to A by | |
the generator. | |
- images_aa (tensor) : Reconstructed images in domain A. | |
- images_bb (tensor) : Reconstructed images in domain B. | |
gan_recon (bool): If ``True``, also classifies reconstructed images. | |
real (bool): If ``True``, also classifies real images. Otherwise it | |
only classifies generated images to save computation during the | |
generator update. | |
Returns: | |
(dict): | |
- out_ab (tensor): Output of the discriminator for images | |
translated from domain A to B by the generator. | |
- out_ab (tensor): Output of the discriminator for images | |
translated from domain B to A by the generator. | |
- fea_ab (tensor): Intermediate features of the discriminator | |
for images translated from domain B to A by the generator. | |
- fea_ba (tensor): Intermediate features of the discriminator | |
for images translated from domain A to B by the generator. | |
- out_a (tensor): Output of the discriminator for images | |
in domain A. | |
- out_b (tensor): Output of the discriminator for images | |
in domain B. | |
- fea_a (tensor): Intermediate features of the discriminator | |
for images in domain A. | |
- fea_b (tensor): Intermediate features of the discriminator | |
for images in domain B. | |
- out_aa (tensor): Output of the discriminator for | |
reconstructed images in domain A. | |
- out_bb (tensor): Output of the discriminator for | |
reconstructed images in domain B. | |
- fea_aa (tensor): Intermediate features of the discriminator | |
for reconstructed images in domain A. | |
- fea_bb (tensor): Intermediate features of the discriminator | |
for reconstructed images in domain B. | |
""" | |
out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab']) | |
out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba']) | |
output = dict(out_ba=out_ba, out_ab=out_ab, | |
fea_ba=fea_ba, fea_ab=fea_ab) | |
if real: | |
out_a, fea_a, _ = self.discriminator_a(data['images_a']) | |
out_b, fea_b, _ = self.discriminator_b(data['images_b']) | |
output.update(dict(out_a=out_a, out_b=out_b, | |
fea_a=fea_a, fea_b=fea_b)) | |
if gan_recon: | |
out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa']) | |
out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb']) | |
output.update(dict(out_aa=out_aa, out_bb=out_bb, | |
fea_aa=fea_aa, fea_bb=fea_bb)) | |
return output | |