"""Modified from https://github.com/chaofengc/PSFRGAN """ import numpy as np import torch.nn as nn from torch.nn import functional as F class NormLayer(nn.Module): """Normalization Layers. Args: channels: input channels, for batch norm and instance norm. input_size: input shape without batch size, for layer norm. """ def __init__(self, channels, normalize_shape=None, norm_type='bn'): super(NormLayer, self).__init__() norm_type = norm_type.lower() self.norm_type = norm_type if norm_type == 'bn': self.norm = nn.BatchNorm2d(channels, affine=True) elif norm_type == 'in': self.norm = nn.InstanceNorm2d(channels, affine=False) elif norm_type == 'gn': self.norm = nn.GroupNorm(32, channels, affine=True) elif norm_type == 'pixel': self.norm = lambda x: F.normalize(x, p=2, dim=1) elif norm_type == 'layer': self.norm = nn.LayerNorm(normalize_shape) elif norm_type == 'none': self.norm = lambda x: x * 1.0 else: assert 1 == 0, f'Norm type {norm_type} not support.' def forward(self, x, ref=None): if self.norm_type == 'spade': return self.norm(x, ref) else: return self.norm(x) class ReluLayer(nn.Module): """Relu Layer. Args: relu type: type of relu layer, candidates are - ReLU - LeakyReLU: default relu slope 0.2 - PRelu - SELU - none: direct pass """ def __init__(self, channels, relu_type='relu'): super(ReluLayer, self).__init__() relu_type = relu_type.lower() if relu_type == 'relu': self.func = nn.ReLU(True) elif relu_type == 'leakyrelu': self.func = nn.LeakyReLU(0.2, inplace=True) elif relu_type == 'prelu': self.func = nn.PReLU(channels) elif relu_type == 'selu': self.func = nn.SELU(True) elif relu_type == 'none': self.func = lambda x: x * 1.0 else: assert 1 == 0, f'Relu type {relu_type} not support.' def forward(self, x): return self.func(x) class ConvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True, bias=True): super(ConvLayer, self).__init__() self.use_pad = use_pad self.norm_type = norm_type if norm_type in ['bn']: bias = False stride = 2 if scale == 'down' else 1 self.scale_func = lambda x: x if scale == 'up': self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) self.relu = ReluLayer(out_channels, relu_type) self.norm = NormLayer(out_channels, norm_type=norm_type) def forward(self, x): out = self.scale_func(x) if self.use_pad: out = self.reflection_pad(out) out = self.conv2d(out) out = self.norm(out) out = self.relu(out) return out class ResidualBlock(nn.Module): """ Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html """ def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): super(ResidualBlock, self).__init__() if scale == 'none' and c_in == c_out: self.shortcut_func = lambda x: x else: self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} scale_conf = scale_config_dict[scale] self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') def forward(self, x): identity = self.shortcut_func(x) res = self.conv1(x) res = self.conv2(res) return identity + res class ParseNet(nn.Module): def __init__(self, in_size=128, out_size=128, min_feat_size=32, base_ch=64, parsing_ch=19, res_depth=10, relu_type='LeakyReLU', norm_type='bn', ch_range=[32, 256]): super().__init__() self.res_depth = res_depth act_args = {'norm_type': norm_type, 'relu_type': relu_type} min_ch, max_ch = ch_range ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 min_feat_size = min(in_size, min_feat_size) down_steps = int(np.log2(in_size // min_feat_size)) up_steps = int(np.log2(out_size // min_feat_size)) # =============== define encoder-body-decoder ==================== self.encoder = [] self.encoder.append(ConvLayer(3, base_ch, 3, 1)) head_ch = base_ch for i in range(down_steps): cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) head_ch = head_ch * 2 self.body = [] for i in range(res_depth): self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) self.decoder = [] for i in range(up_steps): cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) head_ch = head_ch // 2 self.encoder = nn.Sequential(*self.encoder) self.body = nn.Sequential(*self.body) self.decoder = nn.Sequential(*self.decoder) self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) def forward(self, x): feat = self.encoder(x) x = feat + self.body(feat) x = self.decoder(x) out_img = self.out_img_conv(x) out_mask = self.out_mask_conv(x) return out_mask, out_img