File size: 2,379 Bytes
8ec10cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
import torch.nn as nn
from . import common
from .ResNet import ResNet
from lambda_network import LambdaLayer
def build_model(args):
return MSResNet(args)
class conv_end(nn.Module):
def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
super(conv_end, self).__init__()
modules = [
common.default_conv(in_channels, out_channels, kernel_size),
nn.PixelShuffle(ratio)
]
self.uppath = nn.Sequential(*modules)
def forward(self, x):
return self.uppath(x)
class MSResNet(nn.Module):
def __init__(self, args):
super(MSResNet, self).__init__()
self.rgb_range = args.rgb_range
self.mean = self.rgb_range / 2
self.n_resblocks = args.n_resblocks
self.n_feats = args.n_feats
self.kernel_size = args.kernel_size
self.n_scales = args.n_scales
self.body_models = nn.ModuleList([
ResNet(args, 3, 3, mean_shift=False),
])
self.lambda_models = nn.ModuleList([
LambdaLayer(
dim = 32, # channels going in
dim_out = 32, # channels out
n = 64 * 64, # number of input pixels (64 x 64 image)
dim_k = 16, # key dimension
heads = 4, # number of heads, for multi-query
dim_u = 1 # 'intra-depth' dimension
)
])
for _ in range(1, self.n_scales):
self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False))
self.conv_end_models = nn.ModuleList([None])
for _ in range(1, self.n_scales):
self.conv_end_models += [conv_end(3, 12)]
def forward(self, input_pyramid):
scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse
for s in scales:
input_pyramid[s] = input_pyramid[s] - self.mean
output_pyramid = [None] * self.n_scales
input_s = input_pyramid[-1]
for s in scales: # [2, 1, 0]
output_pyramid[s] = self.body_models[s](input_s)
if s > 0:
up_feat = self.conv_end_models[s](output_pyramid[s])
input_s = torch.cat((input_pyramid[s-1], up_feat), 1)
for s in scales:
output_pyramid[s] = output_pyramid[s] + self.mean
return output_pyramid
|