odor-detection / tests /test_ms_deform_attn.py
mathiaszinnen's picture
Initialize app
3e99b05
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import, division, print_function
import unittest
import torch
import torch.nn.functional as F
from torch.autograd import gradcheck
from detrex.layers.multi_scale_deform_attn import MultiScaleDeformableAttnFunction
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
class TestMsDeformAttn(unittest.TestCase):
def ms_deform_attn_core_pytorch(
self, value, value_spatial_shapes, sampling_locations, attention_weights
):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(N_, M_ * D_, Lq_)
)
return output.transpose(1, 2).contiguous()
def check_gradient_numerical(
self, channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True
):
value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
im2col_step = 2
func = MultiScaleDeformableAttnFunction.apply
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
gradok = gradcheck(
func,
(
value.double(),
shapes,
level_start_index,
sampling_locations.double(),
attention_weights.double(),
im2col_step,
),
)
return gradok
@torch.no_grad()
def test_forward_equal_with_pytorch_double(self):
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
im2col_step = 2
output_pytorch = (
self.ms_deform_attn_core_pytorch(
value.double(), shapes, sampling_locations.double(), attention_weights.double()
)
.detach()
.cpu()
)
output_cuda = (
MultiScaleDeformableAttnFunction.apply(
value.double(),
shapes,
level_start_index,
sampling_locations.double(),
attention_weights.double(),
im2col_step,
)
.detach()
.cpu()
)
self.assertTrue(torch.allclose(output_cuda, output_pytorch))
def test_gradient_numerical(self):
for channels in [30, 32, 64, 71, 1025]:
self.assertTrue(self.check_gradient_numerical(channels, True, True, True))