AlexZou's picture
Upload 17 files
7eb6194
raw history blame
No virus
14.5 kB
import torch.nn as nn
import torch.nn.functional as F
import torch as th
import datetime
import os
import time
import timeit
import copy
import numpy as np
from torch.nn import ModuleList
from torch.nn import Conv2d
from torch.nn import LeakyReLU
#PixelwiseNorm代替了BatchNorm
class PixelwiseNorm(th.nn.Module):
def __init__(self):
super(PixelwiseNorm, self).__init__()
def forward(self, x, alpha=1e-8):
"""
forward pass of the module
:param x: input activations volume
:param alpha: small number for numerical stability
:return: y => pixel normalized activations
"""
y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW]
y = x / y # normalize the input x volume
return y
class MinibatchStdDev(th.nn.Module):
"""
Minibatch standard deviation layer for the discriminator
"""
def __init__(self):
"""
derived class constructor
"""
super().__init__()
def forward(self, x, alpha=1e-8):
"""
forward pass of the layer
:param x: input activation volume
:param alpha: small number for numerical stability
:return: y => x appended with standard deviation constant map
"""
batch_size, _, height, width = x.shape
# [B x C x H x W] Subtract mean over batch.
y = x - x.mean(dim=0, keepdim=True)
# [1 x C x H x W] Calc standard deviation over batch
y = th.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha)
# [1] Take average over feature_maps and pixels.
y = y.mean().view(1, 1, 1, 1)
# [B x 1 x H x W] Replicate over group and pixels.
y = y.repeat(batch_size, 1, height, width)
# [B x C x H x W] Append as new feature_map.
y = th.cat([x, y], 1)
# return the computed values:
return y
# ==========================================================
# Equalized learning rate blocks:
# extending Conv2D and Deconv2D layers for equalized learning rate logic
# ==========================================================
class _equalized_conv2d(th.nn.Module):
""" conv2d with the concept of equalized learning rate
Args:
:param c_in: input channels
:param c_out: output channels
:param k_size: kernel size (h, w) should be a tuple or a single integer
:param stride: stride for conv
:param pad: padding
:param bias: whether to use bias or not
"""
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
""" constructor for the class """
from torch.nn.modules.utils import _pair
from numpy import sqrt, prod
super().__init__()
# define the weight and bias if to be used
self.weight = th.nn.Parameter(th.nn.init.normal_(
th.empty(c_out, c_in, *_pair(k_size))
))
self.use_bias = bias
self.stride = stride
self.pad = pad
if self.use_bias:
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
fan_in = prod(_pair(k_size)) * c_in # value of fan_in
self.scale = sqrt(2) / sqrt(fan_in)
def forward(self, x):
"""
forward pass of the network
:param x: input
:return: y => output
"""
from torch.nn.functional import conv2d
return conv2d(input=x,
weight=self.weight * self.scale, # scale the weight on runtime
bias=self.bias if self.use_bias else None,
stride=self.stride,
padding=self.pad)
def extra_repr(self):
return ", ".join(map(str, self.weight.shape))
class _equalized_deconv2d(th.nn.Module):
""" Transpose convolution using the equalized learning rate
Args:
:param c_in: input channels
:param c_out: output channels
:param k_size: kernel size
:param stride: stride for convolution transpose
:param pad: padding
:param bias: whether to use bias or not
"""
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True):
""" constructor for the class """
from torch.nn.modules.utils import _pair
from numpy import sqrt
super().__init__()
# define the weight and bias if to be used
self.weight = th.nn.Parameter(th.nn.init.normal_(
th.empty(c_in, c_out, *_pair(k_size))
))
self.use_bias = bias
self.stride = stride
self.pad = pad
if self.use_bias:
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
fan_in = c_in # value of fan_in for deconv
self.scale = sqrt(2) / sqrt(fan_in)
def forward(self, x):
"""
forward pass of the layer
:param x: input
:return: y => output
"""
from torch.nn.functional import conv_transpose2d
return conv_transpose2d(input=x,
weight=self.weight * self.scale, # scale the weight on runtime
bias=self.bias if self.use_bias else None,
stride=self.stride,
padding=self.pad)
def extra_repr(self):
return ", ".join(map(str, self.weight.shape))
#basic block of the encoding part of the genarater
#编码器的基本卷积块
class conv_block(nn.Module):
"""
Convolution Block
with two convolution layers
"""
def __init__(self, in_ch, out_ch,use_eql=True):
super(conv_block, self).__init__()
if use_eql:
self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1),
pad=0, bias=True)
self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3),
pad=1, bias=True)
self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3),
pad=1, bias=True)
else:
self.conv_1 = Conv2d(in_ch, out_ch, (3, 3),
padding=1, bias=True)
self.conv_2 = Conv2d(out_ch, out_ch, (3, 3),
padding=1, bias=True)
# pixel_wise feature normalizer:
self.pixNorm = PixelwiseNorm()
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
"""
forward pass of the block
:param x: input
:return: y => output
"""
from torch.nn.functional import interpolate
#y = interpolate(x, scale_factor=2)
y=self.conv_1(self.lrelu(self.pixNorm(x)))
residual=y
y=self.conv_2(self.lrelu(self.pixNorm(y)))
y=self.conv_3(self.lrelu(self.pixNorm(y)))
y=y+residual
return y
#basic up convolution block of the encoding part of the genarater
#编码器的基本卷积块
class up_conv(nn.Module):
"""
Up Convolution Block
"""
def __init__(self, in_ch, out_ch,use_eql=True):
super(up_conv, self).__init__()
if use_eql:
self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1),
pad=0, bias=True)
self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3),
pad=1, bias=True)
self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3),
pad=1, bias=True)
else:
self.conv_1 = Conv2d(in_ch, out_ch, (3, 3),
padding=1, bias=True)
self.conv_2 = Conv2d(out_ch, out_ch, (3, 3),
padding=1, bias=True)
# pixel_wise feature normalizer:
self.pixNorm = PixelwiseNorm()
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
"""
forward pass of the block
:param x: input
:return: y => output
"""
from torch.nn.functional import interpolate
x = interpolate(x, scale_factor=2, mode="bilinear")
y=self.conv_1(self.lrelu(self.pixNorm(x)))
residual=y
y=self.conv_2(self.lrelu(self.pixNorm(y)))
y=self.conv_3(self.lrelu(self.pixNorm(y)))
y=y+residual
return y
#判别器的最后一层
class DisFinalBlock(th.nn.Module):
""" Final block for the Discriminator """
def __init__(self, in_channels, use_eql=True):
"""
constructor of the class
:param in_channels: number of input channels
:param use_eql: whether to use equalized learning rate
"""
from torch.nn import LeakyReLU
from torch.nn import Conv2d
super().__init__()
# declare the required modules for forward pass
self.batch_discriminator = MinibatchStdDev()
if use_eql:
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3),
pad=1, bias=True)
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4),stride=2,pad=1,
bias=True)
# final layer emulates the fully connected layer
self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
else:
# modules required:
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True)
self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True)
# final conv layer emulates a fully connected layer
self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True)
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
"""
forward pass of the FinalBlock
:param x: input
:return: y => output
"""
# minibatch_std_dev layer
y = self.batch_discriminator(x)
# define the computations
y = self.lrelu(self.conv_1(y))
y = self.lrelu(self.conv_2(y))
# fully connected layer
y = self.conv_3(y) # This layer has linear activation
# flatten the output raw discriminator scores
return y
#判别器基本卷积块
class DisGeneralConvBlock(th.nn.Module):
""" General block in the discriminator """
def __init__(self, in_channels, out_channels, use_eql=True):
"""
constructor of the class
:param in_channels: number of input channels
:param out_channels: number of output channels
:param use_eql: whether to use equalized learning rate
"""
from torch.nn import AvgPool2d, LeakyReLU
from torch.nn import Conv2d
super().__init__()
if use_eql:
self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3),
pad=1, bias=True)
self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3),
pad=1, bias=True)
else:
# convolutional modules
self.conv_1 = Conv2d(in_channels, in_channels, (3, 3),
padding=1, bias=True)
self.conv_2 = Conv2d(in_channels, out_channels, (3, 3),
padding=1, bias=True)
self.downSampler = AvgPool2d(2) # downsampler
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
"""
forward pass of the module
:param x: input
:return: y => output
"""
# define the computations
y = self.lrelu(self.conv_1(x))
y = self.lrelu(self.conv_2(y))
y = self.downSampler(y)
return y
class from_rgb(nn.Module):
"""
The RGB image is transformed into a multi-channel feature map to be concatenated with
the feature map with the same number of channels in the network
把RGB图转换为多通道特征图,以便与网络中相同通道数的特征图拼接
"""
def __init__(self, outchannels, use_eql=True):
super(from_rgb, self).__init__()
if use_eql:
self.conv_1 = _equalized_conv2d(3, outchannels, (1, 1), bias=True)
else:
self.conv_1 = nn.Conv2d(3, outchannels, (1, 1),bias=True)
# pixel_wise feature normalizer:
self.pixNorm = PixelwiseNorm()
# leaky_relu:
self.lrelu = LeakyReLU(0.2)
def forward(self, x):
"""
forward pass of the block
:param x: input
:return: y => output
"""
y = self.pixNorm(self.lrelu(self.conv_1(x)))
return y
class to_rgb(nn.Module):
"""
把多通道特征图转换为RGB三通道图,以便输入判别器
The multi-channel feature map is converted into RGB image for input to the discriminator
"""
def __init__(self, inchannels, use_eql=True):
super(to_rgb, self).__init__()
if use_eql:
self.conv_1 = _equalized_conv2d(inchannels, 3, (1, 1), bias=True)
else:
self.conv_1 = nn.Conv2d(inchannels, 3, (1, 1),bias=True)
def forward(self, x):
"""
forward pass of the block
:param x: input
:return: y => output
"""
y = self.conv_1(x)
return y
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class CCA(nn.Module):
"""
CCA Block
"""
def __init__(self, F_g, F_x):
super().__init__()
self.mlp_x = nn.Sequential(
Flatten(),
nn.Linear(F_x, F_x))
self.mlp_g = nn.Sequential(
Flatten(),
nn.Linear(F_g, F_x))
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
# channel-wise attention
avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_x = self.mlp_x(avg_pool_x)
avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3)))
channel_att_g = self.mlp_g(avg_pool_g)
channel_att_sum = (channel_att_x + channel_att_g)/2.0
scale = th.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
x_after_channel = x * scale
out = self.relu(x_after_channel)
return out