|
#pragma once |
|
|
|
#include <c10/core/ScalarType.h> |
|
#include <c10/macros/Macros.h> |
|
#include <c10/util/Load.h> |
|
#include <c10/util/TypeCast.h> |
|
|
|
namespace c10 { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef C10_HOST_DEVICE |
|
#define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false); |
|
#else |
|
#define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type"); |
|
#endif |
|
|
|
|
|
|
|
#define FETCH_AND_CAST_CASE(type, scalartype) \ |
|
case ScalarType::scalartype: \ |
|
return c10::convert<dest_t>(c10::load<type>(ptr)); |
|
|
|
template <typename dest_t> |
|
C10_HOST_DEVICE inline dest_t fetch_and_cast( |
|
const ScalarType src_type, |
|
const void* ptr) { |
|
switch (src_type) { |
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE) |
|
default: |
|
ERROR_UNSUPPORTED_CAST |
|
} |
|
return dest_t(0); |
|
} |
|
|
|
|
|
|
|
#define CAST_AND_STORE_CASE(type, scalartype) \ |
|
case ScalarType::scalartype: \ |
|
*(type*)ptr = c10::convert<type>(value); \ |
|
return; |
|
template <typename src_t> |
|
C10_HOST_DEVICE inline void cast_and_store( |
|
const ScalarType dest_type, |
|
void* ptr, |
|
src_t value) { |
|
switch (dest_type) { |
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE) |
|
default:; |
|
} |
|
ERROR_UNSUPPORTED_CAST |
|
} |
|
|
|
#define DEFINE_UNCASTABLE(T, scalartype_) \ |
|
template <> \ |
|
C10_HOST_DEVICE inline T fetch_and_cast<T>( \ |
|
const ScalarType src_type, const void* ptr) { \ |
|
CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type); \ |
|
return c10::load<T>(ptr); \ |
|
} \ |
|
template <> \ |
|
C10_HOST_DEVICE inline void cast_and_store<T>( \ |
|
const ScalarType dest_type, void* ptr, T value) { \ |
|
CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \ |
|
*(T*)ptr = value; \ |
|
} |
|
|
|
AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE) |
|
|
|
#undef FETCH_AND_CAST_CASE |
|
#undef CAST_AND_STORE_CASE |
|
#undef DEFINE_UNCASTABLE |
|
#undef ERROR_UNSUPPORTED_CAST |
|
|
|
} |
|
|