Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.init | |
import torch.nn.functional as F | |
import functools # used by RRDBNet | |
def GetModel(opt): | |
if opt.model.lower() == 'edsr': | |
net = EDSR(opt) | |
elif opt.model.lower() == 'edsr2max': | |
net = EDSR2Max(normalization=opt.norm,nch_in=opt.nch_in,nch_out=opt.nch_out,scale=opt.scale) | |
elif opt.model.lower() == 'edsr3max': | |
net = EDSR3Max(normalization=opt.norm,nch_in=opt.nch_in,nch_out=opt.nch_out,scale=opt.scale) | |
elif opt.model.lower() == 'rcan': | |
net = RCAN(opt) | |
elif opt.model.lower() == 'rnan': | |
net = RNAN(opt) | |
elif opt.model.lower() == 'rrdb': | |
net = RRDBNet(opt) | |
elif opt.model.lower() == 'srresnet' or opt.model.lower() == 'srgan': | |
net = Generator(16, opt) | |
elif opt.model.lower() == 'unet': | |
net = UNet(opt.nch_in,opt.nch_out,opt) | |
elif opt.model.lower() == 'unet_n2n': | |
net = UNet_n2n(opt.nch_in,opt.nch_out,opt) | |
elif opt.model.lower() == 'unet60m': | |
net = UNet60M(opt.nch_in,opt.nch_out) | |
elif opt.model.lower() == 'unetrep': | |
net = UNetRep(opt.nch_in,opt.nch_out) | |
elif opt.model.lower() == 'unetgreedy': | |
net = UNetGreedy(opt.nch_in,opt.nch_out) | |
elif opt.model.lower() == 'mlpnet': | |
net = MLPNet() | |
elif opt.model.lower() == 'ffdnet': | |
net = FFDNet(opt.nch_in) | |
elif opt.model.lower() == 'dncnn': | |
net = DNCNN(opt.nch_in) | |
elif opt.model.lower() == 'fouriernet': | |
net = FourierNet() | |
elif opt.model.lower() == 'fourierconvnet': | |
net = FourierConvNet() | |
else: | |
print("model undefined") | |
return None | |
net.to(opt.device) | |
if opt.multigpu: | |
net = nn.DataParallel(net) | |
return net | |
class MeanShift(nn.Conv2d): | |
def __init__( | |
self, rgb_range, | |
rgb_mean, rgb_std, sign=-1): | |
super(MeanShift, self).__init__(3, 3, kernel_size=1) | |
std = torch.Tensor(rgb_std) | |
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) | |
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std | |
self.requires_grad = False | |
def normalizationTransforms(normtype): | |
if normtype.lower() == 'div2k': | |
normalize = MeanShift(1, [0.4485, 0.4375, 0.4045], [0.2436, 0.2330, 0.2424]) | |
unnormalize = MeanShift(1, [-1.8411, -1.8777, -1.6687], [4.1051, 4.2918, 4.1254]) | |
print('using div2k normalization') | |
elif normtype.lower() == 'pcam': | |
normalize = MeanShift(1, [0.6975, 0.5348, 0.688], [0.2361, 0.2786, 0.2146]) | |
unnormalize = MeanShift(1, [-2.9547, -1.9198, -3.20643], [4.2363, 3.58972, 4.66049]) | |
print('using pcam normalization') | |
elif normtype.lower() == 'div2k_std1': | |
normalize = MeanShift(1, [0.4485, 0.4375, 0.4045], [1,1,1]) | |
unnormalize = MeanShift(1, [-0.4485, -0.4375, -0.4045], [1,1,1]) | |
print('using div2k normalization with std 1') | |
elif normtype.lower() == 'pcam_std1': | |
normalize = MeanShift(1, [0.6975, 0.5348, 0.688], [1,1,1]) | |
unnormalize = MeanShift(1, [-0.6975, -0.5348, -0.688], [1,1,1]) | |
print('using pcam normalization with std 1') | |
else: | |
print('not using normalization') | |
return None, None | |
return normalize, unnormalize | |
def conv(in_channels, out_channels, kernel_size, bias=True): | |
return nn.Conv2d( | |
in_channels, out_channels, kernel_size, | |
padding=(kernel_size//2), bias=bias) | |
class BasicBlock(nn.Sequential): | |
def __init__( | |
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, | |
bn=True, act=nn.ReLU(True)): | |
m = [conv(in_channels, out_channels, kernel_size, bias=bias)] | |
if bn: m.append(nn.BatchNorm2d(out_channels)) | |
if act is not None: m.append(act) | |
super(BasicBlock, self).__init__(*m) | |
class ResBlock(nn.Module): | |
def __init__( | |
self, conv, n_feats, kernel_size, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(ResBlock, self).__init__() | |
m = [] | |
m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) | |
m.append(nn.ReLU(True)) | |
m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) | |
self.body = nn.Sequential(*m) | |
self.res_scale = res_scale | |
def forward(self, x): | |
res = self.body(x).mul(self.res_scale) | |
res += x | |
return res | |
class ResBlock2Max(nn.Module): | |
def __init__( | |
self, conv, n_feats, kernel_size, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(ResBlock2Max, self).__init__() | |
m = [] | |
m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) | |
m.append(nn.MaxPool2d(2)) | |
m.append(nn.ReLU(True)) | |
m.append(conv(n_feats, 2*n_feats, kernel_size, bias=bias)) | |
m.append(nn.MaxPool2d(2)) | |
m.append(nn.ReLU(True)) | |
m.append(conv(2*n_feats, 4*n_feats, kernel_size, bias=bias)) | |
m.append(nn.ReLU(True)) | |
m.append(nn.ConvTranspose2d(4*n_feats,2*n_feats,3,stride=2, padding=1, output_padding=1)) | |
m.append(nn.ConvTranspose2d(2*n_feats,n_feats,3,stride=2, padding=1, output_padding=1)) | |
self.body = nn.Sequential(*m) | |
self.res_scale = res_scale | |
def forward(self, x): | |
res = self.body(x).mul(self.res_scale) | |
res += x | |
return res | |
class ResBlock3Max(nn.Module): | |
def __init__( | |
self, conv, n_feats, kernel_size, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(ResBlock3Max, self).__init__() | |
m = [] | |
m.append(conv(n_feats, 2*n_feats, kernel_size, bias=bias)) | |
m.append(nn.MaxPool2d(2)) | |
m.append(nn.ReLU(True)) | |
m.append(conv(2*n_feats, 2*n_feats, kernel_size, bias=bias)) | |
m.append(nn.MaxPool2d(2)) | |
m.append(nn.ReLU(True)) | |
m.append(conv(2*n_feats, 4*n_feats, kernel_size, bias=bias)) | |
m.append(nn.MaxPool2d(2)) | |
m.append(nn.ReLU(True)) | |
m.append(conv(4*n_feats, 8*n_feats, kernel_size, bias=bias)) | |
m.append(nn.ReLU(True)) | |
m.append(nn.ConvTranspose2d(8*n_feats,4*n_feats,3,stride=2, padding=1, output_padding=1)) | |
m.append(nn.ConvTranspose2d(4*n_feats,2*n_feats,3,stride=2, padding=1, output_padding=1)) | |
m.append(nn.ConvTranspose2d(2*n_feats,n_feats,3,stride=2, padding=1, output_padding=1)) | |
self.body = nn.Sequential(*m) | |
self.res_scale = res_scale | |
def forward(self, x): | |
res = self.body(x).mul(self.res_scale) | |
res += x | |
return res | |
class Upsampler(nn.Sequential): | |
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): | |
m = [] | |
if (scale & (scale - 1)) == 0: # Is scale = 2^n? | |
for _ in range(int(math.log(scale, 2))): | |
m.append(conv(n_feats, 4 * n_feats, 3, bias)) | |
m.append(nn.PixelShuffle(2)) | |
if bn: m.append(nn.BatchNorm2d(n_feats)) | |
if act == 'relu': | |
m.append(nn.ReLU(True)) | |
elif act == 'prelu': | |
m.append(nn.PReLU(n_feats)) | |
elif scale == 3: | |
m.append(conv(n_feats, 9 * n_feats, 3, bias)) | |
m.append(nn.PixelShuffle(3)) | |
if bn: m.append(nn.BatchNorm2d(n_feats)) | |
if act == 'relu': | |
m.append(nn.ReLU(True)) | |
elif act == 'prelu': | |
m.append(nn.PReLU(n_feats)) | |
else: | |
raise NotImplementedError | |
super(Upsampler, self).__init__(*m) | |
class EDSR(nn.Module): | |
def __init__(self,opt): | |
super(EDSR, self).__init__() | |
n_resblocks = 16 | |
n_feats = 64 | |
kernel_size = 3 | |
act = nn.ReLU(True) | |
if not opt.norm == None: | |
self.normalize, self.unnormalize = normalizationTransforms(opt.norm) | |
else: | |
self.normalize, self.unnormalize = None, None | |
# define head module | |
m_head = [conv(opt.nch_in, n_feats, kernel_size)] | |
# define body module | |
m_body = [ | |
ResBlock( | |
conv, n_feats, kernel_size, act=act, res_scale=0.1 | |
) for _ in range(n_resblocks) | |
] | |
m_body.append(conv(n_feats, n_feats, kernel_size)) | |
# define tail module | |
if opt.scale == 1: | |
if opt.task == 'segment': | |
m_tail = [nn.Conv2d(n_feats, 2, 1)] | |
else: | |
m_tail = [conv(n_feats, opt.nch_out, kernel_size)] | |
else: | |
m_tail = [ | |
Upsampler(conv, opt.scale, n_feats, act=False), | |
conv(n_feats, opt.nch_out, kernel_size)] | |
self.head = nn.Sequential(*m_head) | |
self.body = nn.Sequential(*m_body) | |
self.tail = nn.Sequential(*m_tail) | |
def forward(self, x): | |
if not self.normalize == None: | |
x = self.normalize(x) | |
x = self.head(x) | |
res = self.body(x) | |
res += x | |
x = self.tail(res) | |
if not self.unnormalize == None: | |
x = self.unnormalize(x) | |
return x | |
class EDSR2Max(nn.Module): | |
def __init__(self, normalization=None,nch_in=3,nch_out=3,scale=4): | |
super(EDSR2Max, self).__init__() | |
n_resblocks = 16 | |
n_feats = 64 | |
kernel_size = 3 | |
act = nn.ReLU(True) | |
if not opt.norm == None: | |
self.normalize, self.unnormalize = normalizationTransforms(normalization) | |
else: | |
self.normalize, self.unnormalize = None, None | |
# define head module | |
m_head = [conv(nch_in, n_feats, kernel_size)] | |
# define body module | |
m_body = [ | |
ResBlock2Max( | |
conv, n_feats, kernel_size, act=act, res_scale=0.1 | |
) for _ in range(n_resblocks) | |
] | |
m_body.append(conv(n_feats, n_feats, kernel_size)) | |
# define tail module | |
m_tail = [ | |
conv(n_feats, nch_out, kernel_size) | |
] | |
self.head = nn.Sequential(*m_head) | |
self.body = nn.Sequential(*m_body) | |
self.tail = nn.Sequential(*m_tail) | |
def forward(self, x): | |
if not self.normalize == None: | |
x = self.normalize(x) | |
x = self.head(x) | |
res = self.body(x) | |
res += x | |
x = self.tail(res) | |
if not self.unnormalize == None: | |
x = self.unnormalize(x) | |
return x | |
class EDSR3Max(nn.Module): | |
def __init__(self, normalization=None,nch_in=3,nch_out=3,scale=4): | |
super(EDSR3Max, self).__init__() | |
n_resblocks = 16 | |
n_feats = 64 | |
kernel_size = 3 | |
act = nn.ReLU(True) | |
if not opt.norm == None: | |
self.normalize, self.unnormalize = normalizationTransforms(normalization) | |
else: | |
self.normalize, self.unnormalize = None, None | |
# define head module | |
m_head = [conv(nch_in, n_feats, kernel_size)] | |
# define body module | |
m_body = [ | |
ResBlock3Max( | |
conv, n_feats, kernel_size, act=act, res_scale=0.1 | |
) for _ in range(n_resblocks) | |
] | |
m_body.append(conv(n_feats, n_feats, kernel_size)) | |
# define tail module | |
m_tail = [ | |
conv(n_feats, nch_out, kernel_size) | |
] | |
self.head = nn.Sequential(*m_head) | |
self.body = nn.Sequential(*m_body) | |
self.tail = nn.Sequential(*m_tail) | |
def forward(self, x): | |
if not self.normalize == None: | |
x = self.normalize(x) | |
x = self.head(x) | |
res = self.body(x) | |
res += x | |
x = self.tail(res) | |
if not self.unnormalize == None: | |
x = self.unnormalize(x) | |
return x | |
# ----------------------------------- RCAN ------------------------------------------ | |
## Channel Attention (CA) Layer | |
class CALayer(nn.Module): | |
def __init__(self, channel, reduction=16): | |
super(CALayer, self).__init__() | |
# global average pooling: feature --> point | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
# feature channel downscale and upscale --> channel weight | |
self.conv_du = nn.Sequential( | |
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
y = self.avg_pool(x) | |
y = self.conv_du(y) | |
return x * y | |
## Residual Channel Attention Block (RCAB) | |
class RCAB(nn.Module): | |
def __init__( | |
self, conv, n_feat, kernel_size, reduction, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(RCAB, self).__init__() | |
modules_body = [] | |
for i in range(2): | |
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) | |
if bn: modules_body.append(nn.BatchNorm2d(n_feat)) | |
if i == 0: modules_body.append(act) | |
modules_body.append(CALayer(n_feat, reduction)) | |
self.body = nn.Sequential(*modules_body) | |
self.res_scale = res_scale | |
def forward(self, x): | |
res = self.body(x) | |
#res = self.body(x).mul(self.res_scale) | |
res += x | |
return res | |
## Residual Group (RG) | |
class ResidualGroup(nn.Module): | |
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): | |
super(ResidualGroup, self).__init__() | |
modules_body = [] | |
modules_body = [ | |
RCAB( | |
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ | |
for _ in range(n_resblocks)] | |
modules_body.append(conv(n_feat, n_feat, kernel_size)) | |
self.body = nn.Sequential(*modules_body) | |
def forward(self, x): | |
res = self.body(x) | |
res += x | |
return res | |
## Residual Channel Attention Network (RCAN) | |
class RCAN(nn.Module): | |
def __init__(self, opt): | |
super(RCAN, self).__init__() | |
n_resgroups = opt.n_resgroups | |
n_resblocks = opt.n_resblocks | |
n_feats = opt.n_feats | |
kernel_size = 3 | |
reduction = opt.reduction | |
act = nn.ReLU(True) | |
self.narch = opt.narch | |
if not opt.norm == None: | |
self.normalize, self.unnormalize = normalizationTransforms(opt.norm) | |
else: | |
self.normalize, self.unnormalize = None, None | |
# define head module | |
if self.narch == 0: | |
modules_head = [conv(opt.nch_in, n_feats, kernel_size)] | |
self.head = nn.Sequential(*modules_head) | |
else: | |
self.head0 = conv(1, n_feats, kernel_size) | |
self.head1 = conv(1, n_feats, kernel_size) | |
self.head2 = conv(1, n_feats, kernel_size) | |
self.head3 = conv(1, n_feats, kernel_size) | |
self.head4 = conv(1, n_feats, kernel_size) | |
self.head5 = conv(1, n_feats, kernel_size) | |
self.head6 = conv(1, n_feats, kernel_size) | |
self.head7 = conv(1, n_feats, kernel_size) | |
self.head8 = conv(1, n_feats, kernel_size) | |
self.combineHead = conv(9*n_feats, n_feats, kernel_size) | |
# define body module | |
modules_body = [ | |
ResidualGroup( | |
conv, n_feats, kernel_size, reduction, act=act, res_scale=1, n_resblocks=n_resblocks) \ | |
for _ in range(n_resgroups)] | |
modules_body.append(conv(n_feats, n_feats, kernel_size)) | |
# define tail module | |
if opt.scale == 1: | |
if opt.task == 'segment': | |
modules_tail = [nn.Conv2d(n_feats, opt.nch_out, 1)] | |
else: | |
modules_tail = [conv(n_feats, opt.nch_out, kernel_size)] | |
else: | |
modules_tail = [ | |
Upsampler(conv, opt.scale, n_feats, act=False), | |
conv(n_feats, opt.nch_out, kernel_size)] | |
self.body = nn.Sequential(*modules_body) | |
self.tail = nn.Sequential(*modules_tail) | |
def forward(self, x): | |
if not self.normalize == None: | |
x = self.normalize(x) | |
if self.narch == 0: | |
x = self.head(x) | |
else: | |
x0 = self.head0(x[:,0:0+1,:,:]) | |
x1 = self.head1(x[:,1:1+1,:,:]) | |
x2 = self.head2(x[:,2:2+1,:,:]) | |
x3 = self.head3(x[:,3:3+1,:,:]) | |
x4 = self.head4(x[:,4:4+1,:,:]) | |
x5 = self.head5(x[:,5:5+1,:,:]) | |
x6 = self.head6(x[:,6:6+1,:,:]) | |
x7 = self.head7(x[:,7:7+1,:,:]) | |
x8 = self.head8(x[:,8:8+1,:,:]) | |
x = torch.cat((x0,x1,x2,x3,x4,x5,x6,x7,x8), 1) | |
x = self.combineHead(x) | |
res = self.body(x) | |
res += x | |
x = self.tail(res) | |
if not self.unnormalize == None: | |
x = self.unnormalize(x) | |
return x | |
# ----------------------------------- RNAN ------------------------------------------ | |
# add NonLocalBlock2D | |
# reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py | |
class NonLocalBlock2D(nn.Module): | |
def __init__(self, in_channels, inter_channels): | |
super(NonLocalBlock2D, self).__init__() | |
self.in_channels = in_channels | |
self.inter_channels = inter_channels | |
self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) | |
self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) | |
# for pytorch 0.3.1 | |
#nn.init.constant(self.W.weight, 0) | |
#nn.init.constant(self.W.bias, 0) | |
# for pytorch 0.4.0 | |
nn.init.constant_(self.W.weight, 0) | |
nn.init.constant_(self.W.bias, 0) | |
self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) | |
self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) | |
def forward(self, x): | |
batch_size = x.size(0) | |
g_x = self.g(x).view(batch_size, self.inter_channels, -1) | |
g_x = g_x.permute(0,2,1) | |
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | |
theta_x = theta_x.permute(0,2,1) | |
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | |
f = torch.matmul(theta_x, phi_x) | |
f_div_C = F.softmax(f, dim=1) | |
y = torch.matmul(f_div_C, g_x) | |
y = y.permute(0,2,1).contiguous() | |
y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | |
W_y = self.W(y) | |
z = W_y + x | |
return z | |
## define trunk branch | |
class TrunkBranch(nn.Module): | |
def __init__( | |
self, conv, n_feat, kernel_size, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(TrunkBranch, self).__init__() | |
modules_body = [] | |
for i in range(2): | |
modules_body.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
self.body = nn.Sequential(*modules_body) | |
def forward(self, x): | |
tx = self.body(x) | |
return tx | |
## define mask branch | |
class MaskBranchDownUp(nn.Module): | |
def __init__( | |
self, conv, n_feat, kernel_size, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(MaskBranchDownUp, self).__init__() | |
MB_RB1 = [] | |
MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
MB_Down = [] | |
MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1)) | |
MB_RB2 = [] | |
for i in range(2): | |
MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
MB_Up = [] | |
MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2)) | |
MB_RB3 = [] | |
MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
MB_1x1conv = [] | |
MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True)) | |
MB_sigmoid = [] | |
MB_sigmoid.append(nn.Sigmoid()) | |
self.MB_RB1 = nn.Sequential(*MB_RB1) | |
self.MB_Down = nn.Sequential(*MB_Down) | |
self.MB_RB2 = nn.Sequential(*MB_RB2) | |
self.MB_Up = nn.Sequential(*MB_Up) | |
self.MB_RB3 = nn.Sequential(*MB_RB3) | |
self.MB_1x1conv = nn.Sequential(*MB_1x1conv) | |
self.MB_sigmoid = nn.Sequential(*MB_sigmoid) | |
def forward(self, x): | |
x_RB1 = self.MB_RB1(x) | |
x_Down = self.MB_Down(x_RB1) | |
x_RB2 = self.MB_RB2(x_Down) | |
x_Up = self.MB_Up(x_RB2) | |
x_preRB3 = x_RB1 + x_Up | |
x_RB3 = self.MB_RB3(x_preRB3) | |
x_1x1 = self.MB_1x1conv(x_RB3) | |
mx = self.MB_sigmoid(x_1x1) | |
return mx | |
## define nonlocal mask branch | |
class NLMaskBranchDownUp(nn.Module): | |
def __init__( | |
self, conv, n_feat, kernel_size, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(NLMaskBranchDownUp, self).__init__() | |
MB_RB1 = [] | |
MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2)) | |
MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
MB_Down = [] | |
MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1)) | |
MB_RB2 = [] | |
for i in range(2): | |
MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
MB_Up = [] | |
MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2)) | |
MB_RB3 = [] | |
MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
MB_1x1conv = [] | |
MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True)) | |
MB_sigmoid = [] | |
MB_sigmoid.append(nn.Sigmoid()) | |
self.MB_RB1 = nn.Sequential(*MB_RB1) | |
self.MB_Down = nn.Sequential(*MB_Down) | |
self.MB_RB2 = nn.Sequential(*MB_RB2) | |
self.MB_Up = nn.Sequential(*MB_Up) | |
self.MB_RB3 = nn.Sequential(*MB_RB3) | |
self.MB_1x1conv = nn.Sequential(*MB_1x1conv) | |
self.MB_sigmoid = nn.Sequential(*MB_sigmoid) | |
def forward(self, x): | |
x_RB1 = self.MB_RB1(x) | |
x_Down = self.MB_Down(x_RB1) | |
x_RB2 = self.MB_RB2(x_Down) | |
x_Up = self.MB_Up(x_RB2) | |
x_preRB3 = x_RB1 + x_Up | |
x_RB3 = self.MB_RB3(x_preRB3) | |
x_1x1 = self.MB_1x1conv(x_RB3) | |
mx = self.MB_sigmoid(x_1x1) | |
return mx | |
## define residual attention module | |
class ResAttModuleDownUpPlus(nn.Module): | |
def __init__( | |
self, conv, n_feat, kernel_size, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(ResAttModuleDownUpPlus, self).__init__() | |
RA_RB1 = [] | |
RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
RA_TB = [] | |
RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
RA_MB = [] | |
RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
RA_tail = [] | |
for i in range(2): | |
RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
self.RA_RB1 = nn.Sequential(*RA_RB1) | |
self.RA_TB = nn.Sequential(*RA_TB) | |
self.RA_MB = nn.Sequential(*RA_MB) | |
self.RA_tail = nn.Sequential(*RA_tail) | |
def forward(self, input): | |
RA_RB1_x = self.RA_RB1(input) | |
tx = self.RA_TB(RA_RB1_x) | |
mx = self.RA_MB(RA_RB1_x) | |
txmx = tx * mx | |
hx = txmx + RA_RB1_x | |
hx = self.RA_tail(hx) | |
return hx | |
## define nonlocal residual attention module | |
class NLResAttModuleDownUpPlus(nn.Module): | |
def __init__( | |
self, conv, n_feat, kernel_size, | |
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
super(NLResAttModuleDownUpPlus, self).__init__() | |
RA_RB1 = [] | |
RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
RA_TB = [] | |
RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
RA_MB = [] | |
RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
RA_tail = [] | |
for i in range(2): | |
RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
self.RA_RB1 = nn.Sequential(*RA_RB1) | |
self.RA_TB = nn.Sequential(*RA_TB) | |
self.RA_MB = nn.Sequential(*RA_MB) | |
self.RA_tail = nn.Sequential(*RA_tail) | |
def forward(self, input): | |
RA_RB1_x = self.RA_RB1(input) | |
tx = self.RA_TB(RA_RB1_x) | |
mx = self.RA_MB(RA_RB1_x) | |
txmx = tx * mx | |
hx = txmx + RA_RB1_x | |
hx = self.RA_tail(hx) | |
return hx | |
class _ResGroup(nn.Module): | |
def __init__(self, conv, n_feats, kernel_size, act, res_scale): | |
super(_ResGroup, self).__init__() | |
modules_body = [] | |
modules_body.append(ResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
modules_body.append(conv(n_feats, n_feats, kernel_size)) | |
self.body = nn.Sequential(*modules_body) | |
def forward(self, x): | |
res = self.body(x) | |
return res | |
class _NLResGroup(nn.Module): | |
def __init__(self, conv, n_feats, kernel_size, act, res_scale): | |
super(_NLResGroup, self).__init__() | |
modules_body = [] | |
modules_body.append(NLResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) | |
modules_body.append(conv(n_feats, n_feats, kernel_size)) | |
self.body = nn.Sequential(*modules_body) | |
def forward(self, x): | |
res = self.body(x) | |
return res | |
class RNAN(nn.Module): | |
def __init__(self, opt): | |
super(RNAN, self).__init__() | |
n_resgroups = opt.n_resgroups | |
n_feats = opt.n_feats | |
kernel_size = 3 | |
reduction = opt.reduction | |
act = nn.ReLU(True) | |
print(n_resgroup2,n_resblock,n_feats,kernel_size,reduction,act) | |
# RGB mean for DIV2K 1-800 | |
# rgb_mean = (0.4488, 0.4371, 0.4040) | |
# rgb_std = (1.0, 1.0, 1.0) | |
# self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std) | |
# define head module | |
modules_head = [conv(opt.nch_in, n_feats, kernel_size)] | |
# define body module | |
modules_body_nl_low = [ | |
_NLResGroup( | |
conv, n_feats, kernel_size, act=act, res_scale=1)] | |
modules_body = [ | |
_ResGroup( | |
conv, n_feats, kernel_size, act=act, res_scale=1) \ | |
for _ in range(n_resgroups - 2)] | |
modules_body_nl_high = [ | |
_NLResGroup( | |
conv, n_feats, kernel_size, act=act, res_scale=1)] | |
modules_body.append(conv(n_feats, n_feats, kernel_size)) | |
# define tail module | |
modules_tail = [ | |
Upsampler(conv, opt.scale, n_feats, act=False), | |
conv(n_feats, opt.nch_out, kernel_size)] | |
# self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) | |
self.head = nn.Sequential(*modules_head) | |
self.body_nl_low = nn.Sequential(*modules_body_nl_low) | |
self.body = nn.Sequential(*modules_body) | |
self.body_nl_high = nn.Sequential(*modules_body_nl_high) | |
self.tail = nn.Sequential(*modules_tail) | |
def forward(self, x): | |
# x = self.sub_mean(x) | |
feats_shallow = self.head(x) | |
res = self.body_nl_low(feats_shallow) | |
res = self.body(res) | |
res = self.body_nl_high(res) | |
res += feats_shallow | |
res_main = self.tail(res) | |
# res_main = self.add_mean(res_main) | |
return res_main | |
class FourierNet(nn.Module): | |
def __init__(self): | |
super(FourierNet, self).__init__() | |
self.inp = nn.Linear(85*85*9,85*85) | |
def forward(self, x): | |
x = x.view(-1,85*85*9) | |
x = (self.inp(x)) | |
# x = (self.lay1(x)) | |
x = x.view(-1,1,85,85) | |
return x | |
class FourierConvNet(nn.Module): | |
def __init__(self): | |
super(FourierConvNet, self).__init__() | |
# self.inp = nn.Conv2d(18,32,3, stride=1, padding=1) | |
# self.lay1 = nn.Conv2d(32,32,3, stride=1, padding=1) | |
# self.lay2 = nn.Conv2d(32,32,3, stride=1, padding=1) | |
# self.lay3 = nn.Conv2d(32,32,3, stride=1, padding=1) | |
# self.pool = nn.MaxPool2d(2,2) | |
# self.out = nn.Conv2d(32,1,3, stride=1, padding=1) | |
# self.labels = nn.Linear(4096,18) | |
self.inc = inconv(18, 64) | |
self.down1 = down(64, 128) | |
self.down2 = down(128, 256) | |
self.down3 = down(256, 512) | |
self.down4 = down(512, 512) | |
self.up1 = up(1024, 256) | |
self.up2 = up(512, 128) | |
self.up3 = up(256, 64) | |
self.up4 = up(128, 64) | |
self.outc = outconv(64, 9) # two channels for complex | |
def forward(self, x): | |
# x = self.inp(x) | |
# x = torch.rfft(x,2,onesided=False) | |
# # x = torch.log( torch.abs(x) + 1 ) | |
# x = x.permute(0,1,4,2,3) # put real and imag parts after stack index | |
# x = x.contiguous().view(-1,18,256,256) | |
# x = F.relu(self.inp(x)) | |
# x = self.pool(x) # to 128 | |
# x = F.relu(self.lay2(x)) | |
# x = self.pool(x) # to 64 | |
# x = F.relu(self.lay3(x)) | |
# x = self.out(x) | |
# x = x.view(-1,4096) | |
# x = self.labels(x) | |
x1 = self.inc(x) | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x = self.up1(x5, x4) | |
x = self.up2(x, x3) | |
x = self.up3(x, x2) | |
x = self.up4(x, x1) | |
x = self.outc(x) | |
x = torch.log(torch.abs(x)) | |
# x = x.permute(0,2,3,1) | |
# x = torch.irfft(x,2,onesided=False) | |
return x | |
# super(UNet, self).__init__() | |
# self.inc = inconv(n_channels, 64) | |
# self.down1 = down(64, 128) | |
# self.down2 = down(128, 256) | |
# self.down3 = down(256, 512) | |
# self.down4 = down(512, 512) | |
# self.up1 = up(1024, 256) | |
# self.up2 = up(512, 128) | |
# self.up3 = up(256, 64) | |
# self.up4 = up(128, 64) | |
# if opt.task == 'segment': | |
# self.outc = outconv(64, 2) | |
# else: | |
# self.outc = outconv(64, n_classes) | |
# # Initialize weights | |
# # self._init_weights() | |
# def _init_weights(self): | |
# """Initializes weights using He et al. (2015).""" | |
# for m in self.modules(): | |
# if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): | |
# nn.init.kaiming_normal_(m.weight.data) | |
# m.bias.data.zero_() | |
# def forward(self, x): | |
# x1 = self.inc(x) | |
# x2 = self.down1(x1) | |
# x3 = self.down2(x2) | |
# x4 = self.down3(x3) | |
# x5 = self.down4(x4) | |
# x = self.up1(x5, x4) | |
# x = self.up2(x, x3) | |
# x = self.up3(x, x2) | |
# x = self.up4(x, x1) | |
# x = self.outc(x) | |
# return F.sigmoid(x) | |
# ----------------------------------- RRDB (ESRGAN) ------------------------------------------ | |
def initialize_weights(net_l, scale=1): | |
if not isinstance(net_l, list): | |
net_l = [net_l] | |
for net in net_l: | |
for m in net.modules(): | |
if isinstance(m, nn.Conv2d): | |
torch.nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') | |
m.weight.data *= scale # for residual block | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
torch.nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') | |
m.weight.data *= scale | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
torch.nn.init.constant_(m.weight, 1) | |
torch.nn.init.constant_(m.bias.data, 0.0) | |
def make_layer(block, n_layers): | |
layers = [] | |
for _ in range(n_layers): | |
layers.append(block()) | |
return nn.Sequential(*layers) | |
class ResidualDenseBlock_5C(nn.Module): | |
def __init__(self, nf=64, gc=32, bias=True): | |
super(ResidualDenseBlock_5C, self).__init__() | |
# gc: growth channel, i.e. intermediate channels | |
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) | |
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) | |
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) | |
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) | |
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
# initialization | |
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],0.1) | |
def forward(self, x): | |
x1 = self.lrelu(self.conv1(x)) | |
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | |
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | |
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | |
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
return x5 * 0.2 + x | |
class RRDB(nn.Module): | |
'''Residual in Residual Dense Block''' | |
def __init__(self, nf, gc=32): | |
super(RRDB, self).__init__() | |
self.RDB1 = ResidualDenseBlock_5C(nf, gc) | |
self.RDB2 = ResidualDenseBlock_5C(nf, gc) | |
self.RDB3 = ResidualDenseBlock_5C(nf, gc) | |
def forward(self, x): | |
out = self.RDB1(x) | |
out = self.RDB2(out) | |
out = self.RDB3(out) | |
return out * 0.2 + x | |
class RRDBNet(nn.Module): | |
def __init__(self, opt, gc=32): | |
super(RRDBNet, self).__init__() | |
RRDB_block_f = functools.partial(RRDB, nf=opt.n_feats, gc=gc) | |
self.conv_first = nn.Conv2d(opt.nch_in, opt.n_feats, 3, 1, 1, bias=True) | |
self.RRDB_trunk = make_layer(RRDB_block_f, opt.n_resblocks) | |
self.trunk_conv = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True) | |
#### upsampling | |
self.upconv1 = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True) | |
self.upconv2 = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True) | |
self.HRconv = nn.Conv2d(opt.n_feats, opt.n_feats, 3, 1, 1, bias=True) | |
self.conv_last = nn.Conv2d(opt.n_feats, opt.nch_out, 3, 1, 1, bias=True) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
self.scale = opt.scale | |
def forward(self, x): | |
fea = self.conv_first(x) | |
trunk = self.trunk_conv(self.RRDB_trunk(fea)) | |
fea = fea + trunk | |
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode='nearest'))) | |
# fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=self.scale, mode='nearest'))) | |
out = self.conv_last(self.lrelu(self.HRconv(fea))) | |
return out | |
# ----------------------------------- SRGAN ------------------------------------------ | |
def swish(x): | |
return x * torch.sigmoid(x) | |
class FeatureExtractor(nn.Module): | |
def __init__(self, cnn, feature_layer=11): | |
super(FeatureExtractor, self).__init__() | |
self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer+1)]) | |
def forward(self, x): | |
return self.features(x) | |
class residualBlock(nn.Module): | |
def __init__(self, in_channels=64, k=3, n=64, s=1): | |
super(residualBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, n, k, stride=s, padding=1) | |
self.bn1 = nn.BatchNorm2d(n) | |
self.conv2 = nn.Conv2d(n, n, k, stride=s, padding=1) | |
self.bn2 = nn.BatchNorm2d(n) | |
def forward(self, x): | |
y = swish(self.bn1(self.conv1(x))) | |
return self.bn2(self.conv2(y)) + x | |
class upsampleBlock(nn.Module): | |
# Implements resize-convolution | |
def __init__(self, in_channels, out_channels): | |
super(upsampleBlock, self).__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1) | |
self.shuffler = nn.PixelShuffle(2) | |
def forward(self, x): | |
return swish(self.shuffler(self.conv(x))) | |
class Generator(nn.Module): | |
def __init__(self, n_residual_blocks, opt): | |
super(Generator, self).__init__() | |
self.n_residual_blocks = n_residual_blocks | |
self.upsample_factor = opt.scale | |
self.conv1 = nn.Conv2d(opt.nch_in, 64, 9, stride=1, padding=4) | |
if not opt.norm == None: | |
self.normalize, self.unnormalize = normalizationTransforms(opt.norm) | |
else: | |
self.normalize, self.unnormalize = None, None | |
for i in range(self.n_residual_blocks): | |
self.add_module('residual_block' + str(i+1), residualBlock()) | |
self.conv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1) | |
self.bn2 = nn.BatchNorm2d(64) | |
# for i in range(int(self.upsample_factor/2)): | |
# self.add_module('upsample' + str(i+1), upsampleBlock(64, 256)) | |
if opt.task == 'segment': | |
self.conv3 = nn.Conv2d(64, 2, 1) | |
else: | |
self.conv3 = nn.Conv2d(64, opt.nch_out, 9, stride=1, padding=4) | |
def forward(self, x): | |
if not self.normalize == None: | |
x = self.normalize(x) | |
x = swish(self.conv1(x)) | |
y = x.clone() | |
for i in range(self.n_residual_blocks): | |
y = self.__getattr__('residual_block' + str(i+1))(y) | |
x = self.bn2(self.conv2(y)) + x | |
# for i in range(int(self.upsample_factor/2)): | |
# x = self.__getattr__('upsample' + str(i+1))(x) | |
x = self.conv3(x) | |
if not self.unnormalize == None: | |
x = self.unnormalize(x) | |
return x | |
class Discriminator(nn.Module): | |
def __init__(self,opt): | |
super(Discriminator, self).__init__() | |
self.conv1 = nn.Conv2d(opt.nch_out, 64, 3, stride=1, padding=1) | |
self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1) | |
self.bn2 = nn.BatchNorm2d(64) | |
self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1) | |
self.bn3 = nn.BatchNorm2d(128) | |
self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1) | |
self.bn4 = nn.BatchNorm2d(128) | |
self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1) | |
self.bn5 = nn.BatchNorm2d(256) | |
self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1) | |
self.bn6 = nn.BatchNorm2d(256) | |
self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1) | |
self.bn7 = nn.BatchNorm2d(512) | |
self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1) | |
self.bn8 = nn.BatchNorm2d(512) | |
# Replaced original paper FC layers with FCN | |
self.conv9 = nn.Conv2d(512, 1, 1, stride=1, padding=1) | |
def forward(self, x): | |
x = swish(self.conv1(x)) | |
x = swish(self.bn2(self.conv2(x))) | |
x = swish(self.bn3(self.conv3(x))) | |
x = swish(self.bn4(self.conv4(x))) | |
x = swish(self.bn5(self.conv5(x))) | |
x = swish(self.bn6(self.conv6(x))) | |
x = swish(self.bn7(self.conv7(x))) | |
x = swish(self.bn8(self.conv8(x))) | |
x = self.conv9(x) | |
return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1) | |
class UNet_n2n(nn.Module): | |
"""Custom U-Net architecture for Noise2Noise (see Appendix, Table 2).""" | |
def __init__(self, in_channels=3, out_channels=3, opt = {}): | |
"""Initializes U-Net.""" | |
super(UNet_n2n, self).__init__() | |
# Layers: enc_conv0, enc_conv1, pool1 | |
self._block1 = nn.Sequential( | |
nn.Conv2d(in_channels, 48, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(48, 48, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2)) | |
# Layers: enc_conv(i), pool(i); i=2..5 | |
self._block2 = nn.Sequential( | |
nn.Conv2d(48, 48, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2)) | |
# Layers: enc_conv6, upsample5 | |
self._block3 = nn.Sequential( | |
nn.Conv2d(48, 48, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(48, 48, 3, stride=2, padding=1, output_padding=1)) | |
#nn.Upsample(scale_factor=2, mode='nearest')) | |
# Layers: dec_conv5a, dec_conv5b, upsample4 | |
self._block4 = nn.Sequential( | |
nn.Conv2d(96, 96, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(96, 96, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1)) | |
#nn.Upsample(scale_factor=2, mode='nearest')) | |
# Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2 | |
self._block5 = nn.Sequential( | |
nn.Conv2d(144, 96, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(96, 96, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1)) | |
#nn.Upsample(scale_factor=2, mode='nearest')) | |
# Layers: dec_conv1a, dec_conv1b, dec_conv1c, | |
self._block6 = nn.Sequential( | |
nn.Conv2d(96 + in_channels, 64, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(64, 32, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(32, out_channels, 3, stride=1, padding=1), | |
nn.LeakyReLU(0.1)) | |
# Initialize weights | |
self._init_weights() | |
self.task = opt.task | |
if opt.task == 'segment': | |
self._block6 = nn.Sequential( | |
nn.Conv2d(96 + in_channels, 64, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(64, 32, 3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(32, 2, 1)) | |
def _init_weights(self): | |
"""Initializes weights using He et al. (2015).""" | |
for m in self.modules(): | |
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight.data) | |
m.bias.data.zero_() | |
def forward(self, x): | |
"""Through encoder, then decoder by adding U-skip connections. """ | |
# Encoder | |
pool1 = self._block1(x) | |
pool2 = self._block2(pool1) | |
pool3 = self._block2(pool2) | |
pool4 = self._block2(pool3) | |
pool5 = self._block2(pool4) | |
# Decoder | |
upsample5 = self._block3(pool5) | |
concat5 = torch.cat((upsample5, pool4), dim=1) | |
upsample4 = self._block4(concat5) | |
concat4 = torch.cat((upsample4, pool3), dim=1) | |
upsample3 = self._block5(concat4) | |
concat3 = torch.cat((upsample3, pool2), dim=1) | |
upsample2 = self._block5(concat3) | |
concat2 = torch.cat((upsample2, pool1), dim=1) | |
upsample1 = self._block5(concat2) | |
concat1 = torch.cat((upsample1, x), dim=1) | |
# Final activation | |
return self._block6(concat1) | |
# ------------------ Alternative UNet implementation (batchnorm. outcommented) | |
class double_conv(nn.Module): | |
'''(conv => BN => ReLU) * 2''' | |
def __init__(self, in_ch, out_ch): | |
super(double_conv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch, out_ch, 3, padding=1), | |
# nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_ch, out_ch, 3, padding=1), | |
# nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class inconv(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(inconv, self).__init__() | |
self.conv = double_conv(in_ch, out_ch) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class down(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(down, self).__init__() | |
self.mpconv = nn.Sequential( | |
nn.MaxPool2d(2), | |
# nn.Conv2d(in_ch,in_ch, 2, stride=2), | |
double_conv(in_ch, out_ch) | |
) | |
def forward(self, x): | |
x = self.mpconv(x) | |
return x | |
class up(nn.Module): | |
def __init__(self, in_ch, out_ch, bilinear=False): | |
super(up, self).__init__() | |
# would be a nice idea if the upsampling could be learned too, | |
# but my machine do not have enough memory to handle all those weights | |
if bilinear: | |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
else: | |
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) | |
self.conv = double_conv(in_ch, out_ch) | |
def forward(self, x1, x2): | |
x1 = self.up(x1) | |
# input is CHW | |
diffY = x2.size()[2] - x1.size()[2] | |
diffX = x2.size()[3] - x1.size()[3] | |
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, | |
diffY // 2, diffY - diffY//2)) | |
# for padding issues, see | |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
x = torch.cat([x2, x1], dim=1) | |
x = self.conv(x) | |
return x | |
class outconv(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(outconv, self).__init__() | |
self.conv = nn.Conv2d(in_ch, out_ch, 1) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class UNet(nn.Module): | |
def __init__(self, n_channels, n_classes,opt): | |
super(UNet, self).__init__() | |
self.inc = inconv(n_channels, 64) | |
self.down1 = down(64, 128) | |
self.down2 = down(128, 256) | |
self.down3 = down(256, 512) | |
self.down4 = down(512, 512) | |
self.up1 = up(1024, 256) | |
self.up2 = up(512, 128) | |
self.up3 = up(256, 64) | |
self.up4 = up(128, 64) | |
if opt.task == 'segment': | |
self.outc = outconv(64, 2) | |
else: | |
self.outc = outconv(64, n_classes) | |
# Initialize weights | |
# self._init_weights() | |
def _init_weights(self): | |
"""Initializes weights using He et al. (2015).""" | |
for m in self.modules(): | |
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight.data) | |
m.bias.data.zero_() | |
def forward(self, x): | |
x1 = self.inc(x) | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x = self.up1(x5, x4) | |
x = self.up2(x, x3) | |
x = self.up3(x, x2) | |
x = self.up4(x, x1) | |
x = self.outc(x) | |
return F.sigmoid(x) | |
class UNet60M(nn.Module): | |
def __init__(self, n_channels, n_classes): | |
super(UNet60M, self).__init__() | |
self.inc = inconv(n_channels, 64) | |
self.down1 = down(64, 128) | |
self.down2 = down(128, 256) | |
self.down3 = down(256, 512) | |
self.down4 = down(512, 1024) | |
self.down5 = down(1024, 1024) | |
self.up1 = up(2048, 512) | |
self.up2 = up(1024, 256) | |
self.up3 = up(512, 128) | |
self.up4 = up(256, 64) | |
self.up5 = up(128, 64) | |
self.outc = outconv(64, n_classes) | |
def forward(self, x): | |
x1 = self.inc(x) | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x6 = self.down5(x5) | |
x = self.up1(x6, x5) | |
x = self.up2(x, x4) | |
x = self.up3(x, x3) | |
x = self.up4(x, x2) | |
x = self.up5(x, x1) | |
x = self.outc(x) | |
return F.sigmoid(x) | |
class UNetRep(nn.Module): | |
def __init__(self, n_channels, n_classes): | |
super(UNetRep, self).__init__() | |
self.inc = inconv(n_channels, 64) | |
self.down1 = down(64, 128) | |
self.down2 = down(128, 128) | |
self.up1 = up1(256, 128, 128) | |
self.up2 = up1(192, 64, 128) | |
self.outc = outconv(64, n_classes) | |
def forward(self, x): | |
x1 = self.inc(x) | |
for _ in range(3): | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x = self.up1(x3,x2) | |
x1 = self.up2(x,x1) | |
# x6 = self.down5(x5) | |
# x = self.up1(x6, x5) | |
# x = self.up2(x, x4) | |
# x = self.up3(x, x3) | |
# x = self.up4(x, x2) | |
# x = self.up5(x, x1) | |
x = self.outc(x1) | |
return F.sigmoid(x) | |
# ------------------- UNet Noise2noise implementation | |
class single_conv(nn.Module): | |
'''(conv => BN => ReLU) * 2''' | |
def __init__(self, in_ch, out_ch): | |
super(single_conv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch, out_ch, 3, padding=1), | |
# nn.BatchNorm2d(out_ch), | |
nn.ReLU(inplace=True), | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class outconv2(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(outconv2, self).__init__() | |
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class up1(nn.Module): | |
def __init__(self, in_ch, out_ch, convtr, bilinear=False): | |
super(up1, self).__init__() | |
# would be a nice idea if the upsampling could be learned too, | |
# but my machine do not have enough memory to handle all those weights | |
if bilinear: | |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
else: | |
self.up = nn.ConvTranspose2d(convtr, convtr, 3, stride=2) | |
self.conv = double_conv(in_ch, out_ch) | |
def forward(self, x1, x2): | |
x1 = self.up(x1) | |
# input is CHW | |
diffY = x2.size()[2] - x1.size()[2] | |
diffX = x2.size()[3] - x1.size()[3] | |
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, | |
diffY // 2, diffY - diffY//2)) | |
# for padding issues, see | |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
x = torch.cat([x2, x1], dim=1) | |
x = self.conv(x) | |
return x | |
class up2(nn.Module): | |
def __init__(self, in_ch, in_ch2, out_ch,out_ch2,convtr, bilinear=False): | |
super(up2, self).__init__() | |
# would be a nice idea if the upsampling could be learned too, | |
# but my machine do not have enough memory to handle all those weights | |
if bilinear: | |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
else: | |
self.up = nn.ConvTranspose2d(convtr, convtr, 3, stride=2) | |
# self.conv = double_conv(in_ch, out_ch) | |
self.conv = nn.Conv2d(in_ch + in_ch2, out_ch, 3, padding=1) | |
self.conv2 = nn.Conv2d(out_ch, out_ch2, 3, padding=1) | |
def forward(self, x1, x2): | |
x1 = self.up(x1) | |
# input is CHW | |
diffY = x2.size()[2] - x1.size()[2] | |
diffX = x2.size()[3] - x1.size()[3] | |
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, | |
diffY // 2, diffY - diffY//2)) | |
# for padding issues, see | |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
x = torch.cat([x2, x1], dim=1) | |
x = self.conv(x) | |
x = self.conv2(x) | |
return x | |
class down2(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(down2, self).__init__() | |
self.mpconv = nn.Sequential( | |
# nn.MaxPool2d(2), | |
nn.Conv2d(in_ch,in_ch, 2, stride=2), | |
single_conv(in_ch, out_ch) | |
) | |
def forward(self, x): | |
x = self.mpconv(x) | |
return x | |
class UNetGreedy(nn.Module): | |
def __init__(self, n_channels, n_classes): | |
super(UNetGreedy, self).__init__() | |
self.inc = inconv(n_channels, 144) | |
self.down1 = down(144, 144) | |
self.down2 = down2(144, 144) | |
self.down3 = down2(144, 144) | |
self.down4 = down2(144, 144) | |
self.down5 = down2(144, 144) | |
self.up1 = up1(288, 288,144) | |
self.up2 = up1(432, 288,288) | |
self.up3 = up1(432, 288,288) | |
self.up4 = up1(432, 288,288) | |
self.up5 = up2(288, n_channels, 64, 32,288) | |
self.outc = outconv2(32, n_classes) | |
def forward(self, x0): | |
x1 = self.inc(x0) | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x6 = self.down5(x5) | |
x = self.up1(x6, x5) | |
x = self.up2(x, x4) | |
x = self.up3(x, x3) | |
x = self.up4(x, x2) | |
x = self.up5(x, x0) | |
x = self.outc(x) | |
return F.sigmoid(x) | |
class UNet2(nn.Module): | |
def __init__(self, n_channels, n_classes): | |
super(UNet2, self).__init__() | |
self.inc = inconv(n_channels, 48) | |
self.down1 = down(48, 48) | |
self.down2 = down2(48, 48) | |
self.down3 = down2(48, 48) | |
self.down4 = down2(48, 48) | |
self.down5 = down2(48, 48) | |
self.up1 = up1(96, 96,48) | |
self.up2 = up1(144, 96,96) | |
self.up3 = up1(144, 96,96) | |
self.up4 = up1(144, 96,96) | |
self.up5 = up2(96, n_channels, 64, 32,96) | |
self.outc = outconv2(32, n_classes) | |
def forward(self, x0): | |
x1 = self.inc(x0) | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x6 = self.down5(x5) | |
x = self.up1(x6, x5) | |
x = self.up2(x, x4) | |
x = self.up3(x, x3) | |
x = self.up4(x, x2) | |
x = self.up5(x, x0) | |
x = self.outc(x) | |
return F.sigmoid(x) | |
class MLPNet(nn.Module): | |
def __init__(self): | |
super(MLPNet, self).__init__() | |
# 1 input image channel, 6 output channels, 5x5 square convolution kernel | |
self.conv1 = nn.Conv2d(3, 64, 3, padding=1) | |
self.conv12 = nn.Conv2d(64, 64, 3, padding=1) | |
self.pool = nn.MaxPool2d(2,2) | |
self.conv2 = nn.Conv2d(64, 96, 3, padding=1) | |
self.conv22 = nn.Conv2d(96, 128, 3, padding=1) | |
self.conv3 = nn.Conv2d(96, 128, 3, padding=1) | |
self.conv4 = nn.Conv2d(128, 128, 3, padding=1) | |
self.conv5 = nn.Conv2d(128, 64, 3, padding=1) | |
self.conv6 = nn.Conv2d(64, 32, 5) | |
# self.conv3 = nn.Conv2d(24, 48, 3, padding=1) | |
self.fc = nn.Sequential( | |
nn.Linear(6*6*32, 100), | |
nn.ReLU(), | |
nn.Dropout2d(p=0.2), | |
nn.Linear(100, 1), | |
nn.ReLU() | |
) | |
# self.fc2 = nn.Linear(100,50) | |
# self.fc3 = nn.Linear(50,20) | |
# self.fc4 = nn.Linear(20,1) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv12(x)) | |
x = self.pool(x) | |
x = F.relu(self.conv2(x)) | |
# x = F.relu(self.conv22(x)) | |
x = self.pool(x) | |
x = F.relu(self.conv3(x)) | |
x = F.relu(self.conv4(x)) | |
x = self.pool(x) | |
x = F.relu(self.conv5(x)) | |
x = self.pool(x) | |
x = F.relu(self.conv6(x)) | |
x = self.pool(x) | |
x = x.view(-1, 6*6*32) | |
x = self.fc(x) | |
# x = F.relu(self.fc1(x)) | |
# x = F.relu(self.fc2(x)) | |
return x | |
# --------------------- FFDNet | |
from torch.autograd import Function, Variable | |
def concatenate_input_noise_map(input, noise_sigma): | |
r"""Implements the first layer of FFDNet. This function returns a | |
torch.autograd.Variable composed of the concatenation of the downsampled | |
input image and the noise map. Each image of the batch of size CxHxW gets | |
converted to an array of size 4*CxH/2xW/2. Each of the pixels of the | |
non-overlapped 2x2 patches of the input image are placed in the new array | |
along the first dimension. | |
Args: | |
input: batch containing CxHxW images | |
noise_sigma: the value of the pixels of the CxH/2xW/2 noise map | |
""" | |
# noise_sigma is a list of length batch_size | |
N, C, H, W = input.size() | |
dtype = input.type() | |
sca = 2 | |
sca2 = sca*sca | |
Cout = sca2*C | |
Hout = H//sca | |
Wout = W//sca | |
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] | |
# Fill the downsampled image with zeros | |
if 'cuda' in dtype: | |
downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0) | |
else: | |
downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0) | |
# Build the CxH/2xW/2 noise map | |
noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout) | |
# Populate output | |
for idx in range(sca2): | |
downsampledfeatures[:, idx:Cout:sca2, :, :] = \ | |
input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] | |
# concatenate de-interleaved mosaic with noise map | |
return torch.cat((noise_map, downsampledfeatures), 1) | |
class UpSampleFeaturesFunction(Function): | |
r"""Extends PyTorch's modules by implementing a torch.autograd.Function. | |
This class implements the forward and backward methods of the last layer | |
of FFDNet. It basically performs the inverse of | |
concatenate_input_noise_map(): it converts each of the images of a | |
batch of size CxH/2xW/2 to images of size C/4xHxW | |
""" | |
def forward(ctx, input): | |
N, Cin, Hin, Win = input.size() | |
dtype = input.type() | |
sca = 2 | |
sca2 = sca*sca | |
Cout = Cin//sca2 | |
Hout = Hin*sca | |
Wout = Win*sca | |
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] | |
assert (Cin%sca2 == 0), \ | |
'Invalid input dimensions: number of channels should be divisible by 4' | |
result = torch.zeros((N, Cout, Hout, Wout)).type(dtype) | |
for idx in range(sca2): | |
result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = \ | |
input[:, idx:Cin:sca2, :, :] | |
return result | |
def backward(ctx, grad_output): | |
N, Cg_out, Hg_out, Wg_out = grad_output.size() | |
dtype = grad_output.data.type() | |
sca = 2 | |
sca2 = sca*sca | |
Cg_in = sca2*Cg_out | |
Hg_in = Hg_out//sca | |
Wg_in = Wg_out//sca | |
idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] | |
# Build output | |
grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype) | |
# Populate output | |
for idx in range(sca2): | |
grad_input[:, idx:Cg_in:sca2, :, :] = \ | |
grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] | |
return Variable(grad_input) | |
# Alias functions | |
upsamplefeatures = UpSampleFeaturesFunction.apply | |
class UpSampleFeatures(nn.Module): | |
r"""Implements the last layer of FFDNet | |
""" | |
def __init__(self): | |
super(UpSampleFeatures, self).__init__() | |
def forward(self, x): | |
return upsamplefeatures(x) | |
class IntermediateDnCNN(nn.Module): | |
r"""Implements the middel part of the FFDNet architecture, which | |
is basically a DnCNN net | |
""" | |
def __init__(self, input_features, middle_features, num_conv_layers): | |
super(IntermediateDnCNN, self).__init__() | |
self.kernel_size = 3 | |
self.padding = 1 | |
self.input_features = input_features | |
self.num_conv_layers = num_conv_layers | |
self.middle_features = middle_features | |
if self.input_features == 5: | |
self.output_features = 4 #Grayscale image | |
elif self.input_features == 15: | |
self.output_features = 12 #RGB image | |
else: | |
self.output_features = 3 | |
# raise Exception('Invalid number of input features') | |
layers = [] | |
layers.append(nn.Conv2d(in_channels=self.input_features,\ | |
out_channels=self.middle_features,\ | |
kernel_size=self.kernel_size,\ | |
padding=self.padding,\ | |
bias=False)) | |
layers.append(nn.ReLU(inplace=True)) | |
for _ in range(self.num_conv_layers-2): | |
layers.append(nn.Conv2d(in_channels=self.middle_features,\ | |
out_channels=self.middle_features,\ | |
kernel_size=self.kernel_size,\ | |
padding=self.padding,\ | |
bias=False)) | |
# layers.append(nn.BatchNorm2d(self.middle_features)) | |
layers.append(nn.ReLU(inplace=True)) | |
layers.append(nn.Conv2d(in_channels=self.middle_features,\ | |
out_channels=self.output_features,\ | |
kernel_size=self.kernel_size,\ | |
padding=self.padding,\ | |
bias=False)) | |
self.itermediate_dncnn = nn.Sequential(*layers) | |
def forward(self, x): | |
out = self.itermediate_dncnn(x) | |
return out | |
class FFDNet(nn.Module): | |
r"""Implements the FFDNet architecture | |
""" | |
def __init__(self, num_input_channels, test_mode=False): | |
super(FFDNet, self).__init__() | |
self.num_input_channels = num_input_channels | |
self.test_mode = test_mode | |
if self.num_input_channels == 1: | |
# Grayscale image | |
self.num_feature_maps = 64 | |
self.num_conv_layers = 15 | |
self.downsampled_channels = 5 | |
self.output_features = 4 | |
elif self.num_input_channels == 3: | |
# RGB image | |
self.num_feature_maps = 96 | |
self.num_conv_layers = 12 | |
self.downsampled_channels = 15 | |
self.output_features = 12 | |
else: | |
raise Exception('Invalid number of input features') | |
self.intermediate_dncnn = IntermediateDnCNN(\ | |
input_features=self.downsampled_channels,\ | |
middle_features=self.num_feature_maps,\ | |
num_conv_layers=self.num_conv_layers) | |
self.upsamplefeatures = UpSampleFeatures() | |
def forward(self, x, noise_sigma): | |
concat_noise_x = concatenate_input_noise_map(\ | |
x.data, noise_sigma.data) | |
if self.test_mode: | |
concat_noise_x = Variable(concat_noise_x, volatile=True) | |
else: | |
concat_noise_x = Variable(concat_noise_x) | |
h_dncnn = self.intermediate_dncnn(concat_noise_x) | |
pred_noise = self.upsamplefeatures(h_dncnn) | |
return pred_noise | |
class DNCNN(nn.Module): | |
r"""Implements the DNCNNNet architecture | |
""" | |
def __init__(self, num_input_channels, test_mode=False): | |
super(DNCNN, self).__init__() | |
self.num_input_channels = num_input_channels | |
self.test_mode = test_mode | |
if self.num_input_channels == 1: | |
# Grayscale image | |
self.num_feature_maps = 64 | |
self.num_conv_layers = 15 | |
self.downsampled_channels = 5 | |
self.output_features = 4 | |
elif self.num_input_channels == 3: | |
# RGB image | |
self.num_feature_maps = 96 | |
self.num_conv_layers = 12 | |
self.downsampled_channels = 15 | |
self.output_features = 12 | |
else: | |
raise Exception('Invalid number of input features') | |
self.intermediate_dncnn = IntermediateDnCNN(\ | |
input_features=self.num_input_channels,\ | |
middle_features=self.num_feature_maps,\ | |
num_conv_layers=self.num_conv_layers) | |
def forward(self, x): | |
dncnn = self.intermediate_dncnn(x) | |
return dncnn | |