""" Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. Source url: https://github.com/MarcoForte/FBA_Matting License: MIT License """ import torch import torch.nn as nn from torch.nn import functional as F class Conv2d(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, ): super(Conv2d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, ) def forward(self, x): # return super(Conv2d, self).forward(x) weight = self.weight weight_mean = ( weight.mean(dim=1, keepdim=True) .mean(dim=2, keepdim=True) .mean(dim=3, keepdim=True) ) weight = weight - weight_mean # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 std = ( torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view( -1, 1, 1, 1 ) + 1e-5 ) weight = weight / std.expand_as(weight) return F.conv2d( x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) def BatchNorm2d(num_features): return nn.GroupNorm(num_channels=num_features, num_groups=32)