text2-video-zero / annotator /uniformer /mmcv /ops /modulated_deform_conv.py
camenduru's picture
thanks to Text2Video-Zero team ❤
b944fa1
raw
history blame
10.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from annotator.uniformer.mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log
ext_module = ext_loader.load_ext(
'_ext',
['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
class ModulatedDeformConv2dFunction(Function):
@staticmethod
def symbolic(g, input, offset, mask, weight, bias, stride, padding,
dilation, groups, deform_groups):
input_tensors = [input, offset, mask, weight]
if bias is not None:
input_tensors.append(bias)
return g.op(
'mmcv::MMCVModulatedDeformConv2d',
*input_tensors,
stride_i=stride,
padding_i=padding,
dilation_i=dilation,
groups_i=groups,
deform_groups_i=deform_groups)
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deform_groups=1):
if input is not None and input.dim() != 4:
raise ValueError(
f'Expected 4D tensor as input, got {input.dim()}D tensor \
instead.')
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deform_groups = deform_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(0) # fake tensor
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of model (float32), but "offset" is cast
# to float16 by nn.Conv2d automatically, leading to the type
# mismatch with input (when it is float32) or weight.
# The flag for whether to use fp16 or amp is the type of "offset",
# we cast weight and input to temporarily support fp16 and amp
# whatever the pytorch version is.
input = input.type_as(offset)
weight = weight.type_as(input)
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
ext_module.modulated_deform_conv_forward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
output,
ctx._bufs[1],
kernel_h=weight.size(2),
kernel_w=weight.size(3),
stride_h=ctx.stride[0],
stride_w=ctx.stride[1],
pad_h=ctx.padding[0],
pad_w=ctx.padding[1],
dilation_h=ctx.dilation[0],
dilation_w=ctx.dilation[1],
group=ctx.groups,
deformable_group=ctx.deform_groups,
with_bias=ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
grad_output = grad_output.contiguous()
ext_module.modulated_deform_conv_backward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
ctx._bufs[1],
grad_input,
grad_weight,
grad_bias,
grad_offset,
grad_mask,
grad_output,
kernel_h=weight.size(2),
kernel_w=weight.size(3),
stride_h=ctx.stride[0],
stride_w=ctx.stride[1],
pad_h=ctx.padding[0],
pad_w=ctx.padding[1],
dilation_h=ctx.dilation[0],
dilation_w=ctx.dilation[1],
group=ctx.groups,
deformable_group=ctx.deform_groups,
with_bias=ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None)
@staticmethod
def _output_size(ctx, input, weight):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = ctx.padding[d]
kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = ctx.stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
'convolution input is too small (output would be ' +
'x'.join(map(str, output_size)) + ')')
return output_size
modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
class ModulatedDeformConv2d(nn.Module):
@deprecated_api_warning({'deformable_groups': 'deform_groups'},
cls_name='ModulatedDeformConv2d')
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deform_groups=1,
bias=True):
super(ModulatedDeformConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.init_weights()
def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
@CONV_LAYERS.register_module('DCNv2')
class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv
layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=True)
self.init_weights()
def init_weights(self):
super(ModulatedDeformConv2dPack, self).init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, ModulatedDeformConvPack
# loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
print_log(
f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
'version 2.',
logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)