Towsif7's picture
firrst commit
59e40e1
raw
history blame
1.5 kB
"""
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
Source url: https://github.com/MarcoForte/FBA_Matting
License: MIT License
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
class Conv2d(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(Conv2d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
def forward(self, x):
# return super(Conv2d, self).forward(x)
weight = self.weight
weight_mean = (
weight.mean(dim=1, keepdim=True)
.mean(dim=2, keepdim=True)
.mean(dim=3, keepdim=True)
)
weight = weight - weight_mean
# std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
std = (
torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
-1, 1, 1, 1
)
+ 1e-5
)
weight = weight / std.expand_as(weight)
return F.conv2d(
x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
def BatchNorm2d(num_features):
return nn.GroupNorm(num_channels=num_features, num_groups=32)