Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# InternImage | |
# Copyright (c) 2022 OpenGVLab | |
# Licensed under The MIT License [see LICENSE for details] | |
# -------------------------------------------------------- | |
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import warnings | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.autograd import Function | |
from torch.autograd.function import once_differentiable | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
from torch.nn.init import xavier_uniform_, constant_ | |
class DCNv3Function(Function): | |
def forward( | |
ctx, input, offset, mask, | |
kernel_h, kernel_w, stride_h, stride_w, | |
pad_h, pad_w, dilation_h, dilation_w, | |
group, group_channels, offset_scale, im2col_step): | |
ctx.kernel_h = kernel_h | |
ctx.kernel_w = kernel_w | |
ctx.stride_h = stride_h | |
ctx.stride_w = stride_w | |
ctx.pad_h = pad_h | |
ctx.pad_w = pad_w | |
ctx.dilation_h = dilation_h | |
ctx.dilation_w = dilation_w | |
ctx.group = group | |
ctx.group_channels = group_channels | |
ctx.offset_scale = offset_scale | |
ctx.im2col_step = im2col_step | |
output = _C.dcnv3_forward( | |
input, offset, mask, kernel_h, | |
kernel_w, stride_h, stride_w, pad_h, | |
pad_w, dilation_h, dilation_w, group, | |
group_channels, offset_scale, ctx.im2col_step) | |
ctx.save_for_backward(input, offset, mask) | |
return output | |
def backward(ctx, grad_output): | |
input, offset, mask = ctx.saved_tensors | |
grad_input, grad_offset, grad_mask = \ | |
_C.dcnv3_backward( | |
input, offset, mask, ctx.kernel_h, | |
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, | |
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, | |
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step) | |
return grad_input, grad_offset, grad_mask, \ | |
None, None, None, None, None, None, None, None, None, None, None, None | |
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): | |
_, H_, W_, _ = spatial_shapes | |
H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 | |
W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 | |
ref_y, ref_x = torch.meshgrid( | |
torch.linspace( | |
# pad_h + 0.5, | |
# H_ - pad_h - 0.5, | |
(dilation_h * (kernel_h - 1)) // 2 + 0.5, | |
(dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, | |
H_out, | |
dtype=torch.float32, | |
device=device), | |
torch.linspace( | |
# pad_w + 0.5, | |
# W_ - pad_w - 0.5, | |
(dilation_w * (kernel_w - 1)) // 2 + 0.5, | |
(dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, | |
W_out, | |
dtype=torch.float32, | |
device=device)) | |
ref_y = ref_y.reshape(-1)[None] / H_ | |
ref_x = ref_x.reshape(-1)[None] / W_ | |
ref = torch.stack((ref_x, ref_y), -1).reshape( | |
1, H_out, W_out, 1, 2) | |
return ref | |
def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): | |
_, H_, W_, _ = spatial_shapes | |
points_list = [] | |
x, y = torch.meshgrid( | |
torch.linspace( | |
-((dilation_w * (kernel_w - 1)) // 2), | |
-((dilation_w * (kernel_w - 1)) // 2) + | |
(kernel_w - 1) * dilation_w, kernel_w, | |
dtype=torch.float32, | |
device=device), | |
torch.linspace( | |
-((dilation_h * (kernel_h - 1)) // 2), | |
-((dilation_h * (kernel_h - 1)) // 2) + | |
(kernel_h - 1) * dilation_h, kernel_h, | |
dtype=torch.float32, | |
device=device)) | |
points_list.extend([x / W_, y / H_]) | |
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ | |
repeat(1, group, 1).permute(1, 0, 2) | |
grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) | |
return grid | |
def dcnv3_core_pytorch( | |
input, offset, mask, kernel_h, | |
kernel_w, stride_h, stride_w, pad_h, | |
pad_w, dilation_h, dilation_w, group, | |
group_channels, offset_scale): | |
# for debug and test only, | |
# need to use cuda version instead | |
input = F.pad( | |
input, | |
[0, 0, pad_h, pad_h, pad_w, pad_w]) | |
N_, H_in, W_in, _ = input.shape | |
_, H_out, W_out, _ = offset.shape | |
ref = _get_reference_points( | |
input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) | |
grid = _generate_dilation_grids( | |
input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) | |
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ | |
repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) | |
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ | |
offset * offset_scale / spatial_norm | |
P_ = kernel_h * kernel_w | |
sampling_grids = 2 * sampling_locations - 1 | |
# N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in | |
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ | |
reshape(N_*group, group_channels, H_in, W_in) | |
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2 | |
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ | |
flatten(0, 1) | |
# N_*group, group_channels, H_out*W_out, P_ | |
sampling_input_ = F.grid_sample( | |
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) | |
# (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_) | |
mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ | |
reshape(N_*group, 1, H_out*W_out, P_) | |
output = (sampling_input_ * mask).sum(-1).view(N_, | |
group*group_channels, H_out*W_out) | |
return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() | |
import warnings | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.nn.init import xavier_uniform_, constant_ | |
class to_channels_first(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return x.permute(0, 3, 1, 2) | |
class to_channels_last(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return x.permute(0, 2, 3, 1) | |
def build_norm_layer(dim, | |
norm_layer, | |
in_format='channels_last', | |
out_format='channels_last', | |
eps=1e-6): | |
layers = [] | |
if norm_layer == 'BN': | |
if in_format == 'channels_last': | |
layers.append(to_channels_first()) | |
layers.append(nn.BatchNorm2d(dim)) | |
if out_format == 'channels_last': | |
layers.append(to_channels_last()) | |
elif norm_layer == 'LN': | |
if in_format == 'channels_first': | |
layers.append(to_channels_last()) | |
layers.append(nn.LayerNorm(dim, eps=eps)) | |
if out_format == 'channels_first': | |
layers.append(to_channels_first()) | |
else: | |
raise NotImplementedError( | |
f'build_norm_layer does not support {norm_layer}') | |
return nn.Sequential(*layers) | |
def build_act_layer(act_layer): | |
if act_layer == 'ReLU': | |
return nn.ReLU(inplace=True) | |
elif act_layer == 'SiLU': | |
return nn.SiLU(inplace=True) | |
elif act_layer == 'GELU': | |
return nn.GELU() | |
raise NotImplementedError(f'build_act_layer does not support {act_layer}') | |
def _is_power_of_2(n): | |
if (not isinstance(n, int)) or (n < 0): | |
raise ValueError( | |
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) | |
return (n & (n - 1) == 0) and n != 0 | |
class CenterFeatureScaleModule(nn.Module): | |
def forward(self, | |
query, | |
center_feature_scale_proj_weight, | |
center_feature_scale_proj_bias): | |
center_feature_scale = F.linear(query, | |
weight=center_feature_scale_proj_weight, | |
bias=center_feature_scale_proj_bias).sigmoid() | |
return center_feature_scale | |
class DCNv3_pytorch(nn.Module): | |
def __init__( | |
self, | |
channels=64, | |
kernel_size=3, | |
dw_kernel_size=None, | |
stride=1, | |
pad=1, | |
dilation=1, | |
group=4, | |
offset_scale=1.0, | |
act_layer='GELU', | |
norm_layer='LN', | |
center_feature_scale=False): | |
""" | |
DCNv3 Module | |
:param channels | |
:param kernel_size | |
:param stride | |
:param pad | |
:param dilation | |
:param group | |
:param offset_scale | |
:param act_layer | |
:param norm_layer | |
""" | |
super().__init__() | |
if channels % group != 0: | |
raise ValueError( | |
f'channels must be divisible by group, but got {channels} and {group}') | |
_d_per_group = channels // group | |
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size | |
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation | |
if not _is_power_of_2(_d_per_group): | |
warnings.warn( | |
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " | |
"which is more efficient in our CUDA implementation.") | |
self.offset_scale = offset_scale | |
self.channels = channels | |
self.kernel_size = kernel_size | |
self.dw_kernel_size = dw_kernel_size | |
self.stride = stride | |
self.dilation = dilation | |
self.pad = pad | |
self.group = group | |
self.group_channels = channels // group | |
self.offset_scale = offset_scale | |
self.center_feature_scale = center_feature_scale | |
self.dw_conv = nn.Sequential( | |
nn.Conv2d( | |
channels, | |
channels, | |
kernel_size=dw_kernel_size, | |
stride=1, | |
padding=(dw_kernel_size - 1) // 2, | |
groups=channels), | |
build_norm_layer( | |
channels, | |
norm_layer, | |
'channels_first', | |
'channels_last'), | |
build_act_layer(act_layer)) | |
self.offset = nn.Linear( | |
channels, | |
group * kernel_size * kernel_size * 2) | |
self.mask = nn.Linear( | |
channels, | |
group * kernel_size * kernel_size) | |
self.input_proj = nn.Linear(channels, channels) | |
self.output_proj = nn.Linear(channels, channels) | |
self._reset_parameters() | |
if center_feature_scale: | |
self.center_feature_scale_proj_weight = nn.Parameter( | |
torch.zeros((group, channels), dtype=torch.float)) | |
self.center_feature_scale_proj_bias = nn.Parameter( | |
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) | |
self.center_feature_scale_module = CenterFeatureScaleModule() | |
def _reset_parameters(self): | |
constant_(self.offset.weight.data, 0.) | |
constant_(self.offset.bias.data, 0.) | |
constant_(self.mask.weight.data, 0.) | |
constant_(self.mask.bias.data, 0.) | |
xavier_uniform_(self.input_proj.weight.data) | |
constant_(self.input_proj.bias.data, 0.) | |
xavier_uniform_(self.output_proj.weight.data) | |
constant_(self.output_proj.bias.data, 0.) | |
def forward(self, input): | |
""" | |
:param query (N, H, W, C) | |
:return output (N, H, W, C) | |
""" | |
N, H, W, _ = input.shape | |
x = self.input_proj(input) | |
x_proj = x | |
x1 = input.permute(0, 3, 1, 2) | |
x1 = self.dw_conv(x1) | |
offset = self.offset(x1) | |
mask = self.mask(x1).reshape(N, H, W, self.group, -1) | |
mask = F.softmax(mask, -1).reshape(N, H, W, -1) | |
x = dcnv3_core_pytorch( | |
x, offset, mask, | |
self.kernel_size, self.kernel_size, | |
self.stride, self.stride, | |
self.pad, self.pad, | |
self.dilation, self.dilation, | |
self.group, self.group_channels, | |
self.offset_scale) | |
if self.center_feature_scale: | |
center_feature_scale = self.center_feature_scale_module( | |
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) | |
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels | |
center_feature_scale = center_feature_scale[..., None].repeat( | |
1, 1, 1, 1, self.channels // self.group).flatten(-2) | |
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale | |
x = self.output_proj(x) | |
return x | |
class DCNv3(nn.Module): | |
def __init__( | |
self, | |
channels=64, | |
kernel_size=3, | |
dw_kernel_size=None, | |
stride=1, | |
pad=1, | |
dilation=1, | |
group=4, | |
offset_scale=1.0, | |
act_layer='GELU', | |
norm_layer='LN', | |
center_feature_scale=False): | |
""" | |
DCNv3 Module | |
:param channels | |
:param kernel_size | |
:param stride | |
:param pad | |
:param dilation | |
:param group | |
:param offset_scale | |
:param act_layer | |
:param norm_layer | |
""" | |
super().__init__() | |
if channels % group != 0: | |
raise ValueError( | |
f'channels must be divisible by group, but got {channels} and {group}') | |
_d_per_group = channels // group | |
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size | |
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation | |
if not _is_power_of_2(_d_per_group): | |
warnings.warn( | |
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " | |
"which is more efficient in our CUDA implementation.") | |
self.offset_scale = offset_scale | |
self.channels = channels | |
self.kernel_size = kernel_size | |
self.dw_kernel_size = dw_kernel_size | |
self.stride = stride | |
self.dilation = dilation | |
self.pad = pad | |
self.group = group | |
self.group_channels = channels // group | |
self.offset_scale = offset_scale | |
self.center_feature_scale = center_feature_scale | |
self.dw_conv = nn.Sequential( | |
nn.Conv2d( | |
channels, | |
channels, | |
kernel_size=dw_kernel_size, | |
stride=1, | |
padding=(dw_kernel_size - 1) // 2, | |
groups=channels), | |
build_norm_layer( | |
channels, | |
norm_layer, | |
'channels_first', | |
'channels_last'), | |
build_act_layer(act_layer)) | |
self.offset = nn.Linear( | |
channels, | |
group * kernel_size * kernel_size * 2) | |
self.mask = nn.Linear( | |
channels, | |
group * kernel_size * kernel_size) | |
self.input_proj = nn.Linear(channels, channels) | |
self.output_proj = nn.Linear(channels, channels) | |
self._reset_parameters() | |
if center_feature_scale: | |
self.center_feature_scale_proj_weight = nn.Parameter( | |
torch.zeros((group, channels), dtype=torch.float)) | |
self.center_feature_scale_proj_bias = nn.Parameter( | |
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) | |
self.center_feature_scale_module = CenterFeatureScaleModule() | |
def _reset_parameters(self): | |
constant_(self.offset.weight.data, 0.) | |
constant_(self.offset.bias.data, 0.) | |
constant_(self.mask.weight.data, 0.) | |
constant_(self.mask.bias.data, 0.) | |
xavier_uniform_(self.input_proj.weight.data) | |
constant_(self.input_proj.bias.data, 0.) | |
xavier_uniform_(self.output_proj.weight.data) | |
constant_(self.output_proj.bias.data, 0.) | |
def forward(self, input): | |
""" | |
:param query (N, H, W, C) | |
:return output (N, H, W, C) | |
""" | |
N, H, W, _ = input.shape | |
x = self.input_proj(input) | |
x_proj = x | |
dtype = x.dtype | |
x1 = input.permute(0, 3, 1, 2) | |
x1 = self.dw_conv(x1) | |
offset = self.offset(x1) | |
mask = self.mask(x1).reshape(N, H, W, self.group, -1) | |
mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype) | |
x = DCNv3Function.apply( | |
x, offset, mask, | |
self.kernel_size, self.kernel_size, | |
self.stride, self.stride, | |
self.pad, self.pad, | |
self.dilation, self.dilation, | |
self.group, self.group_channels, | |
self.offset_scale, | |
256) | |
if self.center_feature_scale: | |
center_feature_scale = self.center_feature_scale_module( | |
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) | |
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels | |
center_feature_scale = center_feature_scale[..., None].repeat( | |
1, 1, 1, 1, self.channels // self.group).flatten(-2) | |
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale | |
x = self.output_proj(x) | |
return x | |
def create_dummy_class(klass, dependency, message=""): | |
""" | |
When a dependency of a class is not available, create a dummy class which throws ImportError | |
when used. | |
Args: | |
klass (str): name of the class. | |
dependency (str): name of the dependency. | |
message: extra message to print | |
Returns: | |
class: a class object | |
""" | |
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass) | |
if message: | |
err = err + " " + message | |
class _DummyMetaClass(type): | |
# throw error on class attribute access | |
def __getattr__(_, __): # noqa: B902 | |
raise ImportError(err) | |
class _Dummy(object, metaclass=_DummyMetaClass): | |
# throw error on constructor | |
def __init__(self, *args, **kwargs): | |
raise ImportError(err) | |
return _Dummy | |
def create_dummy_func(func, dependency, message=""): | |
""" | |
When a dependency of a function is not available, create a dummy function which throws | |
ImportError when used. | |
Args: | |
func (str): name of the function. | |
dependency (str or list[str]): name(s) of the dependency. | |
message: extra message to print | |
Returns: | |
function: a function object | |
""" | |
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func) | |
if message: | |
err = err + " " + message | |
if isinstance(dependency, (list, tuple)): | |
dependency = ",".join(dependency) | |
def _dummy(*args, **kwargs): | |
raise ImportError(err) | |
return _dummy | |
try: | |
from detrex import _C | |
except ImportError: | |
# TODO: register ops natively so there is no need to import _C. | |
_msg = "detrex is not compiled successfully, please build following the instructions!" | |
_args = ("detrex._C", _msg) | |
DCNv3 = create_dummy_class( # noqa | |
"DCNv3", *_args | |
) |