|
#pragma once |
|
|
|
#include <ATen/Config.h> |
|
|
|
#include <string> |
|
#include <stdexcept> |
|
#include <sstream> |
|
#include <cufft.h> |
|
#include <cufftXt.h> |
|
|
|
namespace at { namespace native { |
|
|
|
|
|
|
|
constexpr int max_rank = 3; |
|
|
|
static inline std::string _cudaGetErrorEnum(cufftResult error) |
|
{ |
|
switch (error) |
|
{ |
|
case CUFFT_SUCCESS: |
|
return "CUFFT_SUCCESS"; |
|
case CUFFT_INVALID_PLAN: |
|
return "CUFFT_INVALID_PLAN"; |
|
case CUFFT_ALLOC_FAILED: |
|
return "CUFFT_ALLOC_FAILED"; |
|
case CUFFT_INVALID_TYPE: |
|
return "CUFFT_INVALID_TYPE"; |
|
case CUFFT_INVALID_VALUE: |
|
return "CUFFT_INVALID_VALUE"; |
|
case CUFFT_INTERNAL_ERROR: |
|
return "CUFFT_INTERNAL_ERROR"; |
|
case CUFFT_EXEC_FAILED: |
|
return "CUFFT_EXEC_FAILED"; |
|
case CUFFT_SETUP_FAILED: |
|
return "CUFFT_SETUP_FAILED"; |
|
case CUFFT_INVALID_SIZE: |
|
return "CUFFT_INVALID_SIZE"; |
|
case CUFFT_UNALIGNED_DATA: |
|
return "CUFFT_UNALIGNED_DATA"; |
|
case CUFFT_INCOMPLETE_PARAMETER_LIST: |
|
return "CUFFT_INCOMPLETE_PARAMETER_LIST"; |
|
case CUFFT_INVALID_DEVICE: |
|
return "CUFFT_INVALID_DEVICE"; |
|
case CUFFT_PARSE_ERROR: |
|
return "CUFFT_PARSE_ERROR"; |
|
case CUFFT_NO_WORKSPACE: |
|
return "CUFFT_NO_WORKSPACE"; |
|
case CUFFT_NOT_IMPLEMENTED: |
|
return "CUFFT_NOT_IMPLEMENTED"; |
|
#if !defined(USE_ROCM) |
|
case CUFFT_LICENSE_ERROR: |
|
return "CUFFT_LICENSE_ERROR"; |
|
#endif |
|
case CUFFT_NOT_SUPPORTED: |
|
return "CUFFT_NOT_SUPPORTED"; |
|
default: |
|
std::ostringstream ss; |
|
ss << "unknown error " << error; |
|
return ss.str(); |
|
} |
|
} |
|
|
|
static inline void CUFFT_CHECK(cufftResult error) |
|
{ |
|
if (error != CUFFT_SUCCESS) { |
|
std::ostringstream ss; |
|
ss << "cuFFT error: " << _cudaGetErrorEnum(error); |
|
AT_ERROR(ss.str()); |
|
} |
|
} |
|
|
|
}} |
|
|