|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include "cpu/ms_deform_attn_cpu.h" |
|
|
|
#ifdef WITH_CUDA |
|
#include "cuda/ms_deform_attn_cuda.h" |
|
#endif |
|
|
|
|
|
at::Tensor |
|
ms_deform_attn_forward( |
|
const at::Tensor &value, |
|
const at::Tensor &spatial_shapes, |
|
const at::Tensor &level_start_index, |
|
const at::Tensor &sampling_loc, |
|
const at::Tensor &attn_weight, |
|
const int im2col_step) |
|
{ |
|
if (value.type().is_cuda()) |
|
{ |
|
#ifdef WITH_CUDA |
|
return ms_deform_attn_cuda_forward( |
|
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); |
|
#else |
|
AT_ERROR("Not compiled with GPU support"); |
|
#endif |
|
} |
|
AT_ERROR("Not implemented on the CPU"); |
|
} |
|
|
|
std::vector<at::Tensor> |
|
ms_deform_attn_backward( |
|
const at::Tensor &value, |
|
const at::Tensor &spatial_shapes, |
|
const at::Tensor &level_start_index, |
|
const at::Tensor &sampling_loc, |
|
const at::Tensor &attn_weight, |
|
const at::Tensor &grad_output, |
|
const int im2col_step) |
|
{ |
|
if (value.type().is_cuda()) |
|
{ |
|
#ifdef WITH_CUDA |
|
return ms_deform_attn_cuda_backward( |
|
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); |
|
#else |
|
AT_ERROR("Not compiled with GPU support"); |
|
#endif |
|
} |
|
AT_ERROR("Not implemented on the CPU"); |
|
} |
|
|
|
|