| import torch |
| import torch.nn as nn |
| from torchvision.ops import deform_conv2d |
|
|
|
|
| class DeformableConv2d(nn.Module): |
| def __init__( |
| self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False |
| ): |
|
|
| super(DeformableConv2d, self).__init__() |
|
|
| assert type(kernel_size) == tuple or type(kernel_size) == int |
|
|
| kernel_size = ( |
| kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) |
| ) |
| self.stride = stride if type(stride) == tuple else (stride, stride) |
| self.padding = padding |
|
|
| self.offset_conv = nn.Conv2d( |
| in_channels, |
| 2 * kernel_size[0] * kernel_size[1], |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=self.padding, |
| bias=True, |
| ) |
|
|
| nn.init.constant_(self.offset_conv.weight, 0.0) |
| nn.init.constant_(self.offset_conv.bias, 0.0) |
|
|
| self.modulator_conv = nn.Conv2d( |
| in_channels, |
| 1 * kernel_size[0] * kernel_size[1], |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=self.padding, |
| bias=True, |
| ) |
|
|
| nn.init.constant_(self.modulator_conv.weight, 0.0) |
| nn.init.constant_(self.modulator_conv.bias, 0.0) |
|
|
| self.regular_conv = nn.Conv2d( |
| in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=self.padding, |
| bias=bias, |
| ) |
|
|
| def forward(self, x): |
| |
| |
|
|
| offset = self.offset_conv(x) |
| modulator = 2.0 * torch.sigmoid(self.modulator_conv(x)) |
|
|
| x = deform_conv2d( |
| input=x, |
| offset=offset, |
| weight=self.regular_conv.weight, |
| bias=self.regular_conv.bias, |
| padding=self.padding, |
| mask=modulator, |
| stride=self.stride, |
| ) |
| return x |
|
|