import torch.nn as nn
import torch.nn.functional as F
import torch as th
import datetime
import os
import time
import timeit
import copy
import numpy as np
from torch.nn import ModuleList
from torch.nn import Conv2d
from torch.nn import LeakyReLU
class PixelwiseNorm(th.nn.Module):
def __init__(self):
super(PixelwiseNorm, self).__init__()
def forward(self, x, alpha=1e-8):
forward pass of the module
:param x: input activations volume
:param alpha: small number for numerical stability
:return: y => pixel normalized activations
y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW]
y = x / y # normalize the input x volume
return y
class MinibatchStdDev(th.nn.Module):
Minibatch standard deviation layer for the discriminator
def __init__(self):
derived class constructor
def forward(self, x, alpha=1e-8):
forward pass of the layer
:param x: input activation volume
:param alpha: small number for numerical stability
:return: y => x appended with standard deviation constant map
batch_size, _, height, width = x.shape
# [B x C x H x W] Subtract mean over batch.
y = x - x.mean(dim=0, keepdim=True)
# [1 x C x H x W] Calc standard deviation over batch
y = th.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha)
# [1] Take average over feature_maps and pixels.
y = y.mean().view(1, 1, 1, 1)
# [B x 1 x H x W] Replicate over group and pixels.
y = y.repeat(batch_size, 1, height, width)
# [B x C x H x W] Append as new feature_map.
y =[x, y], 1)
# return the computed values:
return y
# ==========================================================
# Equalized learning rate blocks:
# extending Conv2D and Deconv2D layers for equalized learning rate logic
# ==========================================================
class _equalized_conv2d(th.nn.Module):
""" conv2d with the concept of equalized learning rate
:param c_in: input channels
:param c_out: output channels
:param k_size: kernel size (h, w) should be a tuple or a single integer
:param stride: stride for conv
:param pad: padding
:param bias: whether to use bias or not
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
""" constructor for the class """
from torch.nn.modules.utils import _pair
from numpy import sqrt, prod
# define the weight and bias if to be used
self.weight = th.nn.Parameter(th.nn.init.normal_(
th.empty(c_out, c_in, *_pair(k_size))
self.use_bias = bias
self.stride = stride
self.pad = pad
if self.use_bias:
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
fan_in = prod(_pair(k_size)) * c_in # value of fan_in
self.scale = sqrt(2) / sqrt(fan_in)
def forward(self, x):
forward pass of the network
:param x: input
:return: y => output
from torch.nn.functional import conv2d
return conv2d(input=x,
weight=self.weight * self.scale, # scale the weight on runtime
bias=self.bias if self.use_bias else None,
def extra_repr(self):
return ", ".join(map(str, self.weight.shape))
class _equalized_deconv2d(th.nn.Module):
""" Transpose convolution using the equalized learning rate
:param c_in: input channels
:param c_out: output channels
:param k_size: kernel size
:param stride: stride for convolution transpose
:param pad: padding
:param bias: whether to use bias or not
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
""" constructor for the class """
from torch.nn.modules.utils import _pair
from numpy import sqrt
# define the weight and bias if to be used
self.weight = th.nn.Parameter(th.nn.init.normal_(
th.empty(c_in, c_out, *_pair(k_size))
self.use_bias = bias
self.stride = stride
self.pad = pad
if self.use_bias:
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
fan_in = c_in # value of fan_in for deconv
self.scale = sqrt(2) / sqrt(fan_in)
def forward(self, x):
forward pass of the layer
:param x: input
:return: y => output
from torch.nn.functional import conv_transpose2d
return conv_transpose2d(input=x,
weight=self.weight * self.scale, # scale the weight on runtime
bias=self.bias if self.use_bias else None,
def extra_repr(self):
return ", ".join(map(str, self.weight.shape))
#basic block of the encoding part of the genarater
class conv_block(nn.Module):
Convolution Block
with two convolution layers
def __init__(self, in_ch, out_ch,use_eql=True):
super(conv_block, self).__init__()
if use_eql:
self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1),
pad=0, bias=True)
self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3),
pad=1, bias=True)
self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3),
pad=1, bias=True)
self.conv_1 = Conv2d(in_ch, out_ch, (3, 3),
padding=1, bias=True)
self.conv_2 = Conv2d(out_ch, out_ch, (3, 3),
padding=1, bias=True)
# pixel_wise feature normalizer:
self.pixNorm = PixelwiseNorm()
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
forward pass of the block
:param x: input
:return: y => output
from torch.nn.functional import interpolate
#y = interpolate(x, scale_factor=2)
return y
#basic up convolution block of the encoding part of the genarater
class up_conv(nn.Module):
Up Convolution Block
def __init__(self, in_ch, out_ch,use_eql=True):
super(up_conv, self).__init__()
if use_eql:
self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1),
pad=0, bias=True)
self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3),
pad=1, bias=True)
self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3),
pad=1, bias=True)
self.conv_1 = Conv2d(in_ch, out_ch, (3, 3),
padding=1, bias=True)
self.conv_2 = Conv2d(out_ch, out_ch, (3, 3),
padding=1, bias=True)
# pixel_wise feature normalizer:
self.pixNorm = PixelwiseNorm()
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
forward pass of the block
:param x: input
:return: y => output
from torch.nn.functional import interpolate
x = interpolate(x, scale_factor=2, mode="bilinear")
return y
class DisFinalBlock(th.nn.Module):
""" Final block for the Discriminator """
def __init__(self, in_channels, use_eql=True):
constructor of the class
:param in_channels: number of input channels
:param use_eql: whether to use equalized learning rate
from torch.nn import LeakyReLU
from torch.nn import Conv2d
# declare the required modules for forward pass
self.batch_discriminator = MinibatchStdDev()
if use_eql:
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3),
pad=1, bias=True)
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4),stride=2,pad=1,
# final layer emulates the fully connected layer
self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
# modules required:
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True)
self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True)
# final conv layer emulates a fully connected layer
self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True)
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
forward pass of the FinalBlock
:param x: input
:return: y => output
# minibatch_std_dev layer
y = self.batch_discriminator(x)
# define the computations
y = self.lrelu(self.conv_1(y))
y = self.lrelu(self.conv_2(y))
# fully connected layer
y = self.conv_3(y) # This layer has linear activation
# flatten the output raw discriminator scores
return y
class DisGeneralConvBlock(th.nn.Module):
""" General block in the discriminator """
def __init__(self, in_channels, out_channels, use_eql=True):
constructor of the class
:param in_channels: number of input channels
:param out_channels: number of output channels
:param use_eql: whether to use equalized learning rate
from torch.nn import AvgPool2d, LeakyReLU
from torch.nn import Conv2d
if use_eql:
self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3),
pad=1, bias=True)
self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3),
pad=1, bias=True)
# convolutional modules
self.conv_1 = Conv2d(in_channels, in_channels, (3, 3),
padding=1, bias=True)
self.conv_2 = Conv2d(in_channels, out_channels, (3, 3),
padding=1, bias=True)
self.downSampler = AvgPool2d(2) # downsampler
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
forward pass of the module
:param x: input
:return: y => output
# define the computations
y = self.lrelu(self.conv_1(x))
y = self.lrelu(self.conv_2(y))
y = self.downSampler(y)
return y
class from_rgb(nn.Module):
The RGB image is transformed into a multi-channel feature map to be concatenated with
the feature map with the same number of channels in the network
def __init__(self, outchannels, use_eql=True):
super(from_rgb, self).__init__()
if use_eql:
self.conv_1 = _equalized_conv2d(3, outchannels, (1, 1), bias=True)
self.conv_1 = nn.Conv2d(3, outchannels, (1, 1),bias=True)
# pixel_wise feature normalizer:
self.pixNorm = PixelwiseNorm()
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
forward pass of the block
:param x: input
:return: y => output
y = self.pixNorm(self.lrelu(self.conv_1(x)))
return y
class to_rgb(nn.Module):
The multi-channel feature map is converted into RGB image for input to the discriminator
def __init__(self, inchannels, use_eql=True):
super(to_rgb, self).__init__()
if use_eql:
self.conv_1 = _equalized_conv2d(inchannels, 3, (1, 1), bias=True)
self.conv_1 = nn.Conv2d(inchannels, 3, (1, 1),bias=True)
def forward(self, x):
forward pass of the block
:param x: input
:return: y => output
y = self.conv_1(x)
return y
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class CCA(nn.Module):
CCA Block
def __init__(self, F_g, F_x):
self.mlp_x = nn.Sequential(
nn.Linear(F_x, F_x))
self.mlp_g = nn.Sequential(
nn.Linear(F_g, F_x))
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
# channel-wise attention
avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_x = self.mlp_x(avg_pool_x)
avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))
channel_att_g = self.mlp_g(avg_pool_g)
channel_att_sum = (channel_att_x + channel_att_g)/2.0
scale = th.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
x_after_channel = x * scale
out = self.relu(x_after_channel)
return out