from contextlib import contextmanager from math import sqrt, log import torch import torch.nn as nn # import warnings # warnings.simplefilter('ignore') class BaseModule(nn.Module): def __init__(self): self.act_fn = None super(BaseModule, self).__init__() def selu_init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d) and m.weight.requires_grad: m.weight.data.normal_(0.0, 1.0 / sqrt(m.weight.numel())) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad: m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear) and m.weight.requires_grad: m.weight.data.normal_(0, 1.0 / sqrt(m.weight.numel())) m.bias.data.zero_() def initialize_weights_xavier_uniform(self): for m in self.modules(): if isinstance(m, nn.Conv2d) and m.weight.requires_grad: # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') nn.init.xavier_uniform_(m.weight) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad: m.weight.data.fill_(1) m.bias.data.zero_() def load_state_dict(self, state_dict, strict=True, self_state=False): own_state = self_state if self_state else self.state_dict() for name, param in state_dict.items(): if name in own_state: try: own_state[name].copy_(param.data) except Exception as e: print("Parameter {} fails to load.".format(name)) print("-----------------------------------------") print(e) else: print("Parameter {} is not in the model. ".format(name)) @contextmanager def set_activation_inplace(self): if hasattr(self, 'act_fn') and hasattr(self.act_fn, 'inplace'): # save memory self.act_fn.inplace = True yield self.act_fn.inplace = False else: yield def total_parameters(self): total = sum([i.numel() for i in self.parameters()]) trainable = sum([i.numel() for i in self.parameters() if i.requires_grad]) print("Total parameters : {}. Trainable parameters : {}".format(total, trainable)) return total def forward(self, *x): raise NotImplementedError class ResidualFixBlock(BaseModule): def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, groups=1, activation=nn.SELU(), conv=nn.Conv2d): super(ResidualFixBlock, self).__init__() self.act_fn = activation self.m = nn.Sequential( conv(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation, groups=groups), activation, # conv(out_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2, dilation=1, groups=groups), conv(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation, groups=groups), ) def forward(self, x): out = self.m(x) return self.act_fn(out + x) class ConvBlock(BaseModule): def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, groups=1, activation=nn.SELU(), conv=nn.Conv2d): super(ConvBlock, self).__init__() self.m = nn.Sequential(conv(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation, groups=groups), activation) def forward(self, x): return self.m(x) class UpSampleBlock(BaseModule): def __init__(self, channels, scale, activation, atrous_rate=1, conv=nn.Conv2d): assert scale in [2, 4, 8], "Currently UpSampleBlock supports 2, 4, 8 scaling" super(UpSampleBlock, self).__init__() m = nn.Sequential( conv(channels, 4 * channels, kernel_size=3, padding=atrous_rate, dilation=atrous_rate), activation, nn.PixelShuffle(2) ) self.m = nn.Sequential(*[m for _ in range(int(log(scale, 2)))]) def forward(self, x): return self.m(x) class SpatialChannelSqueezeExcitation(BaseModule): # https://arxiv.org/abs/1709.01507 # https://arxiv.org/pdf/1803.02579v1.pdf def __init__(self, in_channel, reduction=16, activation=nn.ReLU()): super(SpatialChannelSqueezeExcitation, self).__init__() linear_nodes = max(in_channel // reduction, 4) # avoid only 1 node case self.avg_pool = nn.AdaptiveAvgPool2d(1) self.channel_excite = nn.Sequential( # check the paper for the number 16 in reduction. It is selected by experiment. nn.Linear(in_channel, linear_nodes), activation, nn.Linear(linear_nodes, in_channel), nn.Sigmoid() ) self.spatial_excite = nn.Sequential( nn.Conv2d(in_channel, 1, kernel_size=1, stride=1, padding=0, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, h, w = x.size() # channel = self.avg_pool(x).view(b, c) # channel = F.avg_pool2d(x, kernel_size=(h,w)).view(b,c) # used for porting to other frameworks cSE = self.channel_excite(channel).view(b, c, 1, 1) x_cSE = torch.mul(x, cSE) # spatial sSE = self.spatial_excite(x) x_sSE = torch.mul(x, sSE) # return x_sSE return torch.add(x_cSE, x_sSE) class PartialConv(nn.Module): # reference: # Image Inpainting for Irregular Holes Using Partial Convolutions # http://masc.cs.gmu.edu/wiki/partialconv/show?time=2018-05-24+21%3A41%3A10 # https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/net.py # https://github.com/SeitaroShinagawa/chainer-partial_convolution_image_inpainting/blob/master/common/net.py # partial based padding # https: // github.com / NVIDIA / partialconv / blob / master / models / pd_resnet.py def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(PartialConv, self).__init__() self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv = nn.Conv2d(1, 1, kernel_size, stride, padding, dilation, groups, bias=False) self.window_size = self.mask_conv.kernel_size[0] * self.mask_conv.kernel_size[1] torch.nn.init.constant_(self.mask_conv.weight, 1.0) for param in self.mask_conv.parameters(): param.requires_grad = False def forward(self, x): output = self.feature_conv(x) if self.feature_conv.bias is not None: output_bias = self.feature_conv.bias.view(1, -1, 1, 1).expand_as(output) else: output_bias = torch.zeros_like(output, device=x.device) with torch.no_grad(): ones = torch.ones(1, 1, x.size(2), x.size(3), device=x.device) output_mask = self.mask_conv(ones) output_mask = self.window_size / output_mask output = (output - output_bias) * output_mask + output_bias return output