|
#pragma once |
|
|
|
#include <c10/util/TypeTraits.h> |
|
|
|
namespace c10 { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class FuncType_, FuncType_* func_ptr_> |
|
struct CompileTimeFunctionPointer final { |
|
static_assert( |
|
guts::is_function_type<FuncType_>::value, |
|
"TORCH_FN can only wrap function types."); |
|
using FuncType = FuncType_; |
|
|
|
static constexpr FuncType* func_ptr() { |
|
return func_ptr_; |
|
} |
|
}; |
|
|
|
template <class T> |
|
struct is_compile_time_function_pointer : std::false_type {}; |
|
template <class FuncType, FuncType* func_ptr> |
|
struct is_compile_time_function_pointer< |
|
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {}; |
|
|
|
} |
|
|
|
#define TORCH_FN_TYPE(func) \ |
|
::c10::CompileTimeFunctionPointer< \ |
|
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \ |
|
func> |
|
#define TORCH_FN(func) TORCH_FN_TYPE(func)() |
|
|