|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self, in_channels=3): |
|
super(Discriminator, self).__init__() |
|
|
|
def discriminator_block(in_filters, out_filters, normalization=True): |
|
"""Returns downsampling layers of each discriminator block""" |
|
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] |
|
if normalization: |
|
layers.append(nn.InstanceNorm2d(out_filters)) |
|
layers.append(nn.LeakyReLU(0.2, inplace=True)) |
|
return layers |
|
|
|
self.model = nn.Sequential( |
|
*discriminator_block(in_channels * 3, 64, normalization=False), |
|
*discriminator_block(64, 128), |
|
*discriminator_block(128, 256), |
|
*discriminator_block(256, 512), |
|
nn.ZeroPad2d((1, 0, 1, 0)), |
|
nn.Conv2d(512, 1, 4, padding=1, bias=False) |
|
) |
|
|
|
def forward(self, img_out, img_l, img_ref ): |
|
|
|
img_input = torch.cat((img_out, img_l, img_ref), 1) |
|
return self.model(img_input) |
|
|
|
|