import math import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F def init_weights(modules): pass class MeanShift(nn.Module): def __init__(self, mean_rgb, sub): super(MeanShift, self).__init__() sign = -1 if sub else 1 r = mean_rgb[0] * sign g = mean_rgb[1] * sign b = mean_rgb[2] * sign self.shifter = nn.Conv2d(3, 3, 1, 1, 0) #3 is size of output, 3 is size of input, 1 is kernel 1 is padding, 0 is group self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) # view(3,3,1,1) convert a shape into (3,3,1,1) eye(3) is a 3x3 matrix and diagonal is 1. self.shifter.bias.data = torch.Tensor([r, g, b]) #in_channels, out_channels,ksize=3, stride=1, pad=1 # Freeze the mean shift layer for params in self.shifter.parameters(): params.requires_grad = False def forward(self, x): x = self.shifter(x) return x class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, ksize=3, stride=1, pad=1): super(BasicBlock, self).__init__() self.body = nn.Sequential( nn.Conv2d(in_channels, out_channels, ksize, stride, pad), nn.ReLU(inplace=True) ) init_weights(self.modules) def forward(self, x): out = self.body(x) return out class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ResidualBlock, self).__init__() self.body = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, 1, 1), ) init_weights(self.modules) def forward(self, x): out = self.body(x) out = F.relu(out + x) return out class EResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, group=1): super(EResidualBlock, self).__init__() self.body = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 1, 1, 0), ) init_weights(self.modules) def forward(self, x): out = self.body(x) out = F.relu(out + x) return out class UpsampleBlock(nn.Module): def __init__(self, n_channels, scale, multi_scale, group=1): super(UpsampleBlock, self).__init__() if multi_scale: self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) else: self.up = _UpsampleBlock(n_channels, scale=scale, group=group) self.multi_scale = multi_scale def forward(self, x, scale): if self.multi_scale: if scale == 2: return self.up2(x) elif scale == 3: return self.up3(x) elif scale == 4: return self.up4(x) else: return self.up(x) class _UpsampleBlock(nn.Module): def __init__(self, n_channels, scale, group=1): super(_UpsampleBlock, self).__init__() modules = [] if scale == 2 or scale == 4 or scale == 8: for _ in range(int(math.log(scale, 2))): modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] #modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group)] modules += [nn.PixelShuffle(2)] elif scale == 3: modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] #modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group)] modules += [nn.PixelShuffle(3)] self.body = nn.Sequential(*modules) init_weights(self.modules) def forward(self, x): out = self.body(x) return out