Spaces:
Build error
Build error
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch as th | |
import datetime | |
import os | |
import time | |
import timeit | |
import copy | |
import numpy as np | |
from torch.nn import ModuleList | |
from torch.nn import Conv2d | |
from torch.nn import LeakyReLU | |
#PixelwiseNorm代替了BatchNorm | |
class PixelwiseNorm(th.nn.Module): | |
def __init__(self): | |
super(PixelwiseNorm, self).__init__() | |
def forward(self, x, alpha=1e-8): | |
""" | |
forward pass of the module | |
:param x: input activations volume | |
:param alpha: small number for numerical stability | |
:return: y => pixel normalized activations | |
""" | |
y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW] | |
y = x / y # normalize the input x volume | |
return y | |
class MinibatchStdDev(th.nn.Module): | |
""" | |
Minibatch standard deviation layer for the discriminator | |
""" | |
def __init__(self): | |
""" | |
derived class constructor | |
""" | |
super().__init__() | |
def forward(self, x, alpha=1e-8): | |
""" | |
forward pass of the layer | |
:param x: input activation volume | |
:param alpha: small number for numerical stability | |
:return: y => x appended with standard deviation constant map | |
""" | |
batch_size, _, height, width = x.shape | |
# [B x C x H x W] Subtract mean over batch. | |
y = x - x.mean(dim=0, keepdim=True) | |
# [1 x C x H x W] Calc standard deviation over batch | |
y = th.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha) | |
# [1] Take average over feature_maps and pixels. | |
y = y.mean().view(1, 1, 1, 1) | |
# [B x 1 x H x W] Replicate over group and pixels. | |
y = y.repeat(batch_size, 1, height, width) | |
# [B x C x H x W] Append as new feature_map. | |
y = th.cat([x, y], 1) | |
# return the computed values: | |
return y | |
# ========================================================== | |
# Equalized learning rate blocks: | |
# extending Conv2D and Deconv2D layers for equalized learning rate logic | |
# ========================================================== | |
class _equalized_conv2d(th.nn.Module): | |
""" conv2d with the concept of equalized learning rate | |
Args: | |
:param c_in: input channels | |
:param c_out: output channels | |
:param k_size: kernel size (h, w) should be a tuple or a single integer | |
:param stride: stride for conv | |
:param pad: padding | |
:param bias: whether to use bias or not | |
""" | |
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True): | |
""" constructor for the class """ | |
from torch.nn.modules.utils import _pair | |
from numpy import sqrt, prod | |
super().__init__() | |
# define the weight and bias if to be used | |
self.weight = th.nn.Parameter(th.nn.init.normal_( | |
th.empty(c_out, c_in, *_pair(k_size)) | |
)) | |
self.use_bias = bias | |
self.stride = stride | |
self.pad = pad | |
if self.use_bias: | |
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) | |
fan_in = prod(_pair(k_size)) * c_in # value of fan_in | |
self.scale = sqrt(2) / sqrt(fan_in) | |
def forward(self, x): | |
""" | |
forward pass of the network | |
:param x: input | |
:return: y => output | |
""" | |
from torch.nn.functional import conv2d | |
return conv2d(input=x, | |
weight=self.weight * self.scale, # scale the weight on runtime | |
bias=self.bias if self.use_bias else None, | |
stride=self.stride, | |
padding=self.pad) | |
def extra_repr(self): | |
return ", ".join(map(str, self.weight.shape)) | |
class _equalized_deconv2d(th.nn.Module): | |
""" Transpose convolution using the equalized learning rate | |
Args: | |
:param c_in: input channels | |
:param c_out: output channels | |
:param k_size: kernel size | |
:param stride: stride for convolution transpose | |
:param pad: padding | |
:param bias: whether to use bias or not | |
""" | |
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True): | |
""" constructor for the class """ | |
from torch.nn.modules.utils import _pair | |
from numpy import sqrt | |
super().__init__() | |
# define the weight and bias if to be used | |
self.weight = th.nn.Parameter(th.nn.init.normal_( | |
th.empty(c_in, c_out, *_pair(k_size)) | |
)) | |
self.use_bias = bias | |
self.stride = stride | |
self.pad = pad | |
if self.use_bias: | |
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) | |
fan_in = c_in # value of fan_in for deconv | |
self.scale = sqrt(2) / sqrt(fan_in) | |
def forward(self, x): | |
""" | |
forward pass of the layer | |
:param x: input | |
:return: y => output | |
""" | |
from torch.nn.functional import conv_transpose2d | |
return conv_transpose2d(input=x, | |
weight=self.weight * self.scale, # scale the weight on runtime | |
bias=self.bias if self.use_bias else None, | |
stride=self.stride, | |
padding=self.pad) | |
def extra_repr(self): | |
return ", ".join(map(str, self.weight.shape)) | |
#basic block of the encoding part of the genarater | |
#编码器的基本卷积块 | |
class conv_block(nn.Module): | |
""" | |
Convolution Block | |
with two convolution layers | |
""" | |
def __init__(self, in_ch, out_ch,use_eql=True): | |
super(conv_block, self).__init__() | |
if use_eql: | |
self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1), | |
pad=0, bias=True) | |
self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3), | |
pad=1, bias=True) | |
self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3), | |
pad=1, bias=True) | |
else: | |
self.conv_1 = Conv2d(in_ch, out_ch, (3, 3), | |
padding=1, bias=True) | |
self.conv_2 = Conv2d(out_ch, out_ch, (3, 3), | |
padding=1, bias=True) | |
# pixel_wise feature normalizer: | |
self.pixNorm = PixelwiseNorm() | |
# leaky_relu: | |
self.lrelu = LeakyReLU(0.2) | |
def forward(self, x): | |
""" | |
forward pass of the block | |
:param x: input | |
:return: y => output | |
""" | |
from torch.nn.functional import interpolate | |
#y = interpolate(x, scale_factor=2) | |
y=self.conv_1(self.lrelu(self.pixNorm(x))) | |
residual=y | |
y=self.conv_2(self.lrelu(self.pixNorm(y))) | |
y=self.conv_3(self.lrelu(self.pixNorm(y))) | |
y=y+residual | |
return y | |
#basic up convolution block of the encoding part of the genarater | |
#编码器的基本卷积块 | |
class up_conv(nn.Module): | |
""" | |
Up Convolution Block | |
""" | |
def __init__(self, in_ch, out_ch,use_eql=True): | |
super(up_conv, self).__init__() | |
if use_eql: | |
self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1), | |
pad=0, bias=True) | |
self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3), | |
pad=1, bias=True) | |
self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3), | |
pad=1, bias=True) | |
else: | |
self.conv_1 = Conv2d(in_ch, out_ch, (3, 3), | |
padding=1, bias=True) | |
self.conv_2 = Conv2d(out_ch, out_ch, (3, 3), | |
padding=1, bias=True) | |
# pixel_wise feature normalizer: | |
self.pixNorm = PixelwiseNorm() | |
# leaky_relu: | |
self.lrelu = LeakyReLU(0.2) | |
def forward(self, x): | |
""" | |
forward pass of the block | |
:param x: input | |
:return: y => output | |
""" | |
from torch.nn.functional import interpolate | |
x = interpolate(x, scale_factor=2, mode="bilinear") | |
y=self.conv_1(self.lrelu(self.pixNorm(x))) | |
residual=y | |
y=self.conv_2(self.lrelu(self.pixNorm(y))) | |
y=self.conv_3(self.lrelu(self.pixNorm(y))) | |
y=y+residual | |
return y | |
#判别器的最后一层 | |
class DisFinalBlock(th.nn.Module): | |
""" Final block for the Discriminator """ | |
def __init__(self, in_channels, use_eql=True): | |
""" | |
constructor of the class | |
:param in_channels: number of input channels | |
:param use_eql: whether to use equalized learning rate | |
""" | |
from torch.nn import LeakyReLU | |
from torch.nn import Conv2d | |
super().__init__() | |
# declare the required modules for forward pass | |
self.batch_discriminator = MinibatchStdDev() | |
if use_eql: | |
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), | |
pad=1, bias=True) | |
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4),stride=2,pad=1, | |
bias=True) | |
# final layer emulates the fully connected layer | |
self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True) | |
else: | |
# modules required: | |
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True) | |
self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True) | |
# final conv layer emulates a fully connected layer | |
self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True) | |
# leaky_relu: | |
self.lrelu = LeakyReLU(0.2) | |
def forward(self, x): | |
""" | |
forward pass of the FinalBlock | |
:param x: input | |
:return: y => output | |
""" | |
# minibatch_std_dev layer | |
y = self.batch_discriminator(x) | |
# define the computations | |
y = self.lrelu(self.conv_1(y)) | |
y = self.lrelu(self.conv_2(y)) | |
# fully connected layer | |
y = self.conv_3(y) # This layer has linear activation | |
# flatten the output raw discriminator scores | |
return y | |
#判别器基本卷积块 | |
class DisGeneralConvBlock(th.nn.Module): | |
""" General block in the discriminator """ | |
def __init__(self, in_channels, out_channels, use_eql=True): | |
""" | |
constructor of the class | |
:param in_channels: number of input channels | |
:param out_channels: number of output channels | |
:param use_eql: whether to use equalized learning rate | |
""" | |
from torch.nn import AvgPool2d, LeakyReLU | |
from torch.nn import Conv2d | |
super().__init__() | |
if use_eql: | |
self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3), | |
pad=1, bias=True) | |
self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3), | |
pad=1, bias=True) | |
else: | |
# convolutional modules | |
self.conv_1 = Conv2d(in_channels, in_channels, (3, 3), | |
padding=1, bias=True) | |
self.conv_2 = Conv2d(in_channels, out_channels, (3, 3), | |
padding=1, bias=True) | |
self.downSampler = AvgPool2d(2) # downsampler | |
# leaky_relu: | |
self.lrelu = LeakyReLU(0.2) | |
def forward(self, x): | |
""" | |
forward pass of the module | |
:param x: input | |
:return: y => output | |
""" | |
# define the computations | |
y = self.lrelu(self.conv_1(x)) | |
y = self.lrelu(self.conv_2(y)) | |
y = self.downSampler(y) | |
return y | |
class from_rgb(nn.Module): | |
""" | |
The RGB image is transformed into a multi-channel feature map to be concatenated with | |
the feature map with the same number of channels in the network | |
把RGB图转换为多通道特征图,以便与网络中相同通道数的特征图拼接 | |
""" | |
def __init__(self, outchannels, use_eql=True): | |
super(from_rgb, self).__init__() | |
if use_eql: | |
self.conv_1 = _equalized_conv2d(3, outchannels, (1, 1), bias=True) | |
else: | |
self.conv_1 = nn.Conv2d(3, outchannels, (1, 1),bias=True) | |
# pixel_wise feature normalizer: | |
self.pixNorm = PixelwiseNorm() | |
# leaky_relu: | |
self.lrelu = LeakyReLU(0.2) | |
def forward(self, x): | |
""" | |
forward pass of the block | |
:param x: input | |
:return: y => output | |
""" | |
y = self.pixNorm(self.lrelu(self.conv_1(x))) | |
return y | |
class to_rgb(nn.Module): | |
""" | |
把多通道特征图转换为RGB三通道图,以便输入判别器 | |
The multi-channel feature map is converted into RGB image for input to the discriminator | |
""" | |
def __init__(self, inchannels, use_eql=True): | |
super(to_rgb, self).__init__() | |
if use_eql: | |
self.conv_1 = _equalized_conv2d(inchannels, 3, (1, 1), bias=True) | |
else: | |
self.conv_1 = nn.Conv2d(inchannels, 3, (1, 1),bias=True) | |
def forward(self, x): | |
""" | |
forward pass of the block | |
:param x: input | |
:return: y => output | |
""" | |
y = self.conv_1(x) | |
return y | |
class Flatten(nn.Module): | |
def forward(self, x): | |
return x.view(x.size(0), -1) | |
class CCA(nn.Module): | |
""" | |
CCA Block | |
""" | |
def __init__(self, F_g, F_x): | |
super().__init__() | |
self.mlp_x = nn.Sequential( | |
Flatten(), | |
nn.Linear(F_x, F_x)) | |
self.mlp_g = nn.Sequential( | |
Flatten(), | |
nn.Linear(F_g, F_x)) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, g, x): | |
# channel-wise attention | |
avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) | |
channel_att_x = self.mlp_x(avg_pool_x) | |
avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3))) | |
channel_att_g = self.mlp_g(avg_pool_g) | |
channel_att_sum = (channel_att_x + channel_att_g)/2.0 | |
scale = th.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) | |
x_after_channel = x * scale | |
out = self.relu(x_after_channel) | |
return out |