|
#pragma once |
|
|
|
#include <complex> |
|
|
|
#include <c10/macros/Macros.h> |
|
|
|
#if defined(__CUDACC__) || defined(__HIPCC__) |
|
#include <thrust/complex.h> |
|
#endif |
|
|
|
C10_CLANG_DIAGNOSTIC_PUSH() |
|
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") |
|
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") |
|
#endif |
|
#if C10_CLANG_HAS_WARNING("-Wfloat-conversion") |
|
C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") |
|
#endif |
|
|
|
namespace c10 { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
struct alignas(sizeof(T) * 2) complex { |
|
using value_type = T; |
|
|
|
T real_ = T(0); |
|
T imag_ = T(0); |
|
|
|
constexpr complex() = default; |
|
C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) |
|
: real_(re), imag_(im) {} |
|
template <typename U> |
|
explicit constexpr complex(const std::complex<U>& other) |
|
: complex(other.real(), other.imag()) {} |
|
#if defined(__CUDACC__) || defined(__HIPCC__) |
|
template <typename U> |
|
explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other) |
|
: real_(other.real()), imag_(other.imag()) {} |
|
|
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
template <typename U = T> |
|
C10_HOST_DEVICE explicit constexpr complex( |
|
const std::enable_if_t<std::is_same<U, float>::value, complex<double>>& |
|
other) |
|
: real_(other.real_), imag_(other.imag_) {} |
|
template <typename U = T> |
|
C10_HOST_DEVICE constexpr complex( |
|
const std::enable_if_t<std::is_same<U, double>::value, complex<float>>& |
|
other) |
|
: real_(other.real_), imag_(other.imag_) {} |
|
|
|
constexpr complex<T>& operator=(T re) { |
|
real_ = re; |
|
imag_ = 0; |
|
return *this; |
|
} |
|
|
|
constexpr complex<T>& operator+=(T re) { |
|
real_ += re; |
|
return *this; |
|
} |
|
|
|
constexpr complex<T>& operator-=(T re) { |
|
real_ -= re; |
|
return *this; |
|
} |
|
|
|
constexpr complex<T>& operator*=(T re) { |
|
real_ *= re; |
|
imag_ *= re; |
|
return *this; |
|
} |
|
|
|
constexpr complex<T>& operator/=(T re) { |
|
real_ /= re; |
|
imag_ /= re; |
|
return *this; |
|
} |
|
|
|
template <typename U> |
|
constexpr complex<T>& operator=(const complex<U>& rhs) { |
|
real_ = rhs.real(); |
|
imag_ = rhs.imag(); |
|
return *this; |
|
} |
|
|
|
template <typename U> |
|
constexpr complex<T>& operator+=(const complex<U>& rhs) { |
|
real_ += rhs.real(); |
|
imag_ += rhs.imag(); |
|
return *this; |
|
} |
|
|
|
template <typename U> |
|
constexpr complex<T>& operator-=(const complex<U>& rhs) { |
|
real_ -= rhs.real(); |
|
imag_ -= rhs.imag(); |
|
return *this; |
|
} |
|
|
|
template <typename U> |
|
constexpr complex<T>& operator*=(const complex<U>& rhs) { |
|
|
|
T a = real_; |
|
T b = imag_; |
|
U c = rhs.real(); |
|
U d = rhs.imag(); |
|
real_ = a * c - b * d; |
|
imag_ = a * d + b * c; |
|
return *this; |
|
} |
|
|
|
#ifdef __APPLE__ |
|
#define FORCE_INLINE_APPLE __attribute__((always_inline)) |
|
#else |
|
#define FORCE_INLINE_APPLE |
|
#endif |
|
template <typename U> |
|
constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs) |
|
__ubsan_ignore_float_divide_by_zero__ { |
|
|
|
T a = real_; |
|
T b = imag_; |
|
U c = rhs.real(); |
|
U d = rhs.imag(); |
|
auto denominator = c * c + d * d; |
|
real_ = (a * c + b * d) / denominator; |
|
imag_ = (b * c - a * d) / denominator; |
|
return *this; |
|
} |
|
#undef FORCE_INLINE_APPLE |
|
|
|
template <typename U> |
|
constexpr complex<T>& operator=(const std::complex<U>& rhs) { |
|
real_ = rhs.real(); |
|
imag_ = rhs.imag(); |
|
return *this; |
|
} |
|
|
|
#if defined(__CUDACC__) || defined(__HIPCC__) |
|
template <typename U> |
|
C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) { |
|
real_ = rhs.real(); |
|
imag_ = rhs.imag(); |
|
return *this; |
|
} |
|
#endif |
|
|
|
template <typename U> |
|
explicit constexpr operator std::complex<U>() const { |
|
return std::complex<U>(std::complex<T>(real(), imag())); |
|
} |
|
|
|
#if defined(__CUDACC__) || defined(__HIPCC__) |
|
template <typename U> |
|
C10_HOST_DEVICE explicit operator thrust::complex<U>() const { |
|
return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag())); |
|
} |
|
#endif |
|
|
|
|
|
explicit constexpr operator bool() const { |
|
return real() || imag(); |
|
} |
|
|
|
C10_HOST_DEVICE constexpr T real() const { |
|
return real_; |
|
} |
|
constexpr void real(T value) { |
|
real_ = value; |
|
} |
|
constexpr T imag() const { |
|
return imag_; |
|
} |
|
constexpr void imag(T value) { |
|
imag_ = value; |
|
} |
|
}; |
|
|
|
namespace complex_literals { |
|
|
|
constexpr complex<float> operator"" _if(long double imag) { |
|
return complex<float>(0.0f, static_cast<float>(imag)); |
|
} |
|
|
|
constexpr complex<double> operator"" _id(long double imag) { |
|
return complex<double>(0.0, static_cast<double>(imag)); |
|
} |
|
|
|
constexpr complex<float> operator"" _if(unsigned long long imag) { |
|
return complex<float>(0.0f, static_cast<float>(imag)); |
|
} |
|
|
|
constexpr complex<double> operator"" _id(unsigned long long imag) { |
|
return complex<double>(0.0, static_cast<double>(imag)); |
|
} |
|
|
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator+(const complex<T>& val) { |
|
return val; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator-(const complex<T>& val) { |
|
return complex<T>(-val.real(), -val.imag()); |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) { |
|
complex<T> result = lhs; |
|
return result += rhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) { |
|
complex<T> result = lhs; |
|
return result += rhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) { |
|
return complex<T>(lhs + rhs.real(), rhs.imag()); |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) { |
|
complex<T> result = lhs; |
|
return result -= rhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) { |
|
complex<T> result = lhs; |
|
return result -= rhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) { |
|
complex<T> result = -rhs; |
|
return result += lhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) { |
|
complex<T> result = lhs; |
|
return result *= rhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) { |
|
complex<T> result = lhs; |
|
return result *= rhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) { |
|
complex<T> result = rhs; |
|
return result *= lhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) { |
|
complex<T> result = lhs; |
|
return result /= rhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) { |
|
complex<T> result = lhs; |
|
return result /= rhs; |
|
} |
|
|
|
template <typename T> |
|
constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) { |
|
complex<T> result(lhs, T()); |
|
return result /= rhs; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ |
|
typename std::enable_if_t< \ |
|
std::is_floating_point<fT>::value && std::is_integral<iT>::value, \ |
|
int> = 0 |
|
|
|
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> |
|
constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) { |
|
return a + static_cast<fT>(b); |
|
} |
|
|
|
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> |
|
constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) { |
|
return static_cast<fT>(a) + b; |
|
} |
|
|
|
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> |
|
constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) { |
|
return a - static_cast<fT>(b); |
|
} |
|
|
|
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> |
|
constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) { |
|
return static_cast<fT>(a) - b; |
|
} |
|
|
|
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> |
|
constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) { |
|
return a * static_cast<fT>(b); |
|
} |
|
|
|
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> |
|
constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) { |
|
return static_cast<fT>(a) * b; |
|
} |
|
|
|
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> |
|
constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) { |
|
return a / static_cast<fT>(b); |
|
} |
|
|
|
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION> |
|
constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) { |
|
return static_cast<fT>(a) / b; |
|
} |
|
|
|
#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION |
|
|
|
template <typename T> |
|
constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) { |
|
return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); |
|
} |
|
|
|
template <typename T> |
|
constexpr bool operator==(const complex<T>& lhs, const T& rhs) { |
|
return (lhs.real() == rhs) && (lhs.imag() == T()); |
|
} |
|
|
|
template <typename T> |
|
constexpr bool operator==(const T& lhs, const complex<T>& rhs) { |
|
return (lhs == rhs.real()) && (T() == rhs.imag()); |
|
} |
|
|
|
template <typename T> |
|
constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) { |
|
return !(lhs == rhs); |
|
} |
|
|
|
template <typename T> |
|
constexpr bool operator!=(const complex<T>& lhs, const T& rhs) { |
|
return !(lhs == rhs); |
|
} |
|
|
|
template <typename T> |
|
constexpr bool operator!=(const T& lhs, const complex<T>& rhs) { |
|
return !(lhs == rhs); |
|
} |
|
|
|
template <typename T, typename CharT, typename Traits> |
|
std::basic_ostream<CharT, Traits>& operator<<( |
|
std::basic_ostream<CharT, Traits>& os, |
|
const complex<T>& x) { |
|
return (os << static_cast<std::complex<T>>(x)); |
|
} |
|
|
|
template <typename T, typename CharT, typename Traits> |
|
std::basic_istream<CharT, Traits>& operator>>( |
|
std::basic_istream<CharT, Traits>& is, |
|
complex<T>& x) { |
|
std::complex<T> tmp; |
|
is >> tmp; |
|
x = tmp; |
|
return is; |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#if defined(__CUDACC__) || defined(__HIPCC__) |
|
namespace c10_internal { |
|
template <typename T> |
|
C10_HOST_DEVICE constexpr thrust::complex<T> |
|
cuda101bug_cast_c10_complex_to_thrust_complex(const c10::complex<T>& x) { |
|
#if defined(CUDA_VERSION) && (CUDA_VERSION < 10020) |
|
|
|
|
|
|
|
return thrust::complex<T>(x.real(), x.imag()); |
|
#else |
|
return static_cast<thrust::complex<T>>(x); |
|
#endif |
|
} |
|
} |
|
#endif |
|
|
|
namespace std { |
|
|
|
template <typename T> |
|
constexpr T real(const c10::complex<T>& z) { |
|
return z.real(); |
|
} |
|
|
|
template <typename T> |
|
constexpr T imag(const c10::complex<T>& z) { |
|
return z.imag(); |
|
} |
|
|
|
template <typename T> |
|
C10_HOST_DEVICE T abs(const c10::complex<T>& z) { |
|
#if defined(__CUDACC__) || defined(__HIPCC__) |
|
return thrust::abs( |
|
c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(z)); |
|
#else |
|
return std::abs(static_cast<std::complex<T>>(z)); |
|
#endif |
|
} |
|
|
|
#if defined(USE_ROCM) |
|
#define ROCm_Bug(x) |
|
#else |
|
#define ROCm_Bug(x) x |
|
#endif |
|
|
|
template <typename T> |
|
C10_HOST_DEVICE T arg(const c10::complex<T>& z) { |
|
return ROCm_Bug(std)::atan2(std::imag(z), std::real(z)); |
|
} |
|
|
|
#undef ROCm_Bug |
|
|
|
template <typename T> |
|
constexpr T norm(const c10::complex<T>& z) { |
|
return z.real() * z.real() + z.imag() * z.imag(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
constexpr c10::complex<T> conj(const c10::complex<T>& z) { |
|
return c10::complex<T>(z.real(), -z.imag()); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
namespace c10 { |
|
|
|
template <typename T> |
|
C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) { |
|
#if defined(__CUDACC__) || defined(__HIPCC__) |
|
return static_cast<complex<T>>(thrust::polar(r, theta)); |
|
#else |
|
|
|
|
|
return complex<T>(r * std::cos(theta), r * std::sin(theta)); |
|
#endif |
|
} |
|
|
|
} |
|
|
|
C10_CLANG_DIAGNOSTIC_POP() |
|
|
|
#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H |
|
|
|
#include <c10/util/complex_math.h> |
|
|
|
#include <c10/util/complex_utils.h> |
|
#undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H |
|
|