Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------------------------------ | |
# Deformable DETR | |
# Copyright (c) 2020 SenseTime. All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
# ------------------------------------------------------------------------------------------------ | |
# Modified from: | |
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py | |
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py | |
# ------------------------------------------------------------------------------------------------ | |
from __future__ import absolute_import, division, print_function | |
import unittest | |
import torch | |
import torch.nn.functional as F | |
from torch.autograd import gradcheck | |
from detrex.layers.multi_scale_deform_attn import MultiScaleDeformableAttnFunction | |
N, M, D = 1, 2, 2 | |
Lq, L, P = 2, 2, 2 | |
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() | |
level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) | |
S = sum([(H * W).item() for H, W in shapes]) | |
class TestMsDeformAttn(unittest.TestCase): | |
def ms_deform_attn_core_pytorch( | |
self, value, value_spatial_shapes, sampling_locations, attention_weights | |
): | |
# for debug and test only, | |
# need to use cuda version instead | |
N_, S_, M_, D_ = value.shape | |
_, Lq_, M_, L_, P_, _ = sampling_locations.shape | |
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) | |
sampling_grids = 2 * sampling_locations - 1 | |
sampling_value_list = [] | |
for lid_, (H_, W_) in enumerate(value_spatial_shapes): | |
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ | |
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) | |
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 | |
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) | |
# N_*M_, D_, Lq_, P_ | |
sampling_value_l_ = F.grid_sample( | |
value_l_, | |
sampling_grid_l_, | |
mode="bilinear", | |
padding_mode="zeros", | |
align_corners=False, | |
) | |
sampling_value_list.append(sampling_value_l_) | |
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) | |
attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) | |
output = ( | |
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) | |
.sum(-1) | |
.view(N_, M_ * D_, Lq_) | |
) | |
return output.transpose(1, 2).contiguous() | |
def check_gradient_numerical( | |
self, channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True | |
): | |
value = torch.rand(N, S, M, channels).cuda() * 0.01 | |
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() | |
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 | |
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) | |
im2col_step = 2 | |
func = MultiScaleDeformableAttnFunction.apply | |
value.requires_grad = grad_value | |
sampling_locations.requires_grad = grad_sampling_loc | |
attention_weights.requires_grad = grad_attn_weight | |
gradok = gradcheck( | |
func, | |
( | |
value.double(), | |
shapes, | |
level_start_index, | |
sampling_locations.double(), | |
attention_weights.double(), | |
im2col_step, | |
), | |
) | |
return gradok | |
def test_forward_equal_with_pytorch_double(self): | |
value = torch.rand(N, S, M, D).cuda() * 0.01 | |
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() | |
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 | |
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) | |
im2col_step = 2 | |
output_pytorch = ( | |
self.ms_deform_attn_core_pytorch( | |
value.double(), shapes, sampling_locations.double(), attention_weights.double() | |
) | |
.detach() | |
.cpu() | |
) | |
output_cuda = ( | |
MultiScaleDeformableAttnFunction.apply( | |
value.double(), | |
shapes, | |
level_start_index, | |
sampling_locations.double(), | |
attention_weights.double(), | |
im2col_step, | |
) | |
.detach() | |
.cpu() | |
) | |
self.assertTrue(torch.allclose(output_cuda, output_pytorch)) | |
def test_gradient_numerical(self): | |
for channels in [30, 32, 64, 71, 1025]: | |
self.assertTrue(self.check_gradient_numerical(channels, True, True, True)) | |