import torch import torch.nn as nn from torch.nn import init as init from torch.nn.modules.utils import _pair, _single import math class ModulatedDeformConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deform_groups=1, bias=True): super(ModulatedDeformConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.deform_groups = deform_groups self.with_bias = bias # enable compatibility with nn.Conv2d self.transposed = False self.output_padding = _single(0) self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.init_weights() def init_weights(self): n = self.in_channels for k in self.kernel_size: n *= k stdv = 1. / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.zero_() if hasattr(self, 'conv_offset'): self.conv_offset.weight.data.zero_() self.conv_offset.bias.data.zero_() def forward(self, x, offset, mask): pass