|
#pragma once |
|
|
|
#include <assert.h> |
|
#include <stdint.h> |
|
#include <stdexcept> |
|
#include <string> |
|
#include <type_traits> |
|
#include <utility> |
|
|
|
#include <c10/core/OptionalRef.h> |
|
#include <c10/core/ScalarType.h> |
|
#include <c10/core/SymFloat.h> |
|
#include <c10/core/SymInt.h> |
|
#include <c10/macros/Macros.h> |
|
#include <c10/util/Exception.h> |
|
#include <c10/util/Half.h> |
|
#include <c10/util/TypeCast.h> |
|
#include <c10/util/intrusive_ptr.h> |
|
|
|
C10_CLANG_DIAGNOSTIC_PUSH() |
|
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") |
|
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") |
|
#endif |
|
|
|
namespace c10 { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C10_API Scalar { |
|
public: |
|
Scalar() : Scalar(int64_t(0)) {} |
|
|
|
void destroy() { |
|
if (Tag::HAS_si == tag || Tag::HAS_sd == tag) { |
|
raw::intrusive_ptr::decref(v.p); |
|
v.p = nullptr; |
|
} |
|
} |
|
|
|
~Scalar() { |
|
destroy(); |
|
} |
|
|
|
#define DEFINE_IMPLICIT_CTOR(type, name) \ |
|
Scalar(type vv) : Scalar(vv, true) {} |
|
|
|
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR) |
|
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) |
|
|
|
#undef DEFINE_IMPLICIT_CTOR |
|
|
|
|
|
|
|
|
|
template < |
|
typename T, |
|
typename std::enable_if<std::is_same<T, bool>::value, bool>::type* = |
|
nullptr> |
|
Scalar(T vv) : tag(Tag::HAS_b) { |
|
v.i = convert<int64_t, bool>(vv); |
|
} |
|
|
|
#define DEFINE_ACCESSOR(type, name) \ |
|
type to##name() const { \ |
|
if (Tag::HAS_d == tag) { \ |
|
return checked_convert<type, double>(v.d, #type); \ |
|
} else if (Tag::HAS_z == tag) { \ |
|
return checked_convert<type, c10::complex<double>>(v.z, #type); \ |
|
} \ |
|
if (Tag::HAS_b == tag) { \ |
|
return checked_convert<type, bool>(v.i, #type); \ |
|
} else if (Tag::HAS_i == tag) { \ |
|
return checked_convert<type, int64_t>(v.i, #type); \ |
|
} else if (Tag::HAS_si == tag) { \ |
|
TORCH_CHECK(false, "tried to get " #name " out of SymInt") \ |
|
} else if (Tag::HAS_sd == tag) { \ |
|
TORCH_CHECK(false, "tried to get " #name " out of SymFloat") \ |
|
} \ |
|
TORCH_CHECK(false) \ |
|
} |
|
|
|
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR) |
|
|
|
#undef DEFINE_ACCESSOR |
|
|
|
SymInt toSymInt() const { |
|
if (Tag::HAS_si == tag) { |
|
return c10::SymInt::toSymInt(intrusive_ptr<SymIntNodeImpl>::reclaim_copy( |
|
static_cast<SymIntNodeImpl*>(v.p))); |
|
} else { |
|
return toLong(); |
|
} |
|
} |
|
|
|
SymFloat toSymFloat() const { |
|
if (Tag::HAS_sd == tag) { |
|
return c10::SymFloat::toSymFloat( |
|
intrusive_ptr<SymFloatNodeImpl>::reclaim_copy( |
|
static_cast<SymFloatNodeImpl*>(v.p))); |
|
} else { |
|
return toLong(); |
|
} |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
T to() const = delete; |
|
|
|
|
|
const void* data_ptr() const { |
|
TORCH_INTERNAL_ASSERT(!isSymbolic()); |
|
return static_cast<const void*>(&v); |
|
} |
|
|
|
bool isFloatingPoint() const { |
|
return Tag::HAS_d == tag || Tag::HAS_sd == tag; |
|
} |
|
|
|
C10_DEPRECATED_MESSAGE( |
|
"isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.") |
|
bool isIntegral() const { |
|
return Tag::HAS_i == tag || Tag::HAS_si == tag; |
|
} |
|
bool isIntegral(bool includeBool) const { |
|
return Tag::HAS_i == tag || Tag::HAS_si == tag || |
|
(includeBool && isBoolean()); |
|
} |
|
|
|
bool isComplex() const { |
|
return Tag::HAS_z == tag; |
|
} |
|
bool isBoolean() const { |
|
return Tag::HAS_b == tag; |
|
} |
|
|
|
|
|
bool isSymInt() const { |
|
return Tag::HAS_si == tag; |
|
} |
|
bool isSymFloat() const { |
|
return Tag::HAS_sd == tag; |
|
} |
|
|
|
bool isSymbolic() const { |
|
return Tag::HAS_si == tag || Tag::HAS_sd == tag; |
|
} |
|
|
|
C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) { |
|
if (&other == this) { |
|
return *this; |
|
} |
|
|
|
destroy(); |
|
moveFrom(std::move(other)); |
|
return *this; |
|
} |
|
|
|
C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) { |
|
if (&other == this) { |
|
return *this; |
|
} |
|
|
|
*this = Scalar(other); |
|
return *this; |
|
} |
|
|
|
Scalar operator-() const; |
|
Scalar conj() const; |
|
Scalar log() const; |
|
|
|
template < |
|
typename T, |
|
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0> |
|
bool equal(T num) const { |
|
if (isComplex()) { |
|
TORCH_INTERNAL_ASSERT(!isSymbolic()); |
|
auto val = v.z; |
|
return (val.real() == num) && (val.imag() == T()); |
|
} else if (isFloatingPoint()) { |
|
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality"); |
|
return v.d == num; |
|
} else if (isIntegral(false)) { |
|
TORCH_CHECK(!isSymbolic(), "NYI SymInt equality"); |
|
return v.i == num; |
|
} else if (isBoolean()) { |
|
|
|
TORCH_INTERNAL_ASSERT(!isSymbolic()); |
|
return false; |
|
} else { |
|
TORCH_INTERNAL_ASSERT(false); |
|
} |
|
} |
|
|
|
template < |
|
typename T, |
|
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0> |
|
bool equal(T num) const { |
|
if (isComplex()) { |
|
TORCH_INTERNAL_ASSERT(!isSymbolic()); |
|
return v.z == num; |
|
} else if (isFloatingPoint()) { |
|
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality"); |
|
return (v.d == num.real()) && (num.imag() == T()); |
|
} else if (isIntegral(false)) { |
|
TORCH_CHECK(!isSymbolic(), "NYI SymInt equality"); |
|
return (v.i == num.real()) && (num.imag() == T()); |
|
} else if (isBoolean()) { |
|
|
|
TORCH_INTERNAL_ASSERT(!isSymbolic()); |
|
return false; |
|
} else { |
|
TORCH_INTERNAL_ASSERT(false); |
|
} |
|
} |
|
|
|
bool equal(bool num) const { |
|
if (isBoolean()) { |
|
TORCH_INTERNAL_ASSERT(!isSymbolic()); |
|
return static_cast<bool>(v.i) == num; |
|
} else { |
|
return false; |
|
} |
|
} |
|
|
|
ScalarType type() const { |
|
if (isComplex()) { |
|
return ScalarType::ComplexDouble; |
|
} else if (isFloatingPoint()) { |
|
return ScalarType::Double; |
|
} else if (isIntegral(false)) { |
|
return ScalarType::Long; |
|
} else if (isBoolean()) { |
|
return ScalarType::Bool; |
|
} else { |
|
throw std::runtime_error("Unknown scalar type."); |
|
} |
|
} |
|
|
|
Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) { |
|
moveFrom(std::move(rhs)); |
|
} |
|
|
|
Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) { |
|
if (isSymbolic()) { |
|
c10::raw::intrusive_ptr::incref(v.p); |
|
} |
|
} |
|
|
|
Scalar(c10::SymInt si) { |
|
if (si.is_symbolic()) { |
|
tag = Tag::HAS_si; |
|
v.p = std::move(si).release(); |
|
} else { |
|
tag = Tag::HAS_i; |
|
v.i = si.as_int_unchecked(); |
|
} |
|
} |
|
|
|
Scalar(c10::SymFloat sd) { |
|
if (sd.is_symbolic()) { |
|
tag = Tag::HAS_sd; |
|
v.p = std::move(sd).release(); |
|
} else { |
|
tag = Tag::HAS_d; |
|
v.d = sd.as_float_unchecked(); |
|
} |
|
} |
|
|
|
|
|
|
|
private: |
|
enum class Tag { HAS_d, HAS_i, HAS_z, HAS_b, HAS_sd, HAS_si }; |
|
|
|
|
|
C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept { |
|
v = rhs.v; |
|
tag = rhs.tag; |
|
if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd) { |
|
|
|
rhs.tag = Tag::HAS_i; |
|
rhs.v.i = 0; |
|
} |
|
} |
|
|
|
Tag tag; |
|
|
|
union v_t { |
|
double d; |
|
int64_t i; |
|
c10::complex<double> z; |
|
c10::intrusive_ptr_target* p; |
|
v_t() {} |
|
} v; |
|
|
|
template < |
|
typename T, |
|
typename std::enable_if< |
|
std::is_integral<T>::value && !std::is_same<T, bool>::value, |
|
bool>::type* = nullptr> |
|
Scalar(T vv, bool) : tag(Tag::HAS_i) { |
|
v.i = convert<decltype(v.i), T>(vv); |
|
} |
|
|
|
template < |
|
typename T, |
|
typename std::enable_if< |
|
!std::is_integral<T>::value && !c10::is_complex<T>::value, |
|
bool>::type* = nullptr> |
|
Scalar(T vv, bool) : tag(Tag::HAS_d) { |
|
v.d = convert<decltype(v.d), T>(vv); |
|
} |
|
|
|
template < |
|
typename T, |
|
typename std::enable_if<c10::is_complex<T>::value, bool>::type* = nullptr> |
|
Scalar(T vv, bool) : tag(Tag::HAS_z) { |
|
v.z = convert<decltype(v.z), T>(vv); |
|
} |
|
}; |
|
|
|
using OptionalScalarRef = c10::OptionalRef<Scalar>; |
|
|
|
|
|
#define DEFINE_TO(T, name) \ |
|
template <> \ |
|
inline T Scalar::to<T>() const { \ |
|
return to##name(); \ |
|
} |
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO) |
|
#undef DEFINE_TO |
|
|
|
} |
|
|
|
C10_CLANG_DIAGNOSTIC_POP() |
|
|