import math import torch import torch.nn as nn import torch.nn.init import torch.nn.functional as F import functools # used by RRDBNet def GetModel(opt): if opt.model.lower() == 'edsr': net = EDSR(opt) elif opt.model.lower() == 'edsr2max': net = EDSR2Max(normalization=opt.norm,nch_in=opt.nch_in,nch_out=opt.nch_out,scale=opt.scale) elif opt.model.lower() == 'edsr3max': net = EDSR3Max(normalization=opt.norm,nch_in=opt.nch_in,nch_out=opt.nch_out,scale=opt.scale) elif opt.model.lower() == 'rcan': net = RCAN(opt) elif opt.model.lower() == 'rnan': net = RNAN(opt) elif opt.model.lower() == 'rrdb': net = RRDBNet(opt) elif opt.model.lower() == 'srresnet' or opt.model.lower() == 'srgan': net = Generator(16, opt) elif opt.model.lower() == 'unet': net = UNet(opt.nch_in,opt.nch_out,opt) elif opt.model.lower() == 'unet_n2n': net = UNet_n2n(opt.nch_in,opt.nch_out,opt) elif opt.model.lower() == 'unet60m': net = UNet60M(opt.nch_in,opt.nch_out) elif opt.model.lower() == 'unetrep': net = UNetRep(opt.nch_in,opt.nch_out) elif opt.model.lower() == 'unetgreedy': net = UNetGreedy(opt.nch_in,opt.nch_out) elif opt.model.lower() == 'mlpnet': net = MLPNet() elif opt.model.lower() == 'ffdnet': net = FFDNet(opt.nch_in) elif opt.model.lower() == 'dncnn': net = DNCNN(opt.nch_in) elif opt.model.lower() == 'fouriernet': net = FourierNet() elif opt.model.lower() == 'fourierconvnet': net = FourierConvNet() else: print("model undefined") return None net.to(opt.device) if opt.multigpu: net = nn.DataParallel(net) return net class MeanShift(nn.Conv2d): def __init__( self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std self.requires_grad = False def normalizationTransforms(normtype): if normtype.lower() == 'div2k': normalize = MeanShift(1, [0.4485, 0.4375, 0.4045], [0.2436, 0.2330, 0.2424]) unnormalize = MeanShift(1, [-1.8411, -1.8777, -1.6687], [4.1051, 4.2918, 4.1254]) print('using div2k normalization') elif normtype.lower() == 'pcam': normalize = MeanShift(1, [0.6975, 0.5348, 0.688], [0.2361, 0.2786, 0.2146]) unnormalize = MeanShift(1, [-2.9547, -1.9198, -3.20643], [4.2363, 3.58972, 4.66049]) print('using pcam normalization') elif normtype.lower() == 'div2k_std1': normalize = MeanShift(1, [0.4485, 0.4375, 0.4045], [1,1,1]) unnormalize = MeanShift(1, [-0.4485, -0.4375, -0.4045], [1,1,1]) print('using div2k normalization with std 1') elif normtype.lower() == 'pcam_std1': normalize = MeanShift(1, [0.6975, 0.5348, 0.688], [1,1,1]) unnormalize = MeanShift(1, [-0.6975, -0.5348, -0.688], [1,1,1]) print('using pcam normalization with std 1') else: print('not using normalization') return None, None return normalize, unnormalize def conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) class BasicBlock(nn.Sequential): def __init__( self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): m = [conv(in_channels, out_channels, kernel_size, bias=bias)] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock, self).__init__() m = [] m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) m.append(nn.ReLU(True)) m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class ResBlock2Max(nn.Module): def __init__( self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock2Max, self).__init__() m = [] m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) m.append(nn.MaxPool2d(2)) m.append(nn.ReLU(True)) m.append(conv(n_feats, 2*n_feats, kernel_size, bias=bias)) m.append(nn.MaxPool2d(2)) m.append(nn.ReLU(True)) m.append(conv(2*n_feats, 4*n_feats, kernel_size, bias=bias)) m.append(nn.ReLU(True)) m.append(nn.ConvTranspose2d(4*n_feats,2*n_feats,3,stride=2, padding=1, output_padding=1)) m.append(nn.ConvTranspose2d(2*n_feats,n_feats,3,stride=2, padding=1, output_padding=1)) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class ResBlock3Max(nn.Module): def __init__( self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock3Max, self).__init__() m = [] m.append(conv(n_feats, 2*n_feats, kernel_size, bias=bias)) m.append(nn.MaxPool2d(2)) m.append(nn.ReLU(True)) m.append(conv(2*n_feats, 2*n_feats, kernel_size, bias=bias)) m.append(nn.MaxPool2d(2)) m.append(nn.ReLU(True)) m.append(conv(2*n_feats, 4*n_feats, kernel_size, bias=bias)) m.append(nn.MaxPool2d(2)) m.append(nn.ReLU(True)) m.append(conv(4*n_feats, 8*n_feats, kernel_size, bias=bias)) m.append(nn.ReLU(True)) m.append(nn.ConvTranspose2d(8*n_feats,4*n_feats,3,stride=2, padding=1, output_padding=1)) m.append(nn.ConvTranspose2d(4*n_feats,2*n_feats,3,stride=2, padding=1, output_padding=1)) m.append(nn.ConvTranspose2d(2*n_feats,n_feats,3,stride=2, padding=1, output_padding=1)) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feats, 4 * n_feats, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) elif scale == 3: m.append(conv(n_feats, 9 * n_feats, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) else: raise NotImplementedError super(Upsampler, self).__init__(*m) class EDSR(nn.Module): def __init__(self,opt): super(EDSR, self).__init__() n_resblocks = 16 n_feats = 64 kernel_size = 3 act = nn.ReLU(True) if not opt.norm == None: self.normalize, self.unnormalize = normalizationTransforms(opt.norm) else: self.normalize, self.unnormalize = None, None # define head module m_head = [conv(opt.nch_in, n_feats, kernel_size)] # define body module m_body = [ ResBlock( conv, n_feats, kernel_size, act=act, res_scale=0.1 ) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module if opt.scale == 1: if opt.task == 'segment': m_tail = [nn.Conv2d(n_feats, 2, 1)] else: m_tail = [conv(n_feats, opt.nch_out, kernel_size)] else: m_tail = [ Upsampler(conv, opt.scale, n_feats, act=False), conv(n_feats, opt.nch_out, kernel_size)] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): if not self.normalize == None: x = self.normalize(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) if not self.unnormalize == None: x = self.unnormalize(x) return x class EDSR2Max(nn.Module): def __init__(self, normalization=None,nch_in=3,nch_out=3,scale=4): super(EDSR2Max, self).__init__() n_resblocks = 16 n_feats = 64 kernel_size = 3 act = nn.ReLU(True) if not opt.norm == None: self.normalize, self.unnormalize = normalizationTransforms(normalization) else: self.normalize, self.unnormalize = None, None # define head module m_head = [conv(nch_in, n_feats, kernel_size)] # define body module m_body = [ ResBlock2Max( conv, n_feats, kernel_size, act=act, res_scale=0.1 ) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module m_tail = [ conv(n_feats, nch_out, kernel_size) ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): if not self.normalize == None: x = self.normalize(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) if not self.unnormalize == None: x = self.unnormalize(x) return x class EDSR3Max(nn.Module): def __init__(self, normalization=None,nch_in=3,nch_out=3,scale=4): super(EDSR3Max, self).__init__() n_resblocks = 16 n_feats = 64 kernel_size = 3 act = nn.ReLU(True) if not opt.norm == None: self.normalize, self.unnormalize = normalizationTransforms(normalization) else: self.normalize, self.unnormalize = None, None # define head module m_head = [conv(nch_in, n_feats, kernel_size)] # define body module m_body = [ ResBlock3Max( conv, n_feats, kernel_size, act=act, res_scale=0.1 ) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module m_tail = [ conv(n_feats, nch_out, kernel_size) ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): if not self.normalize == None: x = self.normalize(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) if not self.unnormalize == None: x = self.unnormalize(x) return x # ----------------------------------- RCAN ------------------------------------------ ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) #res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ for _ in range(n_resblocks)] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) class RCAN(nn.Module): def __init__(self, opt): super(RCAN, self).__init__() n_resgroups = opt.n_resgroups n_resblocks = opt.n_resblocks n_feats = opt.n_feats kernel_size = 3 reduction = opt.reduction act = nn.ReLU(True) self.narch = opt.narch if not opt.norm == None: self.normalize, self.unnormalize = normalizationTransforms(opt.norm) else: self.normalize, self.unnormalize = None, None # define head module if self.narch == 0: modules_head = [conv(opt.nch_in, n_feats, kernel_size)] self.head = nn.Sequential(*modules_head) else: self.head0 = conv(1, n_feats, kernel_size) self.head1 = conv(1, n_feats, kernel_size) self.head2 = conv(1, n_feats, kernel_size) self.head3 = conv(1, n_feats, kernel_size) self.head4 = conv(1, n_feats, kernel_size) self.head5 = conv(1, n_feats, kernel_size) self.head6 = conv(1, n_feats, kernel_size) self.head7 = conv(1, n_feats, kernel_size) self.head8 = conv(1, n_feats, kernel_size) self.combineHead = conv(9*n_feats, n_feats, kernel_size) # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1, n_resblocks=n_resblocks) \ for _ in range(n_resgroups)] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module if opt.scale == 1: if opt.task == 'segment': modules_tail = [nn.Conv2d(n_feats, opt.nch_out, 1)] else: modules_tail = [conv(n_feats, opt.nch_out, kernel_size)] else: modules_tail = [ Upsampler(conv, opt.scale, n_feats, act=False), conv(n_feats, opt.nch_out, kernel_size)] self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): if not self.normalize == None: x = self.normalize(x) if self.narch == 0: x = self.head(x) else: x0 = self.head0(x[:,0:0+1,:,:]) x1 = self.head1(x[:,1:1+1,:,:]) x2 = self.head2(x[:,2:2+1,:,:]) x3 = self.head3(x[:,3:3+1,:,:]) x4 = self.head4(x[:,4:4+1,:,:]) x5 = self.head5(x[:,5:5+1,:,:]) x6 = self.head6(x[:,6:6+1,:,:]) x7 = self.head7(x[:,7:7+1,:,:]) x8 = self.head8(x[:,8:8+1,:,:]) x = torch.cat((x0,x1,x2,x3,x4,x5,x6,x7,x8), 1) x = self.combineHead(x) res = self.body(x) res += x x = self.tail(res) if not self.unnormalize == None: x = self.unnormalize(x) return x # ----------------------------------- RNAN ------------------------------------------ # add NonLocalBlock2D # reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py class NonLocalBlock2D(nn.Module): def __init__(self, in_channels, inter_channels): super(NonLocalBlock2D, self).__init__() self.in_channels = in_channels self.inter_channels = inter_channels self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) # for pytorch 0.3.1 #nn.init.constant(self.W.weight, 0) #nn.init.constant(self.W.bias, 0) # for pytorch 0.4.0 nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0,2,1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0,2,1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_x, phi_x) f_div_C = F.softmax(f, dim=1) y = torch.matmul(f_div_C, g_x) y = y.permute(0,2,1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x return z ## define trunk branch class TrunkBranch(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(TrunkBranch, self).__init__() modules_body = [] for i in range(2): modules_body.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) self.body = nn.Sequential(*modules_body) def forward(self, x): tx = self.body(x) return tx ## define mask branch class MaskBranchDownUp(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(MaskBranchDownUp, self).__init__() MB_RB1 = [] MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) MB_Down = [] MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1)) MB_RB2 = [] for i in range(2): MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) MB_Up = [] MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2)) MB_RB3 = [] MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) MB_1x1conv = [] MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True)) MB_sigmoid = [] MB_sigmoid.append(nn.Sigmoid()) self.MB_RB1 = nn.Sequential(*MB_RB1) self.MB_Down = nn.Sequential(*MB_Down) self.MB_RB2 = nn.Sequential(*MB_RB2) self.MB_Up = nn.Sequential(*MB_Up) self.MB_RB3 = nn.Sequential(*MB_RB3) self.MB_1x1conv = nn.Sequential(*MB_1x1conv) self.MB_sigmoid = nn.Sequential(*MB_sigmoid) def forward(self, x): x_RB1 = self.MB_RB1(x) x_Down = self.MB_Down(x_RB1) x_RB2 = self.MB_RB2(x_Down) x_Up = self.MB_Up(x_RB2) x_preRB3 = x_RB1 + x_Up x_RB3 = self.MB_RB3(x_preRB3) x_1x1 = self.MB_1x1conv(x_RB3) mx = self.MB_sigmoid(x_1x1) return mx ## define nonlocal mask branch class NLMaskBranchDownUp(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(NLMaskBranchDownUp, self).__init__() MB_RB1 = [] MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2)) MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) MB_Down = [] MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1)) MB_RB2 = [] for i in range(2): MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) MB_Up = [] MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2)) MB_RB3 = [] MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) MB_1x1conv = [] MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True)) MB_sigmoid = [] MB_sigmoid.append(nn.Sigmoid()) self.MB_RB1 = nn.Sequential(*MB_RB1) self.MB_Down = nn.Sequential(*MB_Down) self.MB_RB2 = nn.Sequential(*MB_RB2) self.MB_Up = nn.Sequential(*MB_Up) self.MB_RB3 = nn.Sequential(*MB_RB3) self.MB_1x1conv = nn.Sequential(*MB_1x1conv) self.MB_sigmoid = nn.Sequential(*MB_sigmoid) def forward(self, x): x_RB1 = self.MB_RB1(x) x_Down = self.MB_Down(x_RB1) x_RB2 = self.MB_RB2(x_Down) x_Up = self.MB_Up(x_RB2) x_preRB3 = x_RB1 + x_Up x_RB3 = self.MB_RB3(x_preRB3) x_1x1 = self.MB_1x1conv(x_RB3) mx = self.MB_sigmoid(x_1x1) return mx ## define residual attention module class ResAttModuleDownUpPlus(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResAttModuleDownUpPlus, self).__init__() RA_RB1 = [] RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) RA_TB = [] RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) RA_MB = [] RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) RA_tail = [] for i in range(2): RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) self.RA_RB1 = nn.Sequential(*RA_RB1) self.RA_TB = nn.Sequential(*RA_TB) self.RA_MB = nn.Sequential(*RA_MB) self.RA_tail = nn.Sequential(*RA_tail) def forward(self, input): RA_RB1_x = self.RA_RB1(input) tx = self.RA_TB(RA_RB1_x) mx = self.RA_MB(RA_RB1_x) txmx = tx * mx hx = txmx + RA_RB1_x hx = self.RA_tail(hx) return hx ## define nonlocal residual attention module class NLResAttModuleDownUpPlus(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(NLResAttModuleDownUpPlus, self).__init__() RA_RB1 = [] RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) RA_TB = [] RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) RA_MB = [] RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) RA_tail = [] for i in range(2): RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) self.RA_RB1 = nn.Sequential(*RA_RB1) self.RA_TB = nn.Sequential(*RA_TB) self.RA_MB = nn.Sequential(*RA_MB) self.RA_tail = nn.Sequential(*RA_tail) def forward(self, input): RA_RB1_x = self.RA_RB1(input) tx = self.RA_TB(RA_RB1_x) mx = self.RA_MB(RA_RB1_x) txmx = tx * mx hx = txmx + RA_RB1_x hx = self.RA_tail(hx) return hx class _ResGroup(nn.Module): def __init__(self, conv, n_feats, kernel_size, act, res_scale): super(_ResGroup, self).__init__() modules_body = [] modules_body.append(ResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) modules_body.append(conv(n_feats, n_feats, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) return res class _NLResGroup(nn.Module): def __init__(self, conv, n_feats, kernel_size, act, res_scale): super(_NLResGroup, self).__init__() modules_body = [] modules_body.append(NLResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) modules_body.append(conv(n_feats, n_feats, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) return res class RNAN(nn.Module): def __init__(self, opt): super(RNAN, self).__init__() n_resgroups = opt.n_resgroups n_feats = opt.n_feats kernel_size = 3 reduction = opt.reduction act = nn.ReLU(True) print(n_resgroup2,n_resblock,n_feats,kernel_size,reduction,act) # RGB mean for DIV2K 1-800 # rgb_mean = (0.4488, 0.4371, 0.4040) # rgb_std = (1.0, 1.0, 1.0) # self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std) # define head module modules_head = [conv(opt.nch_in, n_feats, kernel_size)] # define body module modules_body_nl_low = [ _NLResGroup( conv, n_feats, kernel_size, act=act, res_scale=1)] modules_body = [ _ResGroup( conv, n_feats, kernel_size, act=act, res_scale=1) \ for _ in range(n_resgroups - 2)] modules_body_nl_high = [ _NLResGroup( conv, n_feats, kernel_size, act=act, res_scale=1)] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, opt.scale, n_feats, act=False), conv(n_feats, opt.nch_out, kernel_size)] # self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body_nl_low = nn.Sequential(*modules_body_nl_low) self.body = nn.Sequential(*modules_body) self.body_nl_high = nn.Sequential(*modules_body_nl_high) self.tail = nn.Sequential(*modules_tail) def forward(self, x): # x = self.sub_mean(x) feats_shallow = self.head(x) res = self.body_nl_low(feats_shallow) res = self.body(res) res = self.body_nl_high(res) res += feats_shallow res_main = self.tail(res) # res_main = self.add_mean(res_main) return res_main class FourierNet(nn.Module): def __init__(self): super(FourierNet, self).__init__() self.inp = nn.Linear(85*85*9,85*85) def forward(self, x): x = x.view(-1,85*85*9) x = (self.inp(x)) # x = (self.lay1(x)) x = x.view(-1,1,85,85) return x class FourierConvNet(nn.Module): def __init__(self): super(FourierConvNet, self).__init__() # self.inp = nn.Conv2d(18,32,3, stride=1, padding=1) # self.lay1 = nn.Conv2d(32,32,3, stride=1, padding=1) # self.lay2 = nn.Conv2d(32,32,3, stride=1, padding=1) # self.lay3 = nn.Conv2d(32,32,3, stride=1, padding=1) # self.pool = nn.MaxPool2d(2,2) # self.out = nn.Conv2d(32,1,3, stride=1, padding=1) # self.labels = nn.Linear(4096,18) self.inc = inconv(18, 64) self.down1 = down(64, 128) self.down2 = down(128, 256) self.down3 = down(256, 512) self.down4 = down(512, 512) self.up1 = up(1024, 256) self.up2 = up(512, 128) self.up3 = up(256, 64) self.up4 = up(128, 64) self.outc = outconv(64, 9) # two channels for complex def forward(self, x): # x = self.inp(x) # x = torch.rfft(x,2,onesided=False) # # x = torch.log( torch.abs(x) + 1 ) # x = x.permute(0,1,4,2,3) # put real and imag parts after stack index # x = x.contiguous().view(-1,18,256,256) # x = F.relu(self.inp(x)) # x = self.pool(x) # to 128 # x = F.relu(self.lay2(x)) # x = self.pool(x) # to 64 # x = F.relu(self.lay3(x)) # x = self.out(x) # x = x.view(-1,4096) # x = self.labels(x) x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) x = self.outc(x) x = torch.log(torch.abs(x)) # x = x.permute(0,2,3,1) # x = torch.irfft(x,2,onesided=False) return x # super(UNet, self).__init__() # self.inc = inconv(n_channels, 64) # self.down1 = down(64, 128) # self.down2 = down(128, 256) # self.down3 = down(256, 512) # self.down4 = down(512, 512) # self.up1 = up(1024, 256) # self.up2 = up(512, 128) # self.up3 = up(256, 64) # self.up4 = up(128, 64) # if opt.task == 'segment': # self.outc = outconv(64, 2) # else: # self.outc = outconv(64, n_classes) # # Initialize weights # # self._init_weights() # def _init_weights(self): # """Initializes weights using He et al. (2015).""" # for m in self.modules(): # if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight.data) # m.bias.data.zero_() # def forward(self, x): # x1 = self.inc(x) # x2 = self.down1(x1) # x3 = self.down2(x2) # x4 = self.down3(x3) # x5 = self.down4(x4) # x = self.up1(x5, x4) # x = self.up2(x, x3) # x = self.up3(x, x2) # x = self.up4(x, x1) # x = self.outc(x) # return F.sigmoid(x) # ----------------------------------- RRDB (ESRGAN) ------------------------------------------ def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): torch.nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],0.1) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): '''Residual in Residual Dense Block''' def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x class RRDBNet(nn.Module): def __init__(self, opt, gc=32): super(RRDBNet, self).__init__() RRDB_block_f = functools.partial(RRDB, nf=opt.n_feats, gc=gc) self.conv_first = nn.Conv2d(opt.nch_in, opt.n_feats, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, opt.n_resblocks) self.trunk_conv = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(opt.n_feats, opt.nch_out, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.scale = opt.scale def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode='nearest'))) # fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=self.scale, mode='nearest'))) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out # ----------------------------------- SRGAN ------------------------------------------ def swish(x): return x * torch.sigmoid(x) class FeatureExtractor(nn.Module): def __init__(self, cnn, feature_layer=11): super(FeatureExtractor, self).__init__() self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer+1)]) def forward(self, x): return self.features(x) class residualBlock(nn.Module): def __init__(self, in_channels=64, k=3, n=64, s=1): super(residualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, n, k, stride=s, padding=1) self.bn1 = nn.BatchNorm2d(n) self.conv2 = nn.Conv2d(n, n, k, stride=s, padding=1) self.bn2 = nn.BatchNorm2d(n) def forward(self, x): y = swish(self.bn1(self.conv1(x))) return self.bn2(self.conv2(y)) + x class upsampleBlock(nn.Module): # Implements resize-convolution def __init__(self, in_channels, out_channels): super(upsampleBlock, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1) self.shuffler = nn.PixelShuffle(2) def forward(self, x): return swish(self.shuffler(self.conv(x))) class Generator(nn.Module): def __init__(self, n_residual_blocks, opt): super(Generator, self).__init__() self.n_residual_blocks = n_residual_blocks self.upsample_factor = opt.scale self.conv1 = nn.Conv2d(opt.nch_in, 64, 9, stride=1, padding=4) if not opt.norm == None: self.normalize, self.unnormalize = normalizationTransforms(opt.norm) else: self.normalize, self.unnormalize = None, None for i in range(self.n_residual_blocks): self.add_module('residual_block' + str(i+1), residualBlock()) self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(64) # for i in range(int(self.upsample_factor/2)): # self.add_module('upsample' + str(i+1), upsampleBlock(64, 256)) if opt.task == 'segment': self.conv3 = nn.Conv2d(64, 2, 1) else: self.conv3 = nn.Conv2d(64, opt.nch_out, 9, stride=1, padding=4) def forward(self, x): if not self.normalize == None: x = self.normalize(x) x = swish(self.conv1(x)) y = x.clone() for i in range(self.n_residual_blocks): y = self.__getattr__('residual_block' + str(i+1))(y) x = self.bn2(self.conv2(y)) + x # for i in range(int(self.upsample_factor/2)): # x = self.__getattr__('upsample' + str(i+1))(x) x = self.conv3(x) if not self.unnormalize == None: x = self.unnormalize(x) return x class Discriminator(nn.Module): def __init__(self,opt): super(Discriminator, self).__init__() self.conv1 = nn.Conv2d(opt.nch_out, 64, 3, stride=1, padding=1) self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1) self.bn2 = nn.BatchNorm2d(64) self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1) self.bn3 = nn.BatchNorm2d(128) self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1) self.bn4 = nn.BatchNorm2d(128) self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1) self.bn5 = nn.BatchNorm2d(256) self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1) self.bn6 = nn.BatchNorm2d(256) self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1) self.bn7 = nn.BatchNorm2d(512) self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1) self.bn8 = nn.BatchNorm2d(512) # Replaced original paper FC layers with FCN self.conv9 = nn.Conv2d(512, 1, 1, stride=1, padding=1) def forward(self, x): x = swish(self.conv1(x)) x = swish(self.bn2(self.conv2(x))) x = swish(self.bn3(self.conv3(x))) x = swish(self.bn4(self.conv4(x))) x = swish(self.bn5(self.conv5(x))) x = swish(self.bn6(self.conv6(x))) x = swish(self.bn7(self.conv7(x))) x = swish(self.bn8(self.conv8(x))) x = self.conv9(x) return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1) class UNet_n2n(nn.Module): """Custom U-Net architecture for Noise2Noise (see Appendix, Table 2).""" def __init__(self, in_channels=3, out_channels=3, opt = {}): """Initializes U-Net.""" super(UNet_n2n, self).__init__() # Layers: enc_conv0, enc_conv1, pool1 self._block1 = nn.Sequential( nn.Conv2d(in_channels, 48, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(48, 48, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2)) # Layers: enc_conv(i), pool(i); i=2..5 self._block2 = nn.Sequential( nn.Conv2d(48, 48, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2)) # Layers: enc_conv6, upsample5 self._block3 = nn.Sequential( nn.Conv2d(48, 48, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(48, 48, 3, stride=2, padding=1, output_padding=1)) #nn.Upsample(scale_factor=2, mode='nearest')) # Layers: dec_conv5a, dec_conv5b, upsample4 self._block4 = nn.Sequential( nn.Conv2d(96, 96, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(96, 96, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1)) #nn.Upsample(scale_factor=2, mode='nearest')) # Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2 self._block5 = nn.Sequential( nn.Conv2d(144, 96, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(96, 96, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1)) #nn.Upsample(scale_factor=2, mode='nearest')) # Layers: dec_conv1a, dec_conv1b, dec_conv1c, self._block6 = nn.Sequential( nn.Conv2d(96 + in_channels, 64, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 32, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, out_channels, 3, stride=1, padding=1), nn.LeakyReLU(0.1)) # Initialize weights self._init_weights() self.task = opt.task if opt.task == 'segment': self._block6 = nn.Sequential( nn.Conv2d(96 + in_channels, 64, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 32, 3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, 2, 1)) def _init_weights(self): """Initializes weights using He et al. (2015).""" for m in self.modules(): if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight.data) m.bias.data.zero_() def forward(self, x): """Through encoder, then decoder by adding U-skip connections. """ # Encoder pool1 = self._block1(x) pool2 = self._block2(pool1) pool3 = self._block2(pool2) pool4 = self._block2(pool3) pool5 = self._block2(pool4) # Decoder upsample5 = self._block3(pool5) concat5 = torch.cat((upsample5, pool4), dim=1) upsample4 = self._block4(concat5) concat4 = torch.cat((upsample4, pool3), dim=1) upsample3 = self._block5(concat4) concat3 = torch.cat((upsample3, pool2), dim=1) upsample2 = self._block5(concat3) concat2 = torch.cat((upsample2, pool1), dim=1) upsample1 = self._block5(concat2) concat1 = torch.cat((upsample1, x), dim=1) # Final activation return self._block6(concat1) # ------------------ Alternative UNet implementation (batchnorm. outcommented) class double_conv(nn.Module): '''(conv => BN => ReLU) * 2''' def __init__(self, in_ch, out_ch): super(double_conv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), # nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), # nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class inconv(nn.Module): def __init__(self, in_ch, out_ch): super(inconv, self).__init__() self.conv = double_conv(in_ch, out_ch) def forward(self, x): x = self.conv(x) return x class down(nn.Module): def __init__(self, in_ch, out_ch): super(down, self).__init__() self.mpconv = nn.Sequential( nn.MaxPool2d(2), # nn.Conv2d(in_ch,in_ch, 2, stride=2), double_conv(in_ch, out_ch) ) def forward(self, x): x = self.mpconv(x) return x class up(nn.Module): def __init__(self, in_ch, out_ch, bilinear=False): super(up, self).__init__() # would be a nice idea if the upsampling could be learned too, # but my machine do not have enough memory to handle all those weights if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)) # for padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) x = self.conv(x) return x class outconv(nn.Module): def __init__(self, in_ch, out_ch): super(outconv, self).__init__() self.conv = nn.Conv2d(in_ch, out_ch, 1) def forward(self, x): x = self.conv(x) return x class UNet(nn.Module): def __init__(self, n_channels, n_classes,opt): super(UNet, self).__init__() self.inc = inconv(n_channels, 64) self.down1 = down(64, 128) self.down2 = down(128, 256) self.down3 = down(256, 512) self.down4 = down(512, 512) self.up1 = up(1024, 256) self.up2 = up(512, 128) self.up3 = up(256, 64) self.up4 = up(128, 64) if opt.task == 'segment': self.outc = outconv(64, 2) else: self.outc = outconv(64, n_classes) # Initialize weights # self._init_weights() def _init_weights(self): """Initializes weights using He et al. (2015).""" for m in self.modules(): if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight.data) m.bias.data.zero_() def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) x = self.outc(x) return F.sigmoid(x) class UNet60M(nn.Module): def __init__(self, n_channels, n_classes): super(UNet60M, self).__init__() self.inc = inconv(n_channels, 64) self.down1 = down(64, 128) self.down2 = down(128, 256) self.down3 = down(256, 512) self.down4 = down(512, 1024) self.down5 = down(1024, 1024) self.up1 = up(2048, 512) self.up2 = up(1024, 256) self.up3 = up(512, 128) self.up4 = up(256, 64) self.up5 = up(128, 64) self.outc = outconv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x6 = self.down5(x5) x = self.up1(x6, x5) x = self.up2(x, x4) x = self.up3(x, x3) x = self.up4(x, x2) x = self.up5(x, x1) x = self.outc(x) return F.sigmoid(x) class UNetRep(nn.Module): def __init__(self, n_channels, n_classes): super(UNetRep, self).__init__() self.inc = inconv(n_channels, 64) self.down1 = down(64, 128) self.down2 = down(128, 128) self.up1 = up1(256, 128, 128) self.up2 = up1(192, 64, 128) self.outc = outconv(64, n_classes) def forward(self, x): x1 = self.inc(x) for _ in range(3): x2 = self.down1(x1) x3 = self.down2(x2) x = self.up1(x3,x2) x1 = self.up2(x,x1) # x6 = self.down5(x5) # x = self.up1(x6, x5) # x = self.up2(x, x4) # x = self.up3(x, x3) # x = self.up4(x, x2) # x = self.up5(x, x1) x = self.outc(x1) return F.sigmoid(x) # ------------------- UNet Noise2noise implementation class single_conv(nn.Module): '''(conv => BN => ReLU) * 2''' def __init__(self, in_ch, out_ch): super(single_conv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), # nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): x = self.conv(x) return x class outconv2(nn.Module): def __init__(self, in_ch, out_ch): super(outconv2, self).__init__() self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1) def forward(self, x): x = self.conv(x) return x class up1(nn.Module): def __init__(self, in_ch, out_ch, convtr, bilinear=False): super(up1, self).__init__() # would be a nice idea if the upsampling could be learned too, # but my machine do not have enough memory to handle all those weights if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d(convtr, convtr, 3, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)) # for padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) x = self.conv(x) return x class up2(nn.Module): def __init__(self, in_ch, in_ch2, out_ch,out_ch2,convtr, bilinear=False): super(up2, self).__init__() # would be a nice idea if the upsampling could be learned too, # but my machine do not have enough memory to handle all those weights if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d(convtr, convtr, 3, stride=2) # self.conv = double_conv(in_ch, out_ch) self.conv = nn.Conv2d(in_ch + in_ch2, out_ch, 3, padding=1) self.conv2 = nn.Conv2d(out_ch, out_ch2, 3, padding=1) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)) # for padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) x = self.conv(x) x = self.conv2(x) return x class down2(nn.Module): def __init__(self, in_ch, out_ch): super(down2, self).__init__() self.mpconv = nn.Sequential( # nn.MaxPool2d(2), nn.Conv2d(in_ch,in_ch, 2, stride=2), single_conv(in_ch, out_ch) ) def forward(self, x): x = self.mpconv(x) return x class UNetGreedy(nn.Module): def __init__(self, n_channels, n_classes): super(UNetGreedy, self).__init__() self.inc = inconv(n_channels, 144) self.down1 = down(144, 144) self.down2 = down2(144, 144) self.down3 = down2(144, 144) self.down4 = down2(144, 144) self.down5 = down2(144, 144) self.up1 = up1(288, 288,144) self.up2 = up1(432, 288,288) self.up3 = up1(432, 288,288) self.up4 = up1(432, 288,288) self.up5 = up2(288, n_channels, 64, 32,288) self.outc = outconv2(32, n_classes) def forward(self, x0): x1 = self.inc(x0) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x6 = self.down5(x5) x = self.up1(x6, x5) x = self.up2(x, x4) x = self.up3(x, x3) x = self.up4(x, x2) x = self.up5(x, x0) x = self.outc(x) return F.sigmoid(x) class UNet2(nn.Module): def __init__(self, n_channels, n_classes): super(UNet2, self).__init__() self.inc = inconv(n_channels, 48) self.down1 = down(48, 48) self.down2 = down2(48, 48) self.down3 = down2(48, 48) self.down4 = down2(48, 48) self.down5 = down2(48, 48) self.up1 = up1(96, 96,48) self.up2 = up1(144, 96,96) self.up3 = up1(144, 96,96) self.up4 = up1(144, 96,96) self.up5 = up2(96, n_channels, 64, 32,96) self.outc = outconv2(32, n_classes) def forward(self, x0): x1 = self.inc(x0) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x6 = self.down5(x5) x = self.up1(x6, x5) x = self.up2(x, x4) x = self.up3(x, x3) x = self.up4(x, x2) x = self.up5(x, x0) x = self.outc(x) return F.sigmoid(x) class MLPNet(nn.Module): def __init__(self): super(MLPNet, self).__init__() # 1 input image channel, 6 output channels, 5x5 square convolution kernel self.conv1 = nn.Conv2d(3, 64, 3, padding=1) self.conv12 = nn.Conv2d(64, 64, 3, padding=1) self.pool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(64, 96, 3, padding=1) self.conv22 = nn.Conv2d(96, 128, 3, padding=1) self.conv3 = nn.Conv2d(96, 128, 3, padding=1) self.conv4 = nn.Conv2d(128, 128, 3, padding=1) self.conv5 = nn.Conv2d(128, 64, 3, padding=1) self.conv6 = nn.Conv2d(64, 32, 5) # self.conv3 = nn.Conv2d(24, 48, 3, padding=1) self.fc = nn.Sequential( nn.Linear(6*6*32, 100), nn.ReLU(), nn.Dropout2d(p=0.2), nn.Linear(100, 1), nn.ReLU() ) # self.fc2 = nn.Linear(100,50) # self.fc3 = nn.Linear(50,20) # self.fc4 = nn.Linear(20,1) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv12(x)) x = self.pool(x) x = F.relu(self.conv2(x)) # x = F.relu(self.conv22(x)) x = self.pool(x) x = F.relu(self.conv3(x)) x = F.relu(self.conv4(x)) x = self.pool(x) x = F.relu(self.conv5(x)) x = self.pool(x) x = F.relu(self.conv6(x)) x = self.pool(x) x = x.view(-1, 6*6*32) x = self.fc(x) # x = F.relu(self.fc1(x)) # x = F.relu(self.fc2(x)) return x # --------------------- FFDNet from torch.autograd import Function, Variable def concatenate_input_noise_map(input, noise_sigma): r"""Implements the first layer of FFDNet. This function returns a torch.autograd.Variable composed of the concatenation of the downsampled input image and the noise map. Each image of the batch of size CxHxW gets converted to an array of size 4*CxH/2xW/2. Each of the pixels of the non-overlapped 2x2 patches of the input image are placed in the new array along the first dimension. Args: input: batch containing CxHxW images noise_sigma: the value of the pixels of the CxH/2xW/2 noise map """ # noise_sigma is a list of length batch_size N, C, H, W = input.size() dtype = input.type() sca = 2 sca2 = sca*sca Cout = sca2*C Hout = H//sca Wout = W//sca idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] # Fill the downsampled image with zeros if 'cuda' in dtype: downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0) else: downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0) # Build the CxH/2xW/2 noise map noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout) # Populate output for idx in range(sca2): downsampledfeatures[:, idx:Cout:sca2, :, :] = \ input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] # concatenate de-interleaved mosaic with noise map return torch.cat((noise_map, downsampledfeatures), 1) class UpSampleFeaturesFunction(Function): r"""Extends PyTorch's modules by implementing a torch.autograd.Function. This class implements the forward and backward methods of the last layer of FFDNet. It basically performs the inverse of concatenate_input_noise_map(): it converts each of the images of a batch of size CxH/2xW/2 to images of size C/4xHxW """ @staticmethod def forward(ctx, input): N, Cin, Hin, Win = input.size() dtype = input.type() sca = 2 sca2 = sca*sca Cout = Cin//sca2 Hout = Hin*sca Wout = Win*sca idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] assert (Cin%sca2 == 0), \ 'Invalid input dimensions: number of channels should be divisible by 4' result = torch.zeros((N, Cout, Hout, Wout)).type(dtype) for idx in range(sca2): result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = \ input[:, idx:Cin:sca2, :, :] return result @staticmethod def backward(ctx, grad_output): N, Cg_out, Hg_out, Wg_out = grad_output.size() dtype = grad_output.data.type() sca = 2 sca2 = sca*sca Cg_in = sca2*Cg_out Hg_in = Hg_out//sca Wg_in = Wg_out//sca idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] # Build output grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype) # Populate output for idx in range(sca2): grad_input[:, idx:Cg_in:sca2, :, :] = \ grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] return Variable(grad_input) # Alias functions upsamplefeatures = UpSampleFeaturesFunction.apply class UpSampleFeatures(nn.Module): r"""Implements the last layer of FFDNet """ def __init__(self): super(UpSampleFeatures, self).__init__() def forward(self, x): return upsamplefeatures(x) class IntermediateDnCNN(nn.Module): r"""Implements the middel part of the FFDNet architecture, which is basically a DnCNN net """ def __init__(self, input_features, middle_features, num_conv_layers): super(IntermediateDnCNN, self).__init__() self.kernel_size = 3 self.padding = 1 self.input_features = input_features self.num_conv_layers = num_conv_layers self.middle_features = middle_features if self.input_features == 5: self.output_features = 4 #Grayscale image elif self.input_features == 15: self.output_features = 12 #RGB image else: self.output_features = 3 # raise Exception('Invalid number of input features') layers = [] layers.append(nn.Conv2d(in_channels=self.input_features,\ out_channels=self.middle_features,\ kernel_size=self.kernel_size,\ padding=self.padding,\ bias=False)) layers.append(nn.ReLU(inplace=True)) for _ in range(self.num_conv_layers-2): layers.append(nn.Conv2d(in_channels=self.middle_features,\ out_channels=self.middle_features,\ kernel_size=self.kernel_size,\ padding=self.padding,\ bias=False)) # layers.append(nn.BatchNorm2d(self.middle_features)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.Conv2d(in_channels=self.middle_features,\ out_channels=self.output_features,\ kernel_size=self.kernel_size,\ padding=self.padding,\ bias=False)) self.itermediate_dncnn = nn.Sequential(*layers) def forward(self, x): out = self.itermediate_dncnn(x) return out class FFDNet(nn.Module): r"""Implements the FFDNet architecture """ def __init__(self, num_input_channels, test_mode=False): super(FFDNet, self).__init__() self.num_input_channels = num_input_channels self.test_mode = test_mode if self.num_input_channels == 1: # Grayscale image self.num_feature_maps = 64 self.num_conv_layers = 15 self.downsampled_channels = 5 self.output_features = 4 elif self.num_input_channels == 3: # RGB image self.num_feature_maps = 96 self.num_conv_layers = 12 self.downsampled_channels = 15 self.output_features = 12 else: raise Exception('Invalid number of input features') self.intermediate_dncnn = IntermediateDnCNN(\ input_features=self.downsampled_channels,\ middle_features=self.num_feature_maps,\ num_conv_layers=self.num_conv_layers) self.upsamplefeatures = UpSampleFeatures() def forward(self, x, noise_sigma): concat_noise_x = concatenate_input_noise_map(\ x.data, noise_sigma.data) if self.test_mode: concat_noise_x = Variable(concat_noise_x, volatile=True) else: concat_noise_x = Variable(concat_noise_x) h_dncnn = self.intermediate_dncnn(concat_noise_x) pred_noise = self.upsamplefeatures(h_dncnn) return pred_noise class DNCNN(nn.Module): r"""Implements the DNCNNNet architecture """ def __init__(self, num_input_channels, test_mode=False): super(DNCNN, self).__init__() self.num_input_channels = num_input_channels self.test_mode = test_mode if self.num_input_channels == 1: # Grayscale image self.num_feature_maps = 64 self.num_conv_layers = 15 self.downsampled_channels = 5 self.output_features = 4 elif self.num_input_channels == 3: # RGB image self.num_feature_maps = 96 self.num_conv_layers = 12 self.downsampled_channels = 15 self.output_features = 12 else: raise Exception('Invalid number of input features') self.intermediate_dncnn = IntermediateDnCNN(\ input_features=self.num_input_channels,\ middle_features=self.num_feature_maps,\ num_conv_layers=self.num_conv_layers) def forward(self, x): dncnn = self.intermediate_dncnn(x) return dncnn