hyliu's picture
Upload folder using huggingface_hub
8cb1339 verified
# Deep Back-Projection Networks For Super-Resolution
# https://arxiv.org/abs/1803.02735
from model import common
import torch
import torch.nn as nn
def make_model(args, parent=False):
return DDBPN(args)
def projection_conv(in_channels, out_channels, scale, up=True):
kernel_size, stride, padding = {
2: (6, 2, 2),
4: (8, 4, 2),
8: (12, 8, 2)
}[scale]
if up:
conv_f = nn.ConvTranspose2d
else:
conv_f = nn.Conv2d
return conv_f(
in_channels, out_channels, kernel_size,
stride=stride, padding=padding
)
class DenseProjection(nn.Module):
def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
super(DenseProjection, self).__init__()
if bottleneck:
self.bottleneck = nn.Sequential(*[
nn.Conv2d(in_channels, nr, 1),
nn.PReLU(nr)
])
inter_channels = nr
else:
self.bottleneck = None
inter_channels = in_channels
self.conv_1 = nn.Sequential(*[
projection_conv(inter_channels, nr, scale, up),
nn.PReLU(nr)
])
self.conv_2 = nn.Sequential(*[
projection_conv(nr, inter_channels, scale, not up),
nn.PReLU(inter_channels)
])
self.conv_3 = nn.Sequential(*[
projection_conv(inter_channels, nr, scale, up),
nn.PReLU(nr)
])
def forward(self, x):
if self.bottleneck is not None:
x = self.bottleneck(x)
a_0 = self.conv_1(x)
b_0 = self.conv_2(a_0)
e = b_0.sub(x)
a_1 = self.conv_3(e)
out = a_0.add(a_1)
return out
class DDBPN(nn.Module):
def __init__(self, args):
super(DDBPN, self).__init__()
scale = args.scale[0]
n0 = 128
nr = 32
self.depth = 6
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
initial = [
nn.Conv2d(args.n_colors, n0, 3, padding=1),
nn.PReLU(n0),
nn.Conv2d(n0, nr, 1),
nn.PReLU(nr)
]
self.initial = nn.Sequential(*initial)
self.upmodules = nn.ModuleList()
self.downmodules = nn.ModuleList()
channels = nr
for i in range(self.depth):
self.upmodules.append(
DenseProjection(channels, nr, scale, True, i > 1)
)
if i != 0:
channels += nr
channels = nr
for i in range(self.depth - 1):
self.downmodules.append(
DenseProjection(channels, nr, scale, False, i != 0)
)
channels += nr
reconstruction = [
nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1)
]
self.reconstruction = nn.Sequential(*reconstruction)
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
def forward(self, x):
x = self.sub_mean(x)
x = self.initial(x)
h_list = []
l_list = []
for i in range(self.depth - 1):
if i == 0:
l = x
else:
l = torch.cat(l_list, dim=1)
h_list.append(self.upmodules[i](l))
l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))
h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))
out = self.reconstruction(torch.cat(h_list, dim=1))
out = self.add_mean(out)
return out