#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations from collections import OrderedDict try: from typing import Literal except ImportError: from typing_extensions import Literal import torch import torch.nn as nn #################### # Basic blocks #################### def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1): # helper selecting activation # neg_slope: for leakyrelu and init of prelu # n_prelu: for p_relu num_parameters act_type = act_type.lower() if act_type == "relu": layer = nn.ReLU(inplace) elif act_type == "leakyrelu": layer = nn.LeakyReLU(neg_slope, inplace) elif act_type == "prelu": layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) else: raise NotImplementedError( "activation layer [{:s}] is not found".format(act_type) ) return layer def norm(norm_type: str, nc: int): # helper selecting normalization layer norm_type = norm_type.lower() if norm_type == "batch": layer = nn.BatchNorm2d(nc, affine=True) elif norm_type == "instance": layer = nn.InstanceNorm2d(nc, affine=False) else: raise NotImplementedError( "normalization layer [{:s}] is not found".format(norm_type) ) return layer def pad(pad_type: str, padding): # helper selecting padding layer # if padding is 'zero', do by conv layers pad_type = pad_type.lower() if padding == 0: return None if pad_type == "reflect": layer = nn.ReflectionPad2d(padding) elif pad_type == "replicate": layer = nn.ReplicationPad2d(padding) else: raise NotImplementedError( "padding layer [{:s}] is not implemented".format(pad_type) ) return layer def get_valid_padding(kernel_size, dilation): kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) padding = (kernel_size - 1) // 2 return padding class ConcatBlock(nn.Module): # Concat the output of a submodule to its input def __init__(self, submodule): super(ConcatBlock, self).__init__() self.sub = submodule def forward(self, x): output = torch.cat((x, self.sub(x)), dim=1) return output def __repr__(self): tmpstr = "Identity .. \n|" modstr = self.sub.__repr__().replace("\n", "\n|") tmpstr = tmpstr + modstr return tmpstr class ShortcutBlock(nn.Module): # Elementwise sum the output of a submodule to its input def __init__(self, submodule): super(ShortcutBlock, self).__init__() self.sub = submodule def forward(self, x): output = x + self.sub(x) return output def __repr__(self): tmpstr = "Identity + \n|" modstr = self.sub.__repr__().replace("\n", "\n|") tmpstr = tmpstr + modstr return tmpstr class ShortcutBlockSPSR(nn.Module): # Elementwise sum the output of a submodule to its input def __init__(self, submodule): super(ShortcutBlockSPSR, self).__init__() self.sub = submodule def forward(self, x): return x, self.sub def __repr__(self): tmpstr = "Identity + \n|" modstr = self.sub.__repr__().replace("\n", "\n|") tmpstr = tmpstr + modstr return tmpstr def sequential(*args): # Flatten Sequential. It unwraps nn.Sequential. if len(args) == 1: if isinstance(args[0], OrderedDict): raise NotImplementedError("sequential does not support OrderedDict input.") return args[0] # No sequential is needed. modules = [] for module in args: if isinstance(module, nn.Sequential): for submodule in module.children(): modules.append(submodule) elif isinstance(module, nn.Module): modules.append(module) return nn.Sequential(*modules) ConvMode = Literal["CNA", "NAC", "CNAC"] # 2x2x2 Conv Block def conv_block_2c2( in_nc, out_nc, act_type="relu", ): return sequential( nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), act(act_type) if act_type else None, ) def conv_block( in_nc: int, out_nc: int, kernel_size, stride=1, dilation=1, groups=1, bias=True, pad_type="zero", norm_type: str | None = None, act_type: str | None = "relu", mode: ConvMode = "CNA", c2x2=False, ): """ Conv layer with padding, normalization, activation mode: CNA --> Conv -> Norm -> Act NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) """ if c2x2: return conv_block_2c2(in_nc, out_nc, act_type=act_type) assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) padding = get_valid_padding(kernel_size, dilation) p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None padding = padding if pad_type == "zero" else 0 c = nn.Conv2d( in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups, ) a = act(act_type) if act_type else None if mode in ("CNA", "CNAC"): n = norm(norm_type, out_nc) if norm_type else None return sequential(p, c, n, a) elif mode == "NAC": if norm_type is None and act_type is not None: a = act(act_type, inplace=False) # Important! # input----ReLU(inplace)----Conv--+----output # |________________________| # inplace ReLU will modify the input, therefore wrong output n = norm(norm_type, in_nc) if norm_type else None return sequential(n, a, p, c) else: assert False, f"Invalid conv mode {mode}" #################### # Useful blocks #################### class ResNetBlock(nn.Module): """ ResNet Block, 3-3 style with extra residual scaling used in EDSR (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) """ def __init__( self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, bias=True, pad_type="zero", norm_type=None, act_type="relu", mode: ConvMode = "CNA", res_scale=1, ): super(ResNetBlock, self).__init__() conv0 = conv_block( in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode, ) if mode == "CNA": act_type = None if mode == "CNAC": # Residual path: |-CNAC-| act_type = None norm_type = None conv1 = conv_block( mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode, ) # if in_nc != out_nc: # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ # None, None) # print('Need a projecter in ResNetBlock.') # else: # self.project = lambda x:x self.res = sequential(conv0, conv1) self.res_scale = res_scale def forward(self, x): res = self.res(x).mul(self.res_scale) return x + res class RRDB(nn.Module): """ Residual in Residual Dense Block (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) """ def __init__( self, nf, kernel_size=3, gc=32, stride=1, bias: bool = True, pad_type="zero", norm_type=None, act_type="leakyrelu", mode: ConvMode = "CNA", _convtype="Conv2D", _spectral_norm=False, plus=False, c2x2=False, ): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C( nf, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode, plus=plus, c2x2=c2x2, ) self.RDB2 = ResidualDenseBlock_5C( nf, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode, plus=plus, c2x2=c2x2, ) self.RDB3 = ResidualDenseBlock_5C( nf, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode, plus=plus, c2x2=c2x2, ) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x class ResidualDenseBlock_5C(nn.Module): """ Residual Dense Block style: 5 convs The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) Modified options that can be used: - "Partial Convolution based Padding" arXiv:1811.11718 - "Spectral normalization" arXiv:1802.05957 - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. {Rakotonirina} and A. {Rasoanaivo} Args: nf (int): Channel number of intermediate features (num_feat). gc (int): Channels for each growth (num_grow_ch: growth channel, i.e. intermediate channels). convtype (str): the type of convolution to use. Default: 'Conv2D' gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new trainable parameters) plus (bool): enable the additional residual paths from ESRGAN+ (adds trainable parameters) """ def __init__( self, nf=64, kernel_size=3, gc=32, stride=1, bias: bool = True, pad_type="zero", norm_type=None, act_type="leakyrelu", mode: ConvMode = "CNA", plus=False, c2x2=False, ): super(ResidualDenseBlock_5C, self).__init__() ## + self.conv1x1 = conv1x1(nf, gc) if plus else None ## + self.conv1 = conv_block( nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode, c2x2=c2x2, ) self.conv2 = conv_block( nf + gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode, c2x2=c2x2, ) self.conv3 = conv_block( nf + 2 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode, c2x2=c2x2, ) self.conv4 = conv_block( nf + 3 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode, c2x2=c2x2, ) if mode == "CNA": last_act = None else: last_act = act_type self.conv5 = conv_block( nf + 4 * gc, nf, 3, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode, c2x2=c2x2, ) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(torch.cat((x, x1), 1)) if self.conv1x1: # pylint: disable=not-callable x2 = x2 + self.conv1x1(x) # + x3 = self.conv3(torch.cat((x, x1, x2), 1)) x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) if self.conv1x1: x4 = x4 + x2 # + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x def conv1x1(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) #################### # Upsampler #################### def pixelshuffle_block( in_nc: int, out_nc: int, upscale_factor=2, kernel_size=3, stride=1, bias=True, pad_type="zero", norm_type: str | None = None, act_type="relu", ): """ Pixel shuffle layer (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network, CVPR17) """ conv = conv_block( in_nc, out_nc * (upscale_factor**2), kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=None, act_type=None, ) pixel_shuffle = nn.PixelShuffle(upscale_factor) n = norm(norm_type, out_nc) if norm_type else None a = act(act_type) if act_type else None return sequential(conv, pixel_shuffle, n, a) def upconv_block( in_nc: int, out_nc: int, upscale_factor=2, kernel_size=3, stride=1, bias=True, pad_type="zero", norm_type: str | None = None, act_type="relu", mode="nearest", c2x2=False, ): # Up conv # described in https://distill.pub/2016/deconv-checkerboard/ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) conv = conv_block( in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, c2x2=c2x2, ) return sequential(upsample, conv)