hyliu's picture
Upload folder using huggingface_hub
d6ec83b verified
from model import common
import torch.nn as nn
import torch
from model.attention import ContextualAttention,NonLocalAttention
def make_model(args, parent=False):
return MSSR(args)
class MultisourceProjection(nn.Module):
def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
super(MultisourceProjection, self).__init__()
self.up_attention = ContextualAttention(scale=2)
self.down_attention = NonLocalAttention()
self.upsample = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.encoder = common.ResBlock(conv, in_channel, kernel_size, act=nn.PReLU(), res_scale=1)
def forward(self,x):
down_map = self.upsample(self.down_attention(x))
up_map = self.up_attention(x)
err = self.encoder(up_map-down_map)
final_map = down_map + err
return final_map
class RecurrentProjection(nn.Module):
def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
super(RecurrentProjection, self).__init__()
self.multi_source_projection_1 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
self.multi_source_projection_2 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
self.down_sample_1 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
#self.down_sample_2 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.down_sample_3 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.down_sample_4 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.error_encode_1 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.error_encode_2 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.post_conv = common.BasicBlock(conv,in_channel,in_channel,kernel_size,stride=1,bias=True,act=nn.PReLU())
def forward(self, x):
x_up = self.multi_source_projection_1(x)
x_down = self.down_sample_1(x_up)
error_up = self.error_encode_1(x-x_down)
h_estimate_1 = x_up + error_up
x_up_2 = self.multi_source_projection_2(h_estimate_1)
x_down_2 = self.down_sample_3(x_up_2)
error_up_2 = self.error_encode_2(x-x_down_2)
h_estimate_2 = x_up_2 + error_up_2
x_final = self.post_conv(self.down_sample_4(h_estimate_2))
return x_final, h_estimate_2
class MSSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(MSSR, self).__init__()
#n_convblock = args.n_convblocks
n_feats = args.n_feats
self.depth = args.depth
kernel_size = 3
scale = args.scale[0]
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module
m_head = [common.BasicBlock(conv, args.n_colors, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU()),
common.BasicBlock(conv,n_feats, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU())]
# define multiple reconstruction module
self.body = RecurrentProjection(n_feats)
# define tail module
m_tail = [
nn.Conv2d(
n_feats*self.depth, args.n_colors, kernel_size,
padding=(kernel_size//2)
)
]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.tail = nn.Sequential(*m_tail)
def forward(self,input):
x = self.sub_mean(input)
x = self.head(x)
bag = []
for i in range(self.depth):
x, h_estimate = self.body(x)
bag.append(h_estimate)
h_feature = torch.cat(bag,dim=1)
h_final = self.tail(h_feature)
return self.add_mean(h_final)