Rerender / gmflow_module /gmflow /trident_conv.py
Anonymous-sub's picture
merge (#1)
251e479
# Copyright (c) Facebook, Inc. and its affiliates.
# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
class MultiScaleTridentConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
strides=1,
paddings=0,
dilations=1,
dilation=1,
groups=1,
num_branch=1,
test_branch_idx=-1,
bias=False,
norm=None,
activation=None,
):
super(MultiScaleTridentConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.num_branch = num_branch
self.stride = _pair(stride)
self.groups = groups
self.with_bias = bias
self.dilation = dilation
if isinstance(paddings, int):
paddings = [paddings] * self.num_branch
if isinstance(dilations, int):
dilations = [dilations] * self.num_branch
if isinstance(strides, int):
strides = [strides] * self.num_branch
self.paddings = [_pair(padding) for padding in paddings]
self.dilations = [_pair(dilation) for dilation in dilations]
self.strides = [_pair(stride) for stride in strides]
self.test_branch_idx = test_branch_idx
self.norm = norm
self.activation = activation
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
if self.bias is not None:
nn.init.constant_(self.bias, 0)
def forward(self, inputs):
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
assert len(inputs) == num_branch
if self.training or self.test_branch_idx == -1:
outputs = [
F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
for input, stride, padding in zip(inputs, self.strides, self.paddings)
]
else:
outputs = [
F.conv2d(
inputs[0],
self.weight,
self.bias,
self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
self.dilation,
self.groups,
)
]
if self.norm is not None:
outputs = [self.norm(x) for x in outputs]
if self.activation is not None:
outputs = [self.activation(x) for x in outputs]
return outputs