# -------------------------------------------------------- # InternImage # Copyright (c) 2022 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from .dcnv3_func import dcnv3_core_pytorch class to_channels_first(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.permute(0, 3, 1, 2) class to_channels_last(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.permute(0, 2, 3, 1) def build_norm_layer(dim, norm_layer, in_format='channels_last', out_format='channels_last', eps=1e-6): layers = [] if norm_layer == 'BN': if in_format == 'channels_last': layers.append(to_channels_first()) layers.append(nn.BatchNorm2d(dim)) if out_format == 'channels_last': layers.append(to_channels_last()) elif norm_layer == 'LN': if in_format == 'channels_first': layers.append(to_channels_last()) layers.append(nn.LayerNorm(dim, eps=eps)) if out_format == 'channels_first': layers.append(to_channels_first()) else: raise NotImplementedError( f'build_norm_layer does not support {norm_layer}') return nn.Sequential(*layers) def build_act_layer(act_layer): if act_layer == 'ReLU': return nn.ReLU(inplace=True) elif act_layer == 'SiLU': return nn.SiLU(inplace=True) elif act_layer == 'GELU': return nn.GELU() raise NotImplementedError(f'build_act_layer does not support {act_layer}') def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError( "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class DCNv3_pytorch(nn.Module): def __init__( self, channels=64, kernel_size=3, stride=1, pad=1, dilation=1, group=4, offset_scale=1.0, act_layer='GELU', norm_layer='LN'): """ DCNv3 Module :param channels :param kernel_size :param stride :param pad :param dilation :param group :param offset_scale :param act_layer :param norm_layer """ super().__init__() if channels % group != 0: raise ValueError( f'channels must be divisible by group, but got {channels} and {group}') _d_per_group = channels // group # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_group): warnings.warn( "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.offset_scale = offset_scale self.channels = channels self.kernel_size = kernel_size self.stride = stride self.dilation = 1 self.pad = pad self.group = group self.group_channels = channels // group self.offset_scale = offset_scale self.dw_conv = nn.Sequential( nn.Conv2d( channels, channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, groups=channels), build_norm_layer( channels, norm_layer, 'channels_first', 'channels_last'), build_act_layer(act_layer)) self.offset = nn.Linear( channels, group * kernel_size * kernel_size * 2) self.mask = nn.Linear( channels, group * kernel_size * kernel_size) self.input_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels) self._reset_parameters() def _reset_parameters(self): constant_(self.offset.weight.data, 0.) constant_(self.offset.bias.data, 0.) constant_(self.mask.weight.data, 0.) constant_(self.mask.bias.data, 0.) xavier_uniform_(self.input_proj.weight.data) constant_(self.input_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, input): """ :param query (N, H, W, C) :return output (N, H, W, C) """ N, H, W, _ = input.shape x = self.input_proj(input) x1 = input.permute(0, 3, 1, 2) x1 = self.dw_conv(x1) offset = self.offset(x1) mask = self.mask(x1).reshape(N, H, W, self.group, -1) mask = F.softmax(mask, -1).reshape(N, H, W, -1) x = dcnv3_core_pytorch( x, offset, mask, self.kernel_size, self.kernel_size, self.stride, self.stride, self.pad, self.pad, self.dilation, self.dilation, self.group, self.group_channels, self.offset_scale) x = self.output_proj(x) return x