|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True, groups=1): |
|
return nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
padding=(kernel_size // 2), bias=bias, groups=groups) |
|
|
|
def default_norm(n_feats): |
|
return nn.BatchNorm2d(n_feats) |
|
|
|
def default_act(): |
|
return nn.ReLU(True) |
|
|
|
def empty_h(x, n_feats): |
|
''' |
|
create an empty hidden state |
|
|
|
input |
|
x: B x T x 3 x H x W |
|
|
|
output |
|
h: B x C x H/4 x W/4 |
|
''' |
|
b = x.size(0) |
|
h, w = x.size()[-2:] |
|
return x.new_zeros((b, n_feats, h//4, w//4)) |
|
|
|
class Normalization(nn.Conv2d): |
|
"""Normalize input tensor value with convolutional layer""" |
|
def __init__(self, mean=(0, 0, 0), std=(1, 1, 1)): |
|
super(Normalization, self).__init__(3, 3, kernel_size=1) |
|
tensor_mean = torch.Tensor(mean) |
|
tensor_inv_std = torch.Tensor(std).reciprocal() |
|
|
|
self.weight.data = torch.eye(3).mul(tensor_inv_std).view(3, 3, 1, 1) |
|
self.bias.data = torch.Tensor(-tensor_mean.mul(tensor_inv_std)) |
|
|
|
for params in self.parameters(): |
|
params.requires_grad = False |
|
|
|
class BasicBlock(nn.Sequential): |
|
"""Convolution layer + Activation layer""" |
|
def __init__( |
|
self, in_channels, out_channels, kernel_size, bias=True, |
|
conv=default_conv, norm=False, act=default_act): |
|
|
|
modules = [] |
|
modules.append( |
|
conv(in_channels, out_channels, kernel_size, bias=bias)) |
|
if norm: modules.append(norm(out_channels)) |
|
if act: modules.append(act()) |
|
|
|
super(BasicBlock, self).__init__(*modules) |
|
|
|
class ResBlock(nn.Module): |
|
def __init__( |
|
self, n_feats, kernel_size, bias=True, |
|
conv=default_conv, norm=False, act=default_act): |
|
|
|
super(ResBlock, self).__init__() |
|
|
|
modules = [] |
|
for i in range(2): |
|
modules.append(conv(n_feats, n_feats, kernel_size, bias=bias)) |
|
if norm: modules.append(norm(n_feats)) |
|
if act and i == 0: modules.append(act()) |
|
|
|
self.body = nn.Sequential(*modules) |
|
|
|
def forward(self, x): |
|
res = self.body(x) |
|
res += x |
|
|
|
return res |
|
|
|
class ResBlock_mobile(nn.Module): |
|
def __init__( |
|
self, n_feats, kernel_size, bias=True, |
|
conv=default_conv, norm=False, act=default_act, dropout=False): |
|
|
|
super(ResBlock_mobile, self).__init__() |
|
|
|
modules = [] |
|
for i in range(2): |
|
modules.append(conv(n_feats, n_feats, kernel_size, bias=False, groups=n_feats)) |
|
modules.append(conv(n_feats, n_feats, 1, bias=False)) |
|
if dropout and i == 0: modules.append(nn.Dropout2d(dropout)) |
|
if norm: modules.append(norm(n_feats)) |
|
if act and i == 0: modules.append(act()) |
|
|
|
self.body = nn.Sequential(*modules) |
|
|
|
def forward(self, x): |
|
res = self.body(x) |
|
res += x |
|
|
|
return res |
|
|
|
class Upsampler(nn.Sequential): |
|
def __init__( |
|
self, scale, n_feats, bias=True, |
|
conv=default_conv, norm=False, act=False): |
|
|
|
modules = [] |
|
if (scale & (scale - 1)) == 0: |
|
for _ in range(int(math.log(scale, 2))): |
|
modules.append(conv(n_feats, 4 * n_feats, 3, bias)) |
|
modules.append(nn.PixelShuffle(2)) |
|
if norm: modules.append(norm(n_feats)) |
|
if act: modules.append(act()) |
|
elif scale == 3: |
|
modules.append(conv(n_feats, 9 * n_feats, 3, bias)) |
|
modules.append(nn.PixelShuffle(3)) |
|
if norm: modules.append(norm(n_feats)) |
|
if act: modules.append(act()) |
|
else: |
|
raise NotImplementedError |
|
|
|
super(Upsampler, self).__init__(*modules) |
|
|
|
|
|
class PixelSort(nn.Module): |
|
"""The inverse operation of PixelShuffle |
|
Reduces the spatial resolution, increasing the number of channels. |
|
Currently, scale 0.5 is supported only. |
|
Later, torch.nn.functional.pixel_sort may be implemented. |
|
Reference: |
|
http://pytorch.org/docs/0.3.0/_modules/torch/nn/modules/pixelshuffle.html#PixelShuffle |
|
http://pytorch.org/docs/0.3.0/_modules/torch/nn/functional.html#pixel_shuffle |
|
""" |
|
def __init__(self, upscale_factor=0.5): |
|
super(PixelSort, self).__init__() |
|
self.upscale_factor = upscale_factor |
|
|
|
def forward(self, x): |
|
b, c, h, w = x.size() |
|
x = x.view(b, c, 2, 2, h // 2, w // 2) |
|
x = x.permute(0, 1, 5, 3, 2, 4).contiguous() |
|
x = x.view(b, 4 * c, h // 2, w // 2) |
|
|
|
return x |
|
|
|
class Downsampler(nn.Sequential): |
|
def __init__( |
|
self, scale, n_feats, bias=True, |
|
conv=default_conv, norm=False, act=False): |
|
|
|
modules = [] |
|
if scale == 0.5: |
|
modules.append(PixelSort()) |
|
modules.append(conv(4 * n_feats, n_feats, 3, bias)) |
|
if norm: modules.append(norm(n_feats)) |
|
if act: modules.append(act()) |
|
else: |
|
raise NotImplementedError |
|
|
|
super(Downsampler, self).__init__(*modules) |
|
|
|
|