#pragma once #include #include #if defined(__CUDACC__) || defined(__HIPCC__) #include #endif C10_CLANG_DIAGNOSTIC_PUSH() #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") #endif #if C10_CLANG_HAS_WARNING("-Wfloat-conversion") C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") #endif namespace c10 { // c10::complex is an implementation of complex numbers that aims // to work on all devices supported by PyTorch // // Most of the APIs duplicates std::complex // Reference: https://en.cppreference.com/w/cpp/numeric/complex // // [NOTE: Complex Operator Unification] // Operators currently use a mix of std::complex, thrust::complex, and // c10::complex internally. The end state is that all operators will use // c10::complex internally. Until then, there may be some hacks to support all // variants. // // // [Note on Constructors] // // The APIs of constructors are mostly copied from C++ standard: // https://en.cppreference.com/w/cpp/numeric/complex/complex // // Since C++14, all constructors are constexpr in std::complex // // There are three types of constructors: // - initializing from real and imag: // `constexpr complex( const T& re = T(), const T& im = T() );` // - implicitly-declared copy constructor // - converting constructors // // Converting constructors: // - std::complex defines converting constructor between float/double/long // double, // while we define converting constructor between float/double. // - For these converting constructors, upcasting is implicit, downcasting is // explicit. // - We also define explicit casting from std::complex/thrust::complex // - Note that the conversion from thrust is not constexpr, because // thrust does not define them as constexpr ???? // // // [Operator =] // // The APIs of operator = are mostly copied from C++ standard: // https://en.cppreference.com/w/cpp/numeric/complex/operator%3D // // Since C++20, all operator= are constexpr. Although we are not building with // C++20, we also obey this behavior. // // There are three types of assign operator: // - Assign a real value from the same scalar type // - In std, this is templated as complex& operator=(const T& x) // with specialization `complex& operator=(T x)` for float/double/long // double Since we only support float and double, on will use `complex& // operator=(T x)` // - Copy assignment operator and converting assignment operator // - There is no specialization of converting assignment operators, which type // is // convertible is solely dependent on whether the scalar type is convertible // // In addition to the standard assignment, we also provide assignment operators // with std and thrust // // // [Casting operators] // // std::complex does not have casting operators. We define casting operators // casting to std::complex and thrust::complex // // // [Operator ""] // // std::complex has custom literals `i`, `if` and `il` defined in namespace // `std::literals::complex_literals`. We define our own custom literals in the // namespace `c10::complex_literals`. Our custom literals does not follow the // same behavior as in std::complex, instead, we define _if, _id to construct // float/double complex literals. // // // [real() and imag()] // // In C++20, there are two overload of these functions, one it to return the // real/imag, another is to set real/imag, they are both constexpr. We follow // this design. // // // [Operator +=,-=,*=,/=] // // Since C++20, these operators become constexpr. In our implementation, they // are also constexpr. // // There are two types of such operators: operating with a real number, or // operating with another complex number. For the operating with a real number, // the generic template form has argument type `const T &`, while the overload // for float/double/long double has `T`. We will follow the same type as // float/double/long double in std. // // [Unary operator +-] // // Since C++20, they are constexpr. We also make them expr // // [Binary operators +-*/] // // Each operator has three versions (taking + as example): // - complex + complex // - complex + real // - real + complex // // [Operator ==, !=] // // Each operator has three versions (taking == as example): // - complex == complex // - complex == real // - real == complex // // Some of them are removed on C++20, but we decide to keep them // // [Operator <<, >>] // // These are implemented by casting to std::complex // // // // TODO(@zasdfgbnm): c10::complex is not currently supported, // because: // - lots of members and functions of c10::Half are not constexpr // - thrust::complex only support float and double template struct alignas(sizeof(T) * 2) complex { using value_type = T; T real_ = T(0); T imag_ = T(0); constexpr complex() = default; C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) : real_(re), imag_(im) {} template explicit constexpr complex(const std::complex& other) : complex(other.real(), other.imag()) {} #if defined(__CUDACC__) || defined(__HIPCC__) template explicit C10_HOST_DEVICE complex(const thrust::complex& other) : real_(other.real()), imag_(other.imag()) {} // NOTE can not be implemented as follow due to ROCm bug: // explicit C10_HOST_DEVICE complex(const thrust::complex &other): // complex(other.real(), other.imag()) {} #endif // Use SFINAE to specialize casting constructor for c10::complex and // c10::complex template C10_HOST_DEVICE explicit constexpr complex( const std::enable_if_t::value, complex>& other) : real_(other.real_), imag_(other.imag_) {} template C10_HOST_DEVICE constexpr complex( const std::enable_if_t::value, complex>& other) : real_(other.real_), imag_(other.imag_) {} constexpr complex& operator=(T re) { real_ = re; imag_ = 0; return *this; } constexpr complex& operator+=(T re) { real_ += re; return *this; } constexpr complex& operator-=(T re) { real_ -= re; return *this; } constexpr complex& operator*=(T re) { real_ *= re; imag_ *= re; return *this; } constexpr complex& operator/=(T re) { real_ /= re; imag_ /= re; return *this; } template constexpr complex& operator=(const complex& rhs) { real_ = rhs.real(); imag_ = rhs.imag(); return *this; } template constexpr complex& operator+=(const complex& rhs) { real_ += rhs.real(); imag_ += rhs.imag(); return *this; } template constexpr complex& operator-=(const complex& rhs) { real_ -= rhs.real(); imag_ -= rhs.imag(); return *this; } template constexpr complex& operator*=(const complex& rhs) { // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i T a = real_; T b = imag_; U c = rhs.real(); U d = rhs.imag(); real_ = a * c - b * d; imag_ = a * d + b * c; return *this; } #ifdef __APPLE__ #define FORCE_INLINE_APPLE __attribute__((always_inline)) #else #define FORCE_INLINE_APPLE #endif template constexpr FORCE_INLINE_APPLE complex& operator/=(const complex& rhs) __ubsan_ignore_float_divide_by_zero__ { // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i T a = real_; T b = imag_; U c = rhs.real(); U d = rhs.imag(); auto denominator = c * c + d * d; real_ = (a * c + b * d) / denominator; imag_ = (b * c - a * d) / denominator; return *this; } #undef FORCE_INLINE_APPLE template constexpr complex& operator=(const std::complex& rhs) { real_ = rhs.real(); imag_ = rhs.imag(); return *this; } #if defined(__CUDACC__) || defined(__HIPCC__) template C10_HOST_DEVICE complex& operator=(const thrust::complex& rhs) { real_ = rhs.real(); imag_ = rhs.imag(); return *this; } #endif template explicit constexpr operator std::complex() const { return std::complex(std::complex(real(), imag())); } #if defined(__CUDACC__) || defined(__HIPCC__) template C10_HOST_DEVICE explicit operator thrust::complex() const { return static_cast>(thrust::complex(real(), imag())); } #endif // consistent with NumPy behavior explicit constexpr operator bool() const { return real() || imag(); } C10_HOST_DEVICE constexpr T real() const { return real_; } constexpr void real(T value) { real_ = value; } constexpr T imag() const { return imag_; } constexpr void imag(T value) { imag_ = value; } }; namespace complex_literals { constexpr complex operator"" _if(long double imag) { return complex(0.0f, static_cast(imag)); } constexpr complex operator"" _id(long double imag) { return complex(0.0, static_cast(imag)); } constexpr complex operator"" _if(unsigned long long imag) { return complex(0.0f, static_cast(imag)); } constexpr complex operator"" _id(unsigned long long imag) { return complex(0.0, static_cast(imag)); } } // namespace complex_literals template constexpr complex operator+(const complex& val) { return val; } template constexpr complex operator-(const complex& val) { return complex(-val.real(), -val.imag()); } template constexpr complex operator+(const complex& lhs, const complex& rhs) { complex result = lhs; return result += rhs; } template constexpr complex operator+(const complex& lhs, const T& rhs) { complex result = lhs; return result += rhs; } template constexpr complex operator+(const T& lhs, const complex& rhs) { return complex(lhs + rhs.real(), rhs.imag()); } template constexpr complex operator-(const complex& lhs, const complex& rhs) { complex result = lhs; return result -= rhs; } template constexpr complex operator-(const complex& lhs, const T& rhs) { complex result = lhs; return result -= rhs; } template constexpr complex operator-(const T& lhs, const complex& rhs) { complex result = -rhs; return result += lhs; } template constexpr complex operator*(const complex& lhs, const complex& rhs) { complex result = lhs; return result *= rhs; } template constexpr complex operator*(const complex& lhs, const T& rhs) { complex result = lhs; return result *= rhs; } template constexpr complex operator*(const T& lhs, const complex& rhs) { complex result = rhs; return result *= lhs; } template constexpr complex operator/(const complex& lhs, const complex& rhs) { complex result = lhs; return result /= rhs; } template constexpr complex operator/(const complex& lhs, const T& rhs) { complex result = lhs; return result /= rhs; } template constexpr complex operator/(const T& lhs, const complex& rhs) { complex result(lhs, T()); return result /= rhs; } // Define operators between integral scalars and c10::complex. std::complex does // not support this when T is a floating-point number. This is useful because it // saves a lot of "static_cast" when operate a complex and an integer. This // makes the code both less verbose and potentially more efficient. #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ typename std::enable_if_t< \ std::is_floating_point::value && std::is_integral::value, \ int> = 0 template constexpr c10::complex operator+(const c10::complex& a, const iT& b) { return a + static_cast(b); } template constexpr c10::complex operator+(const iT& a, const c10::complex& b) { return static_cast(a) + b; } template constexpr c10::complex operator-(const c10::complex& a, const iT& b) { return a - static_cast(b); } template constexpr c10::complex operator-(const iT& a, const c10::complex& b) { return static_cast(a) - b; } template constexpr c10::complex operator*(const c10::complex& a, const iT& b) { return a * static_cast(b); } template constexpr c10::complex operator*(const iT& a, const c10::complex& b) { return static_cast(a) * b; } template constexpr c10::complex operator/(const c10::complex& a, const iT& b) { return a / static_cast(b); } template constexpr c10::complex operator/(const iT& a, const c10::complex& b) { return static_cast(a) / b; } #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION template constexpr bool operator==(const complex& lhs, const complex& rhs) { return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); } template constexpr bool operator==(const complex& lhs, const T& rhs) { return (lhs.real() == rhs) && (lhs.imag() == T()); } template constexpr bool operator==(const T& lhs, const complex& rhs) { return (lhs == rhs.real()) && (T() == rhs.imag()); } template constexpr bool operator!=(const complex& lhs, const complex& rhs) { return !(lhs == rhs); } template constexpr bool operator!=(const complex& lhs, const T& rhs) { return !(lhs == rhs); } template constexpr bool operator!=(const T& lhs, const complex& rhs) { return !(lhs == rhs); } template std::basic_ostream& operator<<( std::basic_ostream& os, const complex& x) { return (os << static_cast>(x)); } template std::basic_istream& operator>>( std::basic_istream& is, complex& x) { std::complex tmp; is >> tmp; x = tmp; return is; } } // namespace c10 // std functions // // The implementation of these functions also follow the design of C++20 #if defined(__CUDACC__) || defined(__HIPCC__) namespace c10_internal { template C10_HOST_DEVICE constexpr thrust::complex cuda101bug_cast_c10_complex_to_thrust_complex(const c10::complex& x) { #if defined(CUDA_VERSION) && (CUDA_VERSION < 10020) // This is to circumvent a CUDA compilation bug. See // https://github.com/pytorch/pytorch/pull/38941 . When the bug is fixed, we // should do static_cast directly. return thrust::complex(x.real(), x.imag()); #else return static_cast>(x); #endif } } // namespace c10_internal #endif namespace std { template constexpr T real(const c10::complex& z) { return z.real(); } template constexpr T imag(const c10::complex& z) { return z.imag(); } template C10_HOST_DEVICE T abs(const c10::complex& z) { #if defined(__CUDACC__) || defined(__HIPCC__) return thrust::abs( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(z)); #else return std::abs(static_cast>(z)); #endif } #if defined(USE_ROCM) #define ROCm_Bug(x) #else #define ROCm_Bug(x) x #endif template C10_HOST_DEVICE T arg(const c10::complex& z) { return ROCm_Bug(std)::atan2(std::imag(z), std::real(z)); } #undef ROCm_Bug template constexpr T norm(const c10::complex& z) { return z.real() * z.real() + z.imag() * z.imag(); } // For std::conj, there are other versions of it: // constexpr std::complex conj( float z ); // template< class DoubleOrInteger > // constexpr std::complex conj( DoubleOrInteger z ); // constexpr std::complex conj( long double z ); // These are not implemented // TODO(@zasdfgbnm): implement them as c10::conj template constexpr c10::complex conj(const c10::complex& z) { return c10::complex(z.real(), -z.imag()); } // Thrust does not have complex --> complex version of thrust::proj, // so this function is not implemented at c10 right now. // TODO(@zasdfgbnm): implement it by ourselves // There is no c10 version of std::polar, because std::polar always // returns std::complex. Use c10::polar instead; } // namespace std namespace c10 { template C10_HOST_DEVICE complex polar(const T& r, const T& theta = T()) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::polar(r, theta)); #else // std::polar() requires r >= 0, so spell out the explicit implementation to // avoid a branch. return complex(r * std::cos(theta), r * std::sin(theta)); #endif } } // namespace c10 C10_CLANG_DIAGNOSTIC_POP() #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H // math functions are included in a separate file #include // IWYU pragma: keep // utilities for complex types #include // IWYU pragma: keep #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H