|
#pragma once |
|
|
|
#include <c10/core/DeviceGuard.h> |
|
#include <c10/core/impl/DeviceGuardImplInterface.h> |
|
#include <c10/core/impl/GPUTrace.h> |
|
#include <c10/macros/Macros.h> |
|
#include <c10/util/Exception.h> |
|
|
|
#include <c10/cuda/CUDACachingAllocator.h> |
|
#include <c10/cuda/CUDAException.h> |
|
#include <c10/cuda/CUDAFunctions.h> |
|
#include <c10/cuda/CUDAStream.h> |
|
|
|
#include <cuda_runtime_api.h> |
|
|
|
namespace c10 { |
|
namespace cuda { |
|
namespace impl { |
|
|
|
struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { |
|
static constexpr DeviceType static_type = DeviceType::CUDA; |
|
|
|
CUDAGuardImpl() {} |
|
explicit CUDAGuardImpl(DeviceType t) { |
|
TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); |
|
} |
|
DeviceType type() const override { |
|
return DeviceType::CUDA; |
|
} |
|
Device exchangeDevice(Device d) const override { |
|
TORCH_INTERNAL_ASSERT(d.is_cuda()); |
|
Device old_device = getDevice(); |
|
if (old_device.index() != d.index()) { |
|
C10_CUDA_CHECK(cudaSetDevice(d.index())); |
|
} |
|
return old_device; |
|
} |
|
Device getDevice() const override { |
|
int device; |
|
C10_CUDA_CHECK(cudaGetDevice(&device)); |
|
return Device(DeviceType::CUDA, device); |
|
} |
|
c10::optional<Device> uncheckedGetDevice() const noexcept { |
|
int device; |
|
const auto err = C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device)); |
|
C10_CUDA_CHECK_WARN(err); |
|
if (err != cudaSuccess) { |
|
return c10::nullopt; |
|
} |
|
return Device(DeviceType::CUDA, device); |
|
} |
|
void setDevice(Device d) const override { |
|
TORCH_INTERNAL_ASSERT(d.is_cuda()); |
|
Device current_device = getDevice(); |
|
if (current_device != d) { |
|
C10_CUDA_CHECK(cudaSetDevice(d.index())); |
|
} |
|
} |
|
void uncheckedSetDevice(Device d) const noexcept override { |
|
auto current_device = uncheckedGetDevice(); |
|
if (!current_device.has_value() || current_device.value() != d) { |
|
C10_CUDA_CHECK_WARN(cudaSetDevice(d.index())); |
|
} |
|
} |
|
Stream getStream(Device d) const noexcept override { |
|
return getCurrentCUDAStream(d.index()).unwrap(); |
|
} |
|
Stream getDefaultStream(Device d) const override { |
|
return getDefaultCUDAStream(d.index()); |
|
} |
|
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) |
|
const override { |
|
return getStreamFromPool(isHighPriority, d.index()); |
|
} |
|
|
|
Stream exchangeStream(Stream s) const noexcept override { |
|
CUDAStream cs(s); |
|
auto old_stream = getCurrentCUDAStream(s.device().index()); |
|
setCurrentCUDAStream(cs); |
|
return old_stream.unwrap(); |
|
} |
|
DeviceIndex deviceCount() const noexcept override { |
|
return device_count(); |
|
} |
|
|
|
|
|
void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { |
|
|
|
auto cuda_flag = cudaEventDefault; |
|
switch (flag) { |
|
case EventFlag::PYTORCH_DEFAULT: |
|
case EventFlag::CUDA_EVENT_DISABLE_TIMING: |
|
cuda_flag = cudaEventDisableTiming; |
|
break; |
|
case EventFlag::BACKEND_DEFAULT: |
|
case EventFlag::CUDA_EVENT_DEFAULT: |
|
cuda_flag = cudaEventDefault; |
|
break; |
|
default: |
|
TORCH_CHECK(false, "CUDA event received unknown flag"); |
|
} |
|
|
|
C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_creation( |
|
reinterpret_cast<uintptr_t>(cuda_event)); |
|
} |
|
} |
|
|
|
void destroyEvent(void* event, const DeviceIndex device_index) |
|
const noexcept override { |
|
if (!event) |
|
return; |
|
auto cuda_event = static_cast<cudaEvent_t>(event); |
|
int orig_device; |
|
C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device)); |
|
C10_CUDA_CHECK_WARN(cudaSetDevice(device_index)); |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_deletion( |
|
reinterpret_cast<uintptr_t>(cuda_event)); |
|
} |
|
C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); |
|
C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device)); |
|
} |
|
|
|
void record( |
|
void** event, |
|
const Stream& stream, |
|
const DeviceIndex device_index, |
|
const EventFlag flag) const override { |
|
TORCH_CHECK( |
|
device_index == -1 || device_index == stream.device_index(), |
|
"Event device index ", |
|
device_index, |
|
" does not match recording stream's device index ", |
|
stream.device_index(), |
|
"."); |
|
|
|
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event); |
|
CUDAStream cuda_stream{stream}; |
|
|
|
|
|
const auto orig_device = getDevice(); |
|
setDevice(stream.device()); |
|
|
|
|
|
if (!cuda_event) |
|
createEvent(&cuda_event, flag); |
|
C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); |
|
|
|
*event = cuda_event; |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_record( |
|
reinterpret_cast<uintptr_t>(cuda_event), |
|
reinterpret_cast<uintptr_t>(cuda_stream.stream())); |
|
} |
|
|
|
|
|
setDevice(orig_device); |
|
} |
|
|
|
void block(void* event, const Stream& stream) const override { |
|
if (!event) |
|
return; |
|
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); |
|
CUDAStream cuda_stream{stream}; |
|
const auto orig_device = getDevice(); |
|
setDevice(stream.device()); |
|
C10_CUDA_CHECK(cudaStreamWaitEvent( |
|
cuda_stream, |
|
cuda_event, |
|
0)); |
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
|
if (C10_UNLIKELY(interp)) { |
|
(*interp)->trace_gpu_event_wait( |
|
reinterpret_cast<uintptr_t>(cuda_event), |
|
reinterpret_cast<uintptr_t>(cuda_stream.stream())); |
|
} |
|
setDevice(orig_device); |
|
} |
|
|
|
|
|
bool queryEvent(void* event) const override { |
|
if (!event) |
|
return true; |
|
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); |
|
const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); |
|
if (err != cudaErrorNotReady) { |
|
C10_CUDA_CHECK(err); |
|
} else { |
|
|
|
(void)cudaGetLastError(); |
|
} |
|
return (err == cudaSuccess); |
|
} |
|
|
|
|
|
bool queryStream(const Stream& stream) const override { |
|
CUDAStream cuda_stream{stream}; |
|
return cuda_stream.query(); |
|
} |
|
|
|
void synchronizeStream(const Stream& stream) const override { |
|
CUDAStream cuda_stream{stream}; |
|
cuda_stream.synchronize(); |
|
} |
|
|
|
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) |
|
const override { |
|
CUDAStream cuda_stream{stream}; |
|
CUDACachingAllocator::recordStream(data_ptr, cuda_stream); |
|
} |
|
}; |
|
|
|
} |
|
} |
|
} |
|
|