# coding=utf-8 # Copyright 2022 The IDEA Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from: # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import, division, print_function import unittest import torch import torch.nn.functional as F from torch.autograd import gradcheck from detrex.layers.multi_scale_deform_attn import MultiScaleDeformableAttnFunction N, M, D = 1, 2, 2 Lq, L, P = 2, 2, 2 shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) S = sum([(H * W).item() for H, W in shapes]) class TestMsDeformAttn(unittest.TestCase): def ms_deform_attn_core_pytorch( self, value, value_spatial_shapes, sampling_locations, attention_weights ): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample( value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False, ) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) output = ( (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) .sum(-1) .view(N_, M_ * D_, Lq_) ) return output.transpose(1, 2).contiguous() def check_gradient_numerical( self, channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True ): value = torch.rand(N, S, M, channels).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 func = MultiScaleDeformableAttnFunction.apply value.requires_grad = grad_value sampling_locations.requires_grad = grad_sampling_loc attention_weights.requires_grad = grad_attn_weight gradok = gradcheck( func, ( value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step, ), ) return gradok @torch.no_grad() def test_forward_equal_with_pytorch_double(self): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ( self.ms_deform_attn_core_pytorch( value.double(), shapes, sampling_locations.double(), attention_weights.double() ) .detach() .cpu() ) output_cuda = ( MultiScaleDeformableAttnFunction.apply( value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step, ) .detach() .cpu() ) self.assertTrue(torch.allclose(output_cuda, output_pytorch)) def test_gradient_numerical(self): for channels in [30, 32, 64, 71, 1025]: self.assertTrue(self.check_gradient_numerical(channels, True, True, True))