venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from torch import nn
from imaginaire.discriminators.multires_patch import MultiResPatchDiscriminator
from imaginaire.discriminators.residual import ResDiscriminator
class Discriminator(nn.Module):
r"""MUNIT discriminator. It can be either a multi-resolution patch
discriminator like in the original implementation, or a
global residual discriminator.
Args:
dis_cfg (obj): Discriminator definition part of the yaml config file.
data_cfg (obj): Data definition part of the yaml config file
"""
def __init__(self, dis_cfg, data_cfg):
super().__init__()
if getattr(dis_cfg, 'patch_wise', True):
# Use the multi-resolution patch discriminator. It works better for
# scene images and when you want to preserve pixel-wise
# correspondence during translation.
self.discriminator_a = \
MultiResPatchDiscriminator(**vars(dis_cfg))
self.discriminator_b = \
MultiResPatchDiscriminator(**vars(dis_cfg))
else:
# Use the global residual discriminator. It works better if images
# have a single centered object (e.g., animal faces, shoes).
self.discriminator_a = ResDiscriminator(**vars(dis_cfg))
self.discriminator_b = ResDiscriminator(**vars(dis_cfg))
def forward(self, data, net_G_output, gan_recon=False, real=True):
r"""Returns the output of the discriminator.
Args:
data (dict):
- images_a (tensor) : Images in domain A.
- images_b (tensor) : Images in domain B.
net_G_output (dict):
- images_ab (tensor) : Images translated from domain A to B by
the generator.
- images_ba (tensor) : Images translated from domain B to A by
the generator.
- images_aa (tensor) : Reconstructed images in domain A.
- images_bb (tensor) : Reconstructed images in domain B.
gan_recon (bool): If ``True``, also classifies reconstructed images.
real (bool): If ``True``, also classifies real images. Otherwise it
only classifies generated images to save computation during the
generator update.
Returns:
(dict):
- out_ab (tensor): Output of the discriminator for images
translated from domain A to B by the generator.
- out_ab (tensor): Output of the discriminator for images
translated from domain B to A by the generator.
- fea_ab (tensor): Intermediate features of the discriminator
for images translated from domain B to A by the generator.
- fea_ba (tensor): Intermediate features of the discriminator
for images translated from domain A to B by the generator.
- out_a (tensor): Output of the discriminator for images
in domain A.
- out_b (tensor): Output of the discriminator for images
in domain B.
- fea_a (tensor): Intermediate features of the discriminator
for images in domain A.
- fea_b (tensor): Intermediate features of the discriminator
for images in domain B.
- out_aa (tensor): Output of the discriminator for
reconstructed images in domain A.
- out_bb (tensor): Output of the discriminator for
reconstructed images in domain B.
- fea_aa (tensor): Intermediate features of the discriminator
for reconstructed images in domain A.
- fea_bb (tensor): Intermediate features of the discriminator
for reconstructed images in domain B.
"""
out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab'])
out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba'])
output = dict(out_ba=out_ba, out_ab=out_ab,
fea_ba=fea_ba, fea_ab=fea_ab)
if real:
out_a, fea_a, _ = self.discriminator_a(data['images_a'])
out_b, fea_b, _ = self.discriminator_b(data['images_b'])
output.update(dict(out_a=out_a, out_b=out_b,
fea_a=fea_a, fea_b=fea_b))
if gan_recon:
out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa'])
out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb'])
output.update(dict(out_aa=out_aa, out_bb=out_bb,
fea_aa=fea_aa, fea_bb=fea_bb))
return output