|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import print_function |
|
from __future__ import division |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): |
|
_, H_, W_, _ = spatial_shapes |
|
H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 |
|
W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 |
|
|
|
ref_y, ref_x = torch.meshgrid( |
|
torch.linspace( |
|
|
|
|
|
(dilation_h * (kernel_h - 1)) // 2 + 0.5, |
|
(dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, |
|
H_out, |
|
dtype=torch.float32, |
|
device=device), |
|
torch.linspace( |
|
|
|
|
|
(dilation_w * (kernel_w - 1)) // 2 + 0.5, |
|
(dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, |
|
W_out, |
|
dtype=torch.float32, |
|
device=device)) |
|
ref_y = ref_y.reshape(-1)[None] / H_ |
|
ref_x = ref_x.reshape(-1)[None] / W_ |
|
|
|
ref = torch.stack((ref_x, ref_y), -1).reshape( |
|
1, H_out, W_out, 1, 2) |
|
|
|
return ref |
|
|
|
|
|
def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): |
|
_, H_, W_, _ = spatial_shapes |
|
points_list = [] |
|
x, y = torch.meshgrid( |
|
torch.linspace( |
|
-((dilation_w * (kernel_w - 1)) // 2), |
|
-((dilation_w * (kernel_w - 1)) // 2) + |
|
(kernel_w - 1) * dilation_w, kernel_w, |
|
dtype=torch.float32, |
|
device=device), |
|
torch.linspace( |
|
-((dilation_h * (kernel_h - 1)) // 2), |
|
-((dilation_h * (kernel_h - 1)) // 2) + |
|
(kernel_h - 1) * dilation_h, kernel_h, |
|
dtype=torch.float32, |
|
device=device)) |
|
|
|
points_list.extend([x / W_, y / H_]) |
|
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ |
|
repeat(1, group, 1).permute(1, 0, 2) |
|
grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) |
|
|
|
return grid |
|
|
|
|
|
def dcnv3_core_pytorch( |
|
input, offset, mask, kernel_h, |
|
kernel_w, stride_h, stride_w, pad_h, |
|
pad_w, dilation_h, dilation_w, group, |
|
group_channels, offset_scale): |
|
|
|
|
|
input = F.pad( |
|
input, |
|
[0, 0, pad_h, pad_h, pad_w, pad_w]) |
|
N_, H_in, W_in, _ = input.shape |
|
_, H_out, W_out, _ = offset.shape |
|
|
|
ref = _get_reference_points( |
|
input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) |
|
grid = _generate_dilation_grids( |
|
input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) |
|
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ |
|
repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) |
|
|
|
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ |
|
offset * offset_scale / spatial_norm |
|
|
|
P_ = kernel_h * kernel_w |
|
sampling_grids = 2 * sampling_locations - 1 |
|
|
|
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ |
|
reshape(N_*group, group_channels, H_in, W_in) |
|
|
|
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ |
|
flatten(0, 1) |
|
|
|
sampling_input_ = F.grid_sample( |
|
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) |
|
|
|
|
|
mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ |
|
reshape(N_*group, 1, H_out*W_out, P_) |
|
output = (sampling_input_ * mask).sum(-1).view(N_, |
|
group*group_channels, H_out*W_out) |
|
|
|
return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() |
|
|