|
#pragma once |
|
|
|
#include <c10/cuda/CUDAStream.h> |
|
#include <utility> |
|
|
|
|
|
|
|
|
|
namespace c10 { |
|
namespace cuda { |
|
|
|
using CaptureId_t = unsigned long long; |
|
|
|
|
|
|
|
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>; |
|
|
|
|
|
|
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
|
struct C10_CUDA_API CUDAStreamCaptureModeGuard { |
|
CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) { |
|
strictness_ = desired; |
|
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_)); |
|
} |
|
~CUDAStreamCaptureModeGuard() { |
|
C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_)); |
|
} |
|
|
|
private: |
|
cudaStreamCaptureMode strictness_; |
|
}; |
|
#endif |
|
|
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
|
|
|
|
|
static_assert( |
|
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, |
|
"unexpected int(cudaStreamCaptureStatusNone) value"); |
|
static_assert( |
|
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, |
|
"unexpected int(cudaStreamCaptureStatusActive) value"); |
|
static_assert( |
|
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, |
|
"unexpected int(cudaStreamCaptureStatusInvalidated) value"); |
|
#endif |
|
|
|
enum class CaptureStatus : int { |
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
|
None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone), |
|
Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive), |
|
Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) |
|
#else |
|
None = 0 |
|
#endif |
|
}; |
|
|
|
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { |
|
switch (status) { |
|
case CaptureStatus::None: |
|
os << "cudaStreamCaptureStatusNone"; |
|
break; |
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
|
case CaptureStatus::Active: |
|
os << "cudaStreamCaptureStatusActive"; |
|
break; |
|
case CaptureStatus::Invalidated: |
|
os << "cudaStreamCaptureStatusInvalidated"; |
|
break; |
|
#endif |
|
default: |
|
TORCH_INTERNAL_ASSERT( |
|
false, "Unknown CUDA graph CaptureStatus", int(status)); |
|
} |
|
return os; |
|
} |
|
|
|
|
|
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { |
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
|
cudaStreamCaptureStatus is_capturing; |
|
C10_CUDA_CHECK( |
|
cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing)); |
|
return CaptureStatus(is_capturing); |
|
#else |
|
return CaptureStatus::None; |
|
#endif |
|
} |
|
|
|
} |
|
} |
|
|