|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Created in September 2020
|
| @author: davide.cozzolino
|
| """
|
|
|
| import math
|
| import torch.nn as nn
|
|
|
| def conv_with_padding(in_planes, out_planes, kernelsize, stride=1, dilation=1, bias=False, padding = None):
|
| if padding is None:
|
| padding = kernelsize//2
|
| return nn.Conv2d(in_planes, out_planes, kernel_size=kernelsize, stride=stride, dilation=dilation, padding=padding, bias=bias)
|
|
|
| def conv_init(conv, act='linear'):
|
| r"""
|
| Reproduces conv initialization from DnCNN
|
| """
|
| n = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels
|
| conv.weight.data.normal_(0, math.sqrt(2. / n))
|
|
|
| def batchnorm_init(m, kernelsize=3):
|
| r"""
|
| Reproduces batchnorm initialization from DnCNN
|
| """
|
| n = kernelsize**2 * m.num_features
|
| m.weight.data.normal_(0, math.sqrt(2. / (n)))
|
| m.bias.data.zero_()
|
|
|
| def make_activation(act):
|
| if act is None:
|
| return None
|
| elif act == 'relu':
|
| return nn.ReLU(inplace=True)
|
| elif act == 'tanh':
|
| return nn.Tanh()
|
| elif act == 'leaky_relu':
|
| return nn.LeakyReLU(inplace=True)
|
| elif act == 'softmax':
|
| return nn.Softmax()
|
| elif act == 'linear':
|
| return None
|
| else:
|
| assert(False)
|
|
|
| def make_net(nplanes_in, kernels, features, bns, acts, dilats, bn_momentum = 0.1, padding=None):
|
| r"""
|
| :param nplanes_in: number of of input feature channels
|
| :param kernels: list of kernel size for convolution layers
|
| :param features: list of hidden layer feature channels
|
| :param bns: list of whether to add batchnorm layers
|
| :param acts: list of activations
|
| :param dilats: list of dilation factors
|
| :param bn_momentum: momentum of batchnorm
|
| :param padding: integer for padding (None for same padding)
|
| """
|
|
|
| depth = len(features)
|
| assert(len(features)==len(kernels))
|
|
|
| layers = list()
|
| for i in range(0,depth):
|
| if i==0:
|
| in_feats = nplanes_in
|
| else:
|
| in_feats = features[i-1]
|
|
|
| elem = conv_with_padding(in_feats, features[i], kernelsize=kernels[i], dilation=dilats[i], padding=padding, bias=not(bns[i]))
|
| conv_init(elem, act=acts[i])
|
| layers.append(elem)
|
|
|
| if bns[i]:
|
| elem = nn.BatchNorm2d(features[i], momentum = bn_momentum)
|
| batchnorm_init(elem, kernelsize=kernels[i])
|
| layers.append(elem)
|
|
|
| elem = make_activation(acts[i])
|
| if elem is not None:
|
| layers.append(elem)
|
|
|
| return nn.Sequential(*layers)
|
|
|
| class DnCNN(nn.Module):
|
| r"""
|
| Implements a DnCNN network
|
| """
|
| def __init__(self, nplanes_in, nplanes_out, features, kernel, depth, activation, residual, bn, lastact=None, bn_momentum = 0.10, padding=None):
|
| r"""
|
| :param nplanes_in: number of of input feature channels
|
| :param nplanes_out: number of of output feature channels
|
| :param features: number of of hidden layer feature channels
|
| :param kernel: kernel size of convolution layers
|
| :param depth: number of convolution layers (minimum 2)
|
| :param bn: whether to add batchnorm layers
|
| :param residual: whether to add a residual connection from input to output
|
| :param bn_momentum: momentum of batchnorm
|
| :param padding: inteteger for padding
|
| """
|
| super(DnCNN, self).__init__()
|
|
|
| self.residual = residual
|
| self.nplanes_out = nplanes_out
|
| self.nplanes_in = nplanes_in
|
|
|
| kernels = [kernel, ] * depth
|
| features = [features, ] * (depth-1) + [nplanes_out, ]
|
| bns = [False, ] + [bn,] * (depth - 2) + [False, ]
|
| dilats = [1, ] * depth
|
| acts = [activation, ] * (depth - 1) + [lastact, ]
|
| self.layers = make_net(nplanes_in, kernels, features, bns, acts, dilats=dilats, bn_momentum = bn_momentum, padding=padding)
|
|
|
|
|
| def forward(self, x):
|
| shortcut = x
|
|
|
| x = self.layers(x)
|
|
|
| if self.residual:
|
| nshortcut = min(self.nplanes_in, self.nplanes_out)
|
| x[:, :nshortcut, :, :] = x[:, :nshortcut, :, :] + shortcut[:, :nshortcut, :, :]
|
|
|
| return x
|
|
|
|
|
| def add_commandline_networkparams(parser, name, features, depth, kernel, activation, bn):
|
| parser.add_argument("--{}.{}".format(name, "features" ), type=int, default=features )
|
| parser.add_argument("--{}.{}".format(name, "depth" ), type=int, default=depth )
|
| parser.add_argument("--{}.{}".format(name, "kernel" ), type=int, default=kernel )
|
| parser.add_argument("--{}.{}".format(name, "activation"), type=str, default=activation)
|
|
|
| bnarg = "{}.{}".format(name, "bn")
|
| parser.add_argument("--"+bnarg , action="store_true", dest=bnarg)
|
| parser.add_argument("--{}.{}".format(name, "no-bn"), action="store_false", dest=bnarg)
|
| parser.set_defaults(**{bnarg: bn})
|
|
|
|
|