|
#pragma once |
|
|
|
#include <c10/core/DeviceType.h> |
|
#include <c10/core/DispatchKey.h> |
|
#include <c10/core/DispatchKeySet.h> |
|
#include <c10/util/Exception.h> |
|
|
|
#include <stdexcept> |
|
|
|
namespace c10 { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum class Backend { |
|
CPU, |
|
CUDA, |
|
HIP, |
|
VE, |
|
FPGA, |
|
IPU, |
|
XPU, |
|
SparseCPU, |
|
SparseCUDA, |
|
SparseCsrCPU, |
|
SparseCsrCUDA, |
|
SparseHIP, |
|
SparseVE, |
|
SparseXPU, |
|
ORT, |
|
XLA, |
|
Vulkan, |
|
Metal, |
|
Meta, |
|
QuantizedCPU, |
|
QuantizedCUDA, |
|
QuantizedXPU, |
|
Undefined, |
|
MkldnnCPU, |
|
MPS, |
|
HPU, |
|
Lazy, |
|
PrivateUse1, |
|
NumOptions |
|
}; |
|
|
|
static inline Backend dispatchKeyToBackend(DispatchKey t) { |
|
if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) { |
|
return Backend::CPU; |
|
} else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) { |
|
return Backend::CUDA; |
|
} else if (t == DispatchKey::HIP) { |
|
return Backend::HIP; |
|
} else if (t == DispatchKey::VE) { |
|
return Backend::VE; |
|
} else if (t == DispatchKey::FPGA) { |
|
return Backend::FPGA; |
|
} else if (t == DispatchKey::ORT) { |
|
return Backend::ORT; |
|
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { |
|
return Backend::XLA; |
|
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) { |
|
return Backend::Lazy; |
|
} else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) { |
|
return Backend::MPS; |
|
} else if (t == DispatchKey::Vulkan) { |
|
return Backend::Vulkan; |
|
} else if (t == DispatchKey::Metal) { |
|
return Backend::Metal; |
|
} else if (t == DispatchKey::Meta) { |
|
return Backend::Meta; |
|
} else if (t == DispatchKey::SparseCPU) { |
|
return Backend::SparseCPU; |
|
} else if (t == DispatchKey::SparseCUDA) { |
|
return Backend::SparseCUDA; |
|
} else if (t == DispatchKey::SparseHIP) { |
|
return Backend::SparseHIP; |
|
} else if (t == DispatchKey::SparseVE) { |
|
return Backend::SparseVE; |
|
} else if (t == DispatchKey::SparseCsrCPU) { |
|
return Backend::SparseCsrCPU; |
|
} else if (t == DispatchKey::SparseCsrCUDA) { |
|
return Backend::SparseCsrCUDA; |
|
} else if (t == DispatchKey::MkldnnCPU) { |
|
return Backend::MkldnnCPU; |
|
} else if (t == DispatchKey::QuantizedCPU) { |
|
return Backend::QuantizedCPU; |
|
} else if (t == DispatchKey::QuantizedCUDA) { |
|
return Backend::QuantizedCUDA; |
|
} else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) { |
|
return Backend::IPU; |
|
} else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) { |
|
return Backend::XPU; |
|
} else if (t == DispatchKey::SparseXPU) { |
|
return Backend::SparseXPU; |
|
} else if (t == DispatchKey::QuantizedXPU) { |
|
return Backend::QuantizedXPU; |
|
} else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) { |
|
return Backend::HPU; |
|
} else if (t == DispatchKey::PrivateUse1) { |
|
return Backend::PrivateUse1; |
|
} else if (t == DispatchKey::Undefined) { |
|
return Backend::Undefined; |
|
} else { |
|
TORCH_CHECK(false, "Unrecognized tensor type ID: ", t); |
|
} |
|
} |
|
|
|
static inline DispatchKey backendToDispatchKey(Backend b) { |
|
switch (b) { |
|
case Backend::CPU: |
|
return DispatchKey::CPU; |
|
case Backend::CUDA: |
|
return DispatchKey::CUDA; |
|
case Backend::HIP: |
|
return DispatchKey::HIP; |
|
case Backend::VE: |
|
return DispatchKey::VE; |
|
case Backend::FPGA: |
|
return DispatchKey::FPGA; |
|
case Backend::ORT: |
|
return DispatchKey::ORT; |
|
case Backend::XLA: |
|
return DispatchKey::XLA; |
|
case Backend::Lazy: |
|
return DispatchKey::Lazy; |
|
case Backend::IPU: |
|
return DispatchKey::IPU; |
|
case Backend::XPU: |
|
return DispatchKey::XPU; |
|
case Backend::SparseXPU: |
|
return DispatchKey::SparseXPU; |
|
case Backend::SparseCPU: |
|
return DispatchKey::SparseCPU; |
|
case Backend::SparseCUDA: |
|
return DispatchKey::SparseCUDA; |
|
case Backend::SparseHIP: |
|
return DispatchKey::SparseHIP; |
|
case Backend::SparseVE: |
|
return DispatchKey::SparseVE; |
|
case Backend::SparseCsrCPU: |
|
return DispatchKey::SparseCsrCPU; |
|
case Backend::SparseCsrCUDA: |
|
return DispatchKey::SparseCsrCUDA; |
|
case Backend::MkldnnCPU: |
|
return DispatchKey::MkldnnCPU; |
|
case Backend::Vulkan: |
|
return DispatchKey::Vulkan; |
|
case Backend::Metal: |
|
return DispatchKey::Metal; |
|
case Backend::Meta: |
|
return DispatchKey::Meta; |
|
case Backend::QuantizedCPU: |
|
return DispatchKey::QuantizedCPU; |
|
case Backend::QuantizedCUDA: |
|
return DispatchKey::QuantizedCUDA; |
|
case Backend::Undefined: |
|
return DispatchKey::Undefined; |
|
case Backend::MPS: |
|
return DispatchKey::MPS; |
|
case Backend::HPU: |
|
return DispatchKey::HPU; |
|
case Backend::PrivateUse1: |
|
return DispatchKey::PrivateUse1; |
|
default: |
|
throw std::runtime_error("Unknown backend"); |
|
} |
|
} |
|
|
|
static inline DeviceType backendToDeviceType(Backend b) { |
|
switch (b) { |
|
case Backend::CPU: |
|
return DeviceType::CPU; |
|
case Backend::CUDA: |
|
return DeviceType::CUDA; |
|
case Backend::HIP: |
|
return DeviceType::HIP; |
|
case Backend::VE: |
|
return DeviceType::VE; |
|
case Backend::FPGA: |
|
return DeviceType::FPGA; |
|
case Backend::ORT: |
|
return DeviceType::ORT; |
|
case Backend::XLA: |
|
return DeviceType::XLA; |
|
case Backend::Lazy: |
|
return DeviceType::Lazy; |
|
case Backend::SparseCPU: |
|
return DeviceType::CPU; |
|
case Backend::SparseCUDA: |
|
return DeviceType::CUDA; |
|
case Backend::SparseHIP: |
|
return DeviceType::HIP; |
|
case Backend::SparseVE: |
|
return DeviceType::VE; |
|
case Backend::SparseCsrCPU: |
|
return DeviceType::CPU; |
|
case Backend::SparseCsrCUDA: |
|
return DeviceType::CUDA; |
|
case Backend::IPU: |
|
return DeviceType::IPU; |
|
case Backend::XPU: |
|
case Backend::SparseXPU: |
|
case Backend::QuantizedXPU: |
|
return DeviceType::XPU; |
|
case Backend::MkldnnCPU: |
|
case Backend::QuantizedCPU: |
|
return DeviceType::CPU; |
|
case Backend::QuantizedCUDA: |
|
return DeviceType::CUDA; |
|
case Backend::Vulkan: |
|
return DeviceType::Vulkan; |
|
case Backend::Metal: |
|
return DeviceType::Metal; |
|
case Backend::Meta: |
|
return DeviceType::Meta; |
|
case Backend::MPS: |
|
return DeviceType::MPS; |
|
case Backend::HPU: |
|
return DeviceType::HPU; |
|
case Backend::PrivateUse1: |
|
return DeviceType::PrivateUse1; |
|
case Backend::Undefined: |
|
TORCH_CHECK(false, "Undefined backend is not a valid device type"); |
|
default: |
|
TORCH_CHECK(false, "Unknown backend"); |
|
} |
|
} |
|
|
|
|
|
static inline const char* toString(Backend b) { |
|
switch (b) { |
|
case Backend::CPU: |
|
return "CPU"; |
|
case Backend::CUDA: |
|
return "CUDA"; |
|
case Backend::HIP: |
|
return "HIP"; |
|
case Backend::VE: |
|
return "VE"; |
|
case Backend::FPGA: |
|
return "FPGA"; |
|
case Backend::XPU: |
|
return "XPU"; |
|
case Backend::IPU: |
|
return "IPU"; |
|
case Backend::ORT: |
|
return "ORT"; |
|
case Backend::XLA: |
|
return "XLA"; |
|
case Backend::Lazy: |
|
return "Lazy"; |
|
case Backend::MPS: |
|
return "MPS"; |
|
case Backend::SparseCPU: |
|
return "SparseCPU"; |
|
case Backend::SparseCUDA: |
|
return "SparseCUDA"; |
|
case Backend::SparseHIP: |
|
return "SparseHIP"; |
|
case Backend::SparseVE: |
|
return "SparseVE"; |
|
case Backend::SparseXPU: |
|
return "SparseXPU"; |
|
case Backend::SparseCsrCPU: |
|
return "SparseCsrCPU"; |
|
case Backend::SparseCsrCUDA: |
|
return "SparseCsrCUDA"; |
|
case Backend::MkldnnCPU: |
|
return "MkldnnCPU"; |
|
case Backend::Vulkan: |
|
return "Vulkan"; |
|
case Backend::Metal: |
|
return "Metal"; |
|
case Backend::Meta: |
|
return "Meta"; |
|
case Backend::QuantizedCPU: |
|
return "QuantizedCPU"; |
|
case Backend::QuantizedCUDA: |
|
return "QuantizedCUDA"; |
|
case Backend::QuantizedXPU: |
|
return "QuantizedXPU"; |
|
case Backend::HPU: |
|
return "HPU"; |
|
case Backend::PrivateUse1: |
|
return "PrivateUseOne"; |
|
default: |
|
return "UNKNOWN_BACKEND"; |
|
} |
|
} |
|
|
|
static inline bool isSparse(Backend b) { |
|
switch (b) { |
|
case Backend::SparseXPU: |
|
case Backend::SparseCPU: |
|
case Backend::SparseCUDA: |
|
case Backend::SparseHIP: |
|
case Backend::SparseVE: |
|
return true; |
|
default: |
|
return false; |
|
} |
|
} |
|
|
|
static inline bool isSparseCsr(Backend b) { |
|
switch (b) { |
|
case Backend::SparseCsrCPU: |
|
case Backend::SparseCsrCUDA: |
|
return true; |
|
default: |
|
return false; |
|
} |
|
} |
|
|
|
} |
|
|