|
from model import common
|
|
import torch.nn as nn
|
|
import torch
|
|
from model.attention import ContextualAttention,NonLocalAttention
|
|
def make_model(args, parent=False):
|
|
return MSSR(args)
|
|
|
|
class MultisourceProjection(nn.Module):
|
|
def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
|
|
super(MultisourceProjection, self).__init__()
|
|
self.up_attention = ContextualAttention(scale=2)
|
|
self.down_attention = NonLocalAttention()
|
|
self.upsample = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
|
|
self.encoder = common.ResBlock(conv, in_channel, kernel_size, act=nn.PReLU(), res_scale=1)
|
|
|
|
def forward(self,x):
|
|
down_map = self.upsample(self.down_attention(x))
|
|
up_map = self.up_attention(x)
|
|
|
|
err = self.encoder(up_map-down_map)
|
|
final_map = down_map + err
|
|
|
|
return final_map
|
|
|
|
class RecurrentProjection(nn.Module):
|
|
def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
|
|
super(RecurrentProjection, self).__init__()
|
|
self.multi_source_projection_1 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
|
|
self.multi_source_projection_2 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
|
|
self.down_sample_1 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
|
|
|
|
self.down_sample_3 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
|
|
self.down_sample_4 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
|
|
self.error_encode_1 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
|
|
self.error_encode_2 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
|
|
self.post_conv = common.BasicBlock(conv,in_channel,in_channel,kernel_size,stride=1,bias=True,act=nn.PReLU())
|
|
|
|
|
|
def forward(self, x):
|
|
x_up = self.multi_source_projection_1(x)
|
|
|
|
x_down = self.down_sample_1(x_up)
|
|
error_up = self.error_encode_1(x-x_down)
|
|
h_estimate_1 = x_up + error_up
|
|
|
|
x_up_2 = self.multi_source_projection_2(h_estimate_1)
|
|
x_down_2 = self.down_sample_3(x_up_2)
|
|
error_up_2 = self.error_encode_2(x-x_down_2)
|
|
h_estimate_2 = x_up_2 + error_up_2
|
|
x_final = self.post_conv(self.down_sample_4(h_estimate_2))
|
|
|
|
return x_final, h_estimate_2
|
|
|
|
|
|
|
|
|
|
|
|
class MSSR(nn.Module):
|
|
def __init__(self, args, conv=common.default_conv):
|
|
super(MSSR, self).__init__()
|
|
|
|
|
|
n_feats = args.n_feats
|
|
self.depth = args.depth
|
|
kernel_size = 3
|
|
scale = args.scale[0]
|
|
|
|
|
|
rgb_mean = (0.4488, 0.4371, 0.4040)
|
|
rgb_std = (1.0, 1.0, 1.0)
|
|
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
|
|
|
|
|
|
m_head = [common.BasicBlock(conv, args.n_colors, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU()),
|
|
common.BasicBlock(conv,n_feats, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU())]
|
|
|
|
|
|
|
|
self.body = RecurrentProjection(n_feats)
|
|
|
|
|
|
|
|
m_tail = [
|
|
nn.Conv2d(
|
|
n_feats*self.depth, args.n_colors, kernel_size,
|
|
padding=(kernel_size//2)
|
|
)
|
|
]
|
|
|
|
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
|
|
|
|
self.head = nn.Sequential(*m_head)
|
|
self.tail = nn.Sequential(*m_tail)
|
|
def forward(self,input):
|
|
x = self.sub_mean(input)
|
|
x = self.head(x)
|
|
bag = []
|
|
for i in range(self.depth):
|
|
x, h_estimate = self.body(x)
|
|
bag.append(h_estimate)
|
|
h_feature = torch.cat(bag,dim=1)
|
|
h_final = self.tail(h_feature)
|
|
|
|
return self.add_mean(h_final)
|
|
|