|
#pragma once |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <c10/core/Device.h> |
|
#include <c10/core/impl/GPUTrace.h> |
|
#include <c10/cuda/CUDAException.h> |
|
#include <c10/cuda/CUDAMacros.h> |
|
#include <cuda_runtime_api.h> |
|
namespace c10 { |
|
namespace cuda { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_CUDA_API DeviceIndex device_count() noexcept; |
|
|
|
|
|
C10_CUDA_API DeviceIndex device_count_ensure_non_zero(); |
|
|
|
C10_CUDA_API DeviceIndex current_device(); |
|
|
|
C10_CUDA_API void set_device(DeviceIndex device); |
|
|
|
C10_CUDA_API void device_synchronize(); |
|
|
|
C10_CUDA_API void warn_or_error_on_sync(); |
|
|
|
enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR }; |
|
|
|
|
|
|
|
|
|
|
|
class WarningState { |
|
public: |
|
void set_sync_debug_mode(SyncDebugMode l) { |
|
sync_debug_mode = l; |
|
} |
|
|
|
SyncDebugMode get_sync_debug_mode() { |
|
return sync_debug_mode; |
|
} |
|
|
|
private: |
|
SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED; |
|
}; |
|
|
|
C10_CUDA_API __inline__ WarningState& warning_state() { |
|
static WarningState warning_state_; |
|
return warning_state_; |
|
} |
|
|
|
|
|
C10_CUDA_API void __inline__ memcpy_and_sync( |
|
void* dst, |
|
void* src, |
|
int64_t nbytes, |
|
cudaMemcpyKind kind, |
|
cudaStream_t stream) { |
|
if (C10_UNLIKELY( |
|
warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { |
|
warn_or_error_on_sync(); |
|
} |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_stream_synchronization( |
|
reinterpret_cast<uintptr_t>(stream)); |
|
} |
|
#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301) |
|
C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); |
|
#else |
|
C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); |
|
C10_CUDA_CHECK(cudaStreamSynchronize(stream)); |
|
#endif |
|
} |
|
|
|
C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) { |
|
if (C10_UNLIKELY( |
|
warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { |
|
warn_or_error_on_sync(); |
|
} |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_stream_synchronization( |
|
reinterpret_cast<uintptr_t>(stream)); |
|
} |
|
C10_CUDA_CHECK(cudaStreamSynchronize(stream)); |
|
} |
|
|
|
} |
|
} |
|
|