|
#ifndef C10_UTIL_EXCEPTION_H_ |
|
#define C10_UTIL_EXCEPTION_H_ |
|
|
|
#include <c10/macros/Macros.h> |
|
#include <c10/util/Deprecated.h> |
|
#include <c10/util/StringUtil.h> |
|
|
|
#include <cstddef> |
|
#include <exception> |
|
#include <ostream> |
|
#include <sstream> |
|
#include <string> |
|
#include <vector> |
|
|
|
#if defined(_MSC_VER) && _MSC_VER <= 1900 |
|
#define __func__ __FUNCTION__ |
|
#endif |
|
|
|
namespace c10 { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class C10_API Error : public std::exception { |
|
|
|
std::string msg_; |
|
|
|
|
|
|
|
|
|
std::vector<std::string> context_; |
|
|
|
|
|
|
|
|
|
std::string backtrace_; |
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string what_; |
|
std::string what_without_backtrace_; |
|
|
|
|
|
|
|
|
|
|
|
|
|
const void* caller_; |
|
|
|
public: |
|
|
|
|
|
Error(SourceLocation source_location, std::string msg); |
|
|
|
|
|
Error( |
|
const char* file, |
|
const uint32_t line, |
|
const char* condition, |
|
const std::string& msg, |
|
const std::string& backtrace, |
|
const void* caller = nullptr); |
|
|
|
|
|
Error(std::string msg, std::string backtrace, const void* caller = nullptr); |
|
|
|
|
|
|
|
|
|
|
|
void add_context(std::string msg); |
|
|
|
const std::string& msg() const { |
|
return msg_; |
|
} |
|
|
|
const std::vector<std::string>& context() const { |
|
return context_; |
|
} |
|
|
|
const std::string& backtrace() const { |
|
return backtrace_; |
|
} |
|
|
|
|
|
|
|
|
|
const char* what() const noexcept override { |
|
return what_.c_str(); |
|
} |
|
|
|
const void* caller() const noexcept { |
|
return caller_; |
|
} |
|
|
|
|
|
|
|
|
|
const char* what_without_backtrace() const noexcept { |
|
return what_without_backtrace_.c_str(); |
|
} |
|
|
|
private: |
|
void refresh_what(); |
|
std::string compute_what(bool include_backtrace) const; |
|
}; |
|
|
|
class C10_API WarningHandler { |
|
public: |
|
virtual ~WarningHandler() = default; |
|
|
|
virtual void process( |
|
const SourceLocation& source_location, |
|
const std::string& msg, |
|
const bool verbatim); |
|
}; |
|
|
|
namespace Warning { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_API void warn( |
|
const SourceLocation& source_location, |
|
const std::string& msg, |
|
bool verbatim); |
|
C10_API void warn( |
|
SourceLocation source_location, |
|
const char* msg, |
|
bool verbatim); |
|
C10_API void warn( |
|
SourceLocation source_location, |
|
::c10::detail::CompileTimeEmptyString msg, |
|
bool verbatim); |
|
|
|
|
|
|
|
|
|
|
|
C10_API void set_warning_handler(WarningHandler* handler) noexcept(true); |
|
|
|
C10_API WarningHandler* get_warning_handler() noexcept(true); |
|
|
|
class C10_API WarningHandlerGuard { |
|
WarningHandler* prev_handler_; |
|
|
|
public: |
|
WarningHandlerGuard(WarningHandler* new_handler) |
|
: prev_handler_(c10::Warning::get_warning_handler()) { |
|
c10::Warning::set_warning_handler(new_handler); |
|
} |
|
~WarningHandlerGuard() { |
|
c10::Warning::set_warning_handler(prev_handler_); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
C10_API void set_warnAlways(bool) noexcept(true); |
|
C10_API bool get_warnAlways(void) noexcept(true); |
|
|
|
|
|
|
|
struct C10_API WarnAlways { |
|
public: |
|
explicit WarnAlways(bool setting = true); |
|
~WarnAlways(); |
|
|
|
private: |
|
bool prev_setting; |
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
class C10_API IndexError : public Error { |
|
using Error::Error; |
|
}; |
|
|
|
|
|
|
|
class C10_API ValueError : public Error { |
|
using Error::Error; |
|
}; |
|
|
|
|
|
|
|
class C10_API TypeError : public Error { |
|
using Error::Error; |
|
}; |
|
|
|
|
|
|
|
class C10_API NotImplementedError : public Error { |
|
using Error::Error; |
|
}; |
|
|
|
|
|
|
|
class C10_API EnforceFiniteError : public Error { |
|
using Error::Error; |
|
}; |
|
|
|
|
|
|
|
class C10_API OnnxfiBackendSystemError : public Error { |
|
using Error::Error; |
|
}; |
|
|
|
|
|
|
|
class C10_API LinAlgError : public Error { |
|
using Error::Error; |
|
}; |
|
|
|
class C10_API OutOfMemoryError : public Error { |
|
using Error::Error; |
|
}; |
|
|
|
|
|
|
|
C10_API std::string GetExceptionString(const std::exception& e); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define C10_THROW_ERROR(err_type, msg) \ |
|
throw ::c10::err_type( \ |
|
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, msg) |
|
|
|
|
|
|
|
|
|
#define C10_EXPAND_MSVC_WORKAROUND(x) x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(__CUDACC__) |
|
#define C10_UNLIKELY_OR_CONST(e) e |
|
#else |
|
#define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) |
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
#ifdef STRIP_ERROR_MESSAGES |
|
#define TORCH_RETHROW(e, ...) throw |
|
#else |
|
#define TORCH_RETHROW(e, ...) \ |
|
do { \ |
|
e.add_context(::c10::str(__VA_ARGS__)); \ |
|
throw; \ |
|
} while (false) |
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef STRIP_ERROR_MESSAGES |
|
#define TORCH_INTERNAL_ASSERT(cond, ...) \ |
|
if (C10_UNLIKELY_OR_CONST(!(cond))) { \ |
|
::c10::detail::torchCheckFail( \ |
|
__func__, \ |
|
__FILE__, \ |
|
static_cast<uint32_t>(__LINE__), \ |
|
#cond " INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__)); \ |
|
} |
|
#else |
|
|
|
|
|
|
|
|
|
|
|
#define TORCH_INTERNAL_ASSERT(cond, ...) \ |
|
if (C10_UNLIKELY_OR_CONST(!(cond))) { \ |
|
::c10::detail::torchInternalAssertFail( \ |
|
__func__, \ |
|
__FILE__, \ |
|
static_cast<uint32_t>(__LINE__), \ |
|
#cond \ |
|
" INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__) ":" C10_STRINGIZE( \ |
|
__LINE__) ", please report a bug to PyTorch. ", \ |
|
c10::str(__VA_ARGS__)); \ |
|
} |
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define TORCH_CHECK_WITH(error_t, cond, ...) \ |
|
TORCH_CHECK_WITH_MSG(error_t, cond, "", __VA_ARGS__) |
|
|
|
#ifdef STRIP_ERROR_MESSAGES |
|
#define TORCH_CHECK_MSG(cond, type, ...) \ |
|
(#cond #type " CHECK FAILED at " C10_STRINGIZE(__FILE__)) |
|
#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ |
|
if (C10_UNLIKELY_OR_CONST(!(cond))) { \ |
|
C10_THROW_ERROR(Error, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \ |
|
} |
|
#else |
|
namespace c10 { |
|
namespace detail { |
|
template <typename... Args> |
|
decltype(auto) torchCheckMsgImpl(const char* , const Args&... args) { |
|
return ::c10::str(args...); |
|
} |
|
inline C10_API const char* torchCheckMsgImpl(const char* msg) { |
|
return msg; |
|
} |
|
|
|
inline C10_API const char* torchCheckMsgImpl( |
|
const char* , |
|
const char* args) { |
|
return args; |
|
} |
|
} |
|
} |
|
|
|
#define TORCH_CHECK_MSG(cond, type, ...) \ |
|
(::c10::detail::torchCheckMsgImpl( \ |
|
"Expected " #cond \ |
|
" to be true, but got false. " \ |
|
"(Could this error message be improved? If so, " \ |
|
"please report an enhancement request to PyTorch.)", \ |
|
##__VA_ARGS__)) |
|
#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ |
|
if (C10_UNLIKELY_OR_CONST(!(cond))) { \ |
|
C10_THROW_ERROR(error_t, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \ |
|
} |
|
#endif |
|
|
|
namespace c10 { |
|
namespace detail { |
|
|
|
[[noreturn]] C10_API void torchCheckFail( |
|
const char* func, |
|
const char* file, |
|
uint32_t line, |
|
const std::string& msg); |
|
[[noreturn]] C10_API void torchCheckFail( |
|
const char* func, |
|
const char* file, |
|
uint32_t line, |
|
const char* msg); |
|
|
|
|
|
|
|
|
|
|
|
|
|
[[noreturn]] C10_API void torchInternalAssertFail( |
|
const char* func, |
|
const char* file, |
|
uint32_t line, |
|
const char* condMsg, |
|
const char* userMsg); |
|
[[noreturn]] inline C10_API void torchInternalAssertFail( |
|
const char* func, |
|
const char* file, |
|
uint32_t line, |
|
const char* condMsg, |
|
::c10::detail::CompileTimeEmptyString ) { |
|
torchCheckFail(func, file, line, condMsg); |
|
} |
|
[[noreturn]] C10_API void torchInternalAssertFail( |
|
const char* func, |
|
const char* file, |
|
uint32_t line, |
|
const char* condMsg, |
|
const std::string& userMsg); |
|
|
|
} |
|
} |
|
|
|
#ifdef STRIP_ERROR_MESSAGES |
|
#define TORCH_CHECK(cond, ...) \ |
|
if (C10_UNLIKELY_OR_CONST(!(cond))) { \ |
|
::c10::detail::torchCheckFail( \ |
|
__func__, \ |
|
__FILE__, \ |
|
static_cast<uint32_t>(__LINE__), \ |
|
TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ |
|
} |
|
#else |
|
#define TORCH_CHECK(cond, ...) \ |
|
if (C10_UNLIKELY_OR_CONST(!(cond))) { \ |
|
::c10::detail::torchCheckFail( \ |
|
__func__, \ |
|
__FILE__, \ |
|
static_cast<uint32_t>(__LINE__), \ |
|
TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ |
|
} |
|
#endif |
|
|
|
|
|
|
|
|
|
#if defined(__CUDACC__) || defined(__HIPCC__) |
|
#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) |
|
#else |
|
#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) TORCH_CHECK(cond, ##__VA_ARGS__) |
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
#ifdef NDEBUG |
|
|
|
#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \ |
|
while (false) \ |
|
C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)) |
|
#else |
|
#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \ |
|
C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)) |
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
#define TORCH_CHECK_LINALG(cond, ...) \ |
|
TORCH_CHECK_WITH_MSG(LinAlgError, cond, "LINALG", __VA_ARGS__) |
|
|
|
|
|
#define TORCH_CHECK_INDEX(cond, ...) \ |
|
TORCH_CHECK_WITH_MSG(IndexError, cond, "INDEX", __VA_ARGS__) |
|
|
|
|
|
#define TORCH_CHECK_VALUE(cond, ...) \ |
|
TORCH_CHECK_WITH_MSG(ValueError, cond, "VALUE", __VA_ARGS__) |
|
|
|
|
|
#define TORCH_CHECK_TYPE(cond, ...) \ |
|
TORCH_CHECK_WITH_MSG(TypeError, cond, "TYPE", __VA_ARGS__) |
|
|
|
|
|
#define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ |
|
TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__) |
|
|
|
|
|
|
|
|
|
#ifdef STRIP_ERROR_MESSAGES |
|
#define TORCH_WARN(...) \ |
|
::c10::Warning::warn( \ |
|
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \ |
|
::c10::detail::CompileTimeEmptyString{}, \ |
|
false) |
|
#else |
|
#define TORCH_WARN(...) \ |
|
::c10::Warning::warn( \ |
|
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \ |
|
::c10::str(__VA_ARGS__), \ |
|
false) |
|
#endif |
|
|
|
|
|
|
|
|
|
#ifdef STRIP_ERROR_MESSAGES |
|
#define _TORCH_WARN_ONCE(...) \ |
|
C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = \ |
|
[&] { \ |
|
::c10::Warning::warn( \ |
|
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \ |
|
::c10::detail::CompileTimeEmptyString{}, \ |
|
false); \ |
|
return true; \ |
|
}() |
|
#else |
|
#define _TORCH_WARN_ONCE(...) \ |
|
C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = \ |
|
[&] { \ |
|
::c10::Warning::warn( \ |
|
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \ |
|
::c10::str(__VA_ARGS__), \ |
|
false); \ |
|
return true; \ |
|
}() |
|
#endif |
|
|
|
#define TORCH_WARN_ONCE(...) \ |
|
if (::c10::Warning::get_warnAlways()) { \ |
|
TORCH_WARN(__VA_ARGS__); \ |
|
} else { \ |
|
_TORCH_WARN_ONCE(__VA_ARGS__); \ |
|
} |
|
|
|
|
|
|
|
#define TORCH_CHECK_ARG(cond, argN, ...) \ |
|
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__) |
|
|
|
|
|
|
|
|
|
|
|
namespace c10 { |
|
namespace detail { |
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void deprecated_AT_ERROR() {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void deprecated_AT_ASSERT() {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void deprecated_AT_ASSERTM() {} |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define AT_ASSERT(...) \ |
|
do { \ |
|
::c10::detail::deprecated_AT_ASSERT(); \ |
|
C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)); \ |
|
} while (false) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define AT_ASSERTM(cond, ...) \ |
|
do { \ |
|
::c10::detail::deprecated_AT_ASSERTM(); \ |
|
C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \ |
|
} while (false) |
|
|
|
|
|
|
|
|
|
|
|
#define AT_ERROR(...) \ |
|
do { \ |
|
::c10::detail::deprecated_AT_ERROR(); \ |
|
C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \ |
|
} while (false) |
|
|
|
#endif |
|
|