|
#include <ATen/Config.h> |
|
#include <ATen/core/DimVector.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <ATen/native/cuda/CuFFTUtils.h> |
|
#include <ATen/native/utils/ParamsHash.h> |
|
#include <c10/util/accumulate.h> |
|
#include <c10/util/irange.h> |
|
|
|
#include <cufft.h> |
|
#include <cufftXt.h> |
|
|
|
#include <limits> |
|
#include <list> |
|
#include <sstream> |
|
#include <stdexcept> |
|
#include <string> |
|
#include <unordered_map> |
|
|
|
namespace at { namespace native { namespace detail { |
|
|
|
|
|
enum class CuFFTTransformType : int8_t { |
|
C2C, |
|
R2C, |
|
C2R, |
|
}; |
|
|
|
|
|
|
|
|
|
struct CuFFTParams |
|
{ |
|
int64_t signal_ndim_; |
|
|
|
int64_t sizes_[max_rank + 1]; |
|
int64_t input_strides_[max_rank + 1]; |
|
int64_t output_strides_[max_rank + 1]; |
|
CuFFTTransformType fft_type_; |
|
ScalarType value_type_; |
|
|
|
CuFFTParams() = default; |
|
|
|
CuFFTParams(IntArrayRef in_strides, IntArrayRef out_strides, |
|
IntArrayRef signal_sizes, CuFFTTransformType fft_type, ScalarType value_type) { |
|
|
|
memset(this, 0, sizeof(*this)); |
|
signal_ndim_ = signal_sizes.size() - 1; |
|
fft_type_ = fft_type; |
|
value_type_ = value_type; |
|
|
|
TORCH_INTERNAL_ASSERT(in_strides.size() == signal_sizes.size()); |
|
TORCH_INTERNAL_ASSERT(out_strides.size() == signal_sizes.size()); |
|
TORCH_INTERNAL_ASSERT(1 <= signal_ndim_ && signal_ndim_ <= max_rank); |
|
|
|
std::copy(signal_sizes.cbegin(), signal_sizes.cend(), sizes_); |
|
std::copy(in_strides.cbegin(), in_strides.cend(), input_strides_); |
|
std::copy(out_strides.cbegin(), out_strides.cend(), output_strides_); |
|
} |
|
}; |
|
|
|
static_assert(std::is_trivial<CuFFTParams>::value, ""); |
|
|
|
|
|
inline bool cufft_complex_input(CuFFTTransformType type) { |
|
switch (type) { |
|
case CuFFTTransformType::C2C: |
|
case CuFFTTransformType::C2R: |
|
return true; |
|
|
|
case CuFFTTransformType::R2C: |
|
return false; |
|
} |
|
TORCH_INTERNAL_ASSERT(false); |
|
} |
|
|
|
|
|
inline bool cufft_complex_output(CuFFTTransformType type) { |
|
switch (type) { |
|
case CuFFTTransformType::C2C: |
|
case CuFFTTransformType::R2C: |
|
return true; |
|
|
|
case CuFFTTransformType::C2R: |
|
return false; |
|
} |
|
TORCH_INTERNAL_ASSERT(false); |
|
} |
|
|
|
|
|
inline CuFFTTransformType GetCuFFTTransformType(bool complex_input, bool complex_output) { |
|
if (complex_input && complex_output) { |
|
return CuFFTTransformType::C2C; |
|
} else if (complex_input && !complex_output) { |
|
return CuFFTTransformType::C2R; |
|
} else if (!complex_input && complex_output) { |
|
return CuFFTTransformType::R2C; |
|
} |
|
TORCH_INTERNAL_ASSERT(false, "Real to real FFTs are not supported"); |
|
} |
|
|
|
|
|
class CuFFTHandle { |
|
::cufftHandle handle_; |
|
public: |
|
|
|
CuFFTHandle() { |
|
CUFFT_CHECK(cufftCreate(&handle_)); |
|
} |
|
|
|
::cufftHandle & get() { return handle_; } |
|
const ::cufftHandle & get() const { return handle_; } |
|
|
|
~CuFFTHandle() { |
|
|
|
#if !defined(USE_ROCM) |
|
cufftDestroy(handle_); |
|
#endif |
|
} |
|
}; |
|
|
|
__forceinline__ |
|
static bool is_pow_of_two(int64_t x) { |
|
return (x & (x - 1)) == 0; |
|
} |
|
|
|
#if defined(USE_ROCM) |
|
using cufft_size_type = int; |
|
#else |
|
using cufft_size_type = long long int; |
|
#endif |
|
|
|
using CuFFTDimVector = c10::SmallVector<cufft_size_type, at::kDimVectorStaticSize>; |
|
|
|
|
|
|
|
struct CuFFTDataLayout { |
|
CuFFTDimVector embed; |
|
cufft_size_type stride, dist; |
|
bool must_clone, simple; |
|
}; |
|
|
|
|
|
|
|
|
|
inline CuFFTDataLayout cufft_simple_embed(IntArrayRef sizes, bool onesided) { |
|
CuFFTDataLayout layout; |
|
layout.simple = true; |
|
layout.must_clone = false; |
|
layout.embed.assign(sizes.cbegin() + 1, sizes.cend()); |
|
if (onesided) { |
|
layout.embed.back() = sizes.back() / 2 + 1; |
|
} |
|
layout.stride = 1; |
|
layout.dist = 1; |
|
for (const auto& len : layout.embed) { |
|
layout.dist *= len; |
|
} |
|
return layout; |
|
} |
|
|
|
|
|
|
|
|
|
inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bool onesided) { |
|
const auto signal_ndim = strides.size() - 1; |
|
CuFFTDataLayout layout; |
|
auto last_stride = strides[signal_ndim]; |
|
layout.must_clone = (last_stride <= 0); |
|
|
|
const auto last_dim_size = onesided ? |
|
sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim]; |
|
const auto signal_numel = c10::multiply_integers(sizes.slice(1, sizes.size() - 2)) * last_dim_size; |
|
|
|
|
|
|
|
if (sizes[0] == 1) { |
|
layout.dist = signal_numel; |
|
} else if (strides[0] == 0) { |
|
layout.must_clone = true; |
|
} else { |
|
layout.dist = strides[0]; |
|
} |
|
|
|
|
|
layout.embed.resize(signal_ndim); |
|
for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) { |
|
auto stride = strides[i]; |
|
if (sizes[i] == 1) { |
|
layout.embed[i] = 1; |
|
} else if (stride > 0 && stride % last_stride == 0) { |
|
layout.embed[i] = stride / last_stride; |
|
last_stride = stride; |
|
} else { |
|
layout.must_clone = true; |
|
} |
|
} |
|
|
|
if (layout.must_clone) { |
|
|
|
layout = cufft_simple_embed(sizes, onesided); |
|
layout.must_clone = true; |
|
} else { |
|
layout.embed[0] = sizes[1]; |
|
layout.stride = strides[signal_ndim]; |
|
|
|
layout.simple = [&] { |
|
for (const auto i : c10::irange(1, signal_ndim - 1)) { |
|
if (layout.embed[i] != sizes[i + 1]) { |
|
return false; |
|
} |
|
} |
|
|
|
return (layout.stride == 1 && layout.dist == signal_numel && |
|
layout.embed.back() == last_dim_size); |
|
}(); |
|
} |
|
return layout; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CuFFTConfig { |
|
public: |
|
|
|
|
|
|
|
|
|
CuFFTConfig(const CuFFTConfig&) = delete; |
|
CuFFTConfig& operator=(CuFFTConfig const&) = delete; |
|
|
|
explicit CuFFTConfig(const CuFFTParams& params): |
|
CuFFTConfig( |
|
IntArrayRef(params.input_strides_, params.signal_ndim_ + 1), |
|
IntArrayRef(params.output_strides_, params.signal_ndim_ + 1), |
|
IntArrayRef(params.sizes_, params.signal_ndim_ + 1), |
|
params.fft_type_, |
|
params.value_type_) {} |
|
|
|
|
|
|
|
CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides, |
|
IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype): |
|
fft_type_(fft_type), value_type_(dtype) { |
|
|
|
|
|
CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end()); |
|
|
|
|
|
const int64_t batch = sizes[0]; |
|
const int64_t signal_ndim = sizes.size() - 1; |
|
|
|
|
|
|
|
|
|
|
|
#if defined(USE_ROCM) |
|
|
|
clone_input = true; |
|
#else |
|
clone_input = false; |
|
#endif |
|
|
|
|
|
|
|
|
|
if (dtype == ScalarType::Half) { |
|
|
|
auto dev_prop = at::cuda::getCurrentDeviceProperties(); |
|
TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3), |
|
"cuFFT doesn't support signals of half type with compute " |
|
"capability less than SM_53, but the device containing input half " |
|
"tensor only has SM_", dev_prop->major, dev_prop->minor); |
|
for (const auto i : c10::irange(signal_ndim)) { |
|
TORCH_CHECK(is_pow_of_two(sizes[i + 1]), |
|
"cuFFT only supports dimensions whose sizes are powers of two when" |
|
" computing in half precision, but got a signal size of", |
|
sizes.slice(1)); |
|
} |
|
clone_input |= in_strides.back() != 1; |
|
} |
|
|
|
CuFFTDataLayout in_layout; |
|
if (clone_input) { |
|
in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R); |
|
} else { |
|
in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R); |
|
} |
|
auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C); |
|
TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding"); |
|
clone_input |= in_layout.must_clone; |
|
|
|
|
|
|
|
|
|
|
|
const bool simple_layout = in_layout.simple && out_layout.simple; |
|
|
|
#if defined(USE_ROCM) |
|
hipfftType exec_type = [&]{ |
|
if (dtype == kFloat) { |
|
switch (fft_type) { |
|
case CuFFTTransformType::C2C: return HIPFFT_C2C; |
|
case CuFFTTransformType::R2C: return HIPFFT_R2C; |
|
case CuFFTTransformType::C2R: return HIPFFT_C2R; |
|
} |
|
} else if (dtype == kDouble) { |
|
switch (fft_type) { |
|
case CuFFTTransformType::C2C: return HIPFFT_Z2Z; |
|
case CuFFTTransformType::R2C: return HIPFFT_D2Z; |
|
case CuFFTTransformType::C2R: return HIPFFT_Z2D; |
|
} |
|
} |
|
TORCH_CHECK(false, "hipFFT doesn't support transforms of type: ", dtype); |
|
}(); |
|
#else |
|
cudaDataType itype, otype, exec_type; |
|
const auto complex_input = cufft_complex_input(fft_type); |
|
const auto complex_output = cufft_complex_output(fft_type); |
|
if (dtype == ScalarType::Float) { |
|
itype = complex_input ? CUDA_C_32F : CUDA_R_32F; |
|
otype = complex_output ? CUDA_C_32F : CUDA_R_32F; |
|
exec_type = CUDA_C_32F; |
|
} else if (dtype == ScalarType::Double) { |
|
itype = complex_input ? CUDA_C_64F : CUDA_R_64F; |
|
otype = complex_output ? CUDA_C_64F : CUDA_R_64F; |
|
exec_type = CUDA_C_64F; |
|
} else if (dtype == ScalarType::Half) { |
|
itype = complex_input ? CUDA_C_16F : CUDA_R_16F; |
|
otype = complex_output ? CUDA_C_16F : CUDA_R_16F; |
|
exec_type = CUDA_C_16F; |
|
} else { |
|
TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype); |
|
} |
|
#endif |
|
|
|
|
|
CUFFT_CHECK(cufftSetAutoAllocation(plan(), 0)); |
|
|
|
size_t ws_size_t; |
|
|
|
|
|
if (simple_layout) { |
|
|
|
|
|
|
|
|
|
|
|
#if defined(USE_ROCM) |
|
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(), |
|
nullptr, 1, 1, |
|
nullptr, 1, 1, |
|
exec_type, batch, &ws_size_t)); |
|
#else |
|
CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(), |
|
nullptr, 1, 1, itype, |
|
nullptr, 1, 1, otype, |
|
batch, &ws_size_t, exec_type)); |
|
#endif |
|
} else { |
|
#if defined(USE_ROCM) |
|
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(), |
|
in_layout.embed.data(), in_layout.stride, in_layout.dist, |
|
out_layout.embed.data(), out_layout.stride, out_layout.dist, |
|
exec_type, batch, &ws_size_t)); |
|
#else |
|
CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(), |
|
in_layout.embed.data(), in_layout.stride, in_layout.dist, itype, |
|
out_layout.embed.data(), out_layout.stride, out_layout.dist, otype, |
|
batch, &ws_size_t, exec_type)); |
|
#endif |
|
} |
|
ws_size = static_cast<int64_t>(ws_size_t); |
|
} |
|
|
|
const cufftHandle &plan() const { return plan_ptr.get(); } |
|
|
|
CuFFTTransformType transform_type() const { return fft_type_; } |
|
ScalarType data_type() const { return value_type_; } |
|
bool should_clone_input() const { return clone_input; } |
|
int64_t workspace_size() const { return ws_size; } |
|
|
|
private: |
|
CuFFTHandle plan_ptr; |
|
bool clone_input; |
|
int64_t ws_size; |
|
CuFFTTransformType fft_type_; |
|
ScalarType value_type_; |
|
}; |
|
|
|
#if (defined(CUDA_VERSION) && CUDA_VERSION < 10000) || defined(USE_ROCM) |
|
|
|
|
|
constexpr int64_t CUFFT_MAX_PLAN_NUM = 1023; |
|
constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM; |
|
#else |
|
constexpr int64_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<int64_t>::max(); |
|
|
|
|
|
|
|
constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = 4096; |
|
#endif |
|
static_assert(0 <= CUFFT_MAX_PLAN_NUM && CUFFT_MAX_PLAN_NUM <= std::numeric_limits<int64_t>::max(), |
|
"CUFFT_MAX_PLAN_NUM not in size_t range"); |
|
static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM, |
|
"CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
class CuFFTParamsLRUCache { |
|
public: |
|
using kv_t = typename std::pair<CuFFTParams, CuFFTConfig>; |
|
using map_t = typename std::unordered_map<std::reference_wrapper<CuFFTParams>, |
|
typename std::list<kv_t>::iterator, |
|
ParamsHash<CuFFTParams>, |
|
ParamsEqual<CuFFTParams>>; |
|
using map_kkv_iter_t = typename map_t::iterator; |
|
|
|
|
|
CuFFTParamsLRUCache() : CuFFTParamsLRUCache(CUFFT_DEFAULT_CACHE_SIZE) {} |
|
|
|
CuFFTParamsLRUCache(int64_t max_size) { |
|
_set_max_size(max_size); |
|
} |
|
|
|
CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept : |
|
_usage_list(std::move(other._usage_list)), |
|
_cache_map(std::move(other._cache_map)), |
|
_max_size(other._max_size) {} |
|
|
|
CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept { |
|
_usage_list = std::move(other._usage_list); |
|
_cache_map = std::move(other._cache_map); |
|
_max_size = other._max_size; |
|
return *this; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
const CuFFTConfig &lookup(CuFFTParams params) { |
|
AT_ASSERT(_max_size > 0); |
|
|
|
map_kkv_iter_t map_it = _cache_map.find(params); |
|
|
|
if (map_it != _cache_map.end()) { |
|
_usage_list.splice(_usage_list.begin(), _usage_list, map_it->second); |
|
return map_it->second->second; |
|
} |
|
|
|
|
|
|
|
if (_usage_list.size() >= _max_size) { |
|
auto last = _usage_list.end(); |
|
last--; |
|
_cache_map.erase(last->first); |
|
_usage_list.pop_back(); |
|
} |
|
|
|
|
|
_usage_list.emplace_front(std::piecewise_construct, |
|
std::forward_as_tuple(params), |
|
std::forward_as_tuple(params)); |
|
auto kv_it = _usage_list.begin(); |
|
_cache_map.emplace(std::piecewise_construct, |
|
std::forward_as_tuple(kv_it->first), |
|
std::forward_as_tuple(kv_it)); |
|
return kv_it->second; |
|
} |
|
|
|
void clear() { |
|
_cache_map.clear(); |
|
_usage_list.clear(); |
|
} |
|
|
|
void resize(int64_t new_size) { |
|
_set_max_size(new_size); |
|
auto cur_size = _usage_list.size(); |
|
if (cur_size > _max_size) { |
|
auto delete_it = _usage_list.end(); |
|
for (size_t i = 0; i < cur_size - _max_size; i++) { |
|
delete_it--; |
|
_cache_map.erase(delete_it->first); |
|
} |
|
_usage_list.erase(delete_it, _usage_list.end()); |
|
} |
|
} |
|
|
|
size_t size() const { return _cache_map.size(); } |
|
|
|
size_t max_size() const noexcept { return _max_size; } |
|
|
|
std::mutex mutex; |
|
|
|
private: |
|
|
|
void _set_max_size(int64_t new_size) { |
|
|
|
|
|
|
|
TORCH_CHECK(new_size >= 0, |
|
"cuFFT plan cache size must be non-negative, but got ", new_size); |
|
TORCH_CHECK(new_size <= CUFFT_MAX_PLAN_NUM, |
|
"cuFFT plan cache size can not be larger than ", CUFFT_MAX_PLAN_NUM, ", but got ", new_size); |
|
_max_size = static_cast<size_t>(new_size); |
|
} |
|
|
|
std::list<kv_t> _usage_list; |
|
map_t _cache_map; |
|
size_t _max_size; |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index); |
|
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size); |
|
int64_t cufft_get_plan_cache_size_impl(int64_t device_index); |
|
void cufft_clear_plan_cache_impl(int64_t device_index); |
|
|
|
}}} |
|
|