Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/Lib/site-packages/torch/fx/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/fx/__pycache__/_compatibility.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/fx/__pycache__/_pytree.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/fx/__pycache__/_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/fx/__pycache__/traceback.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/include/ATen/ATen.h +37 -0
- .venv/Lib/site-packages/torch/include/ATen/AccumulateType.h +173 -0
- .venv/Lib/site-packages/torch/include/ATen/ArrayRef.h +2 -0
- .venv/Lib/site-packages/torch/include/ATen/Backend.h +2 -0
- .venv/Lib/site-packages/torch/include/ATen/CPUApplyUtils.h +343 -0
- .venv/Lib/site-packages/torch/include/ATen/CPUFixedAllocator.h +33 -0
- .venv/Lib/site-packages/torch/include/ATen/CPUFunctions.h +29 -0
- .venv/Lib/site-packages/torch/include/ATen/CPUFunctions_inl.h +540 -0
- .venv/Lib/site-packages/torch/include/ATen/CPUGeneratorImpl.h +49 -0
- .venv/Lib/site-packages/torch/include/ATen/CUDAFunctions.h +29 -0
- .venv/Lib/site-packages/torch/include/ATen/CUDAFunctions_inl.h +623 -0
- .venv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h +29 -0
- .venv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h +323 -0
- .venv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h +29 -0
- .venv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h +502 -0
- .venv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h +29 -0
- .venv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h +25 -0
- .venv/Lib/site-packages/torch/include/ATen/Config.h +21 -0
- .venv/Lib/site-packages/torch/include/ATen/Context.h +610 -0
- .venv/Lib/site-packages/torch/include/ATen/Device.h +2 -0
- .venv/Lib/site-packages/torch/include/ATen/DeviceAccelerator.h +27 -0
- .venv/Lib/site-packages/torch/include/ATen/DeviceGuard.h +41 -0
- .venv/Lib/site-packages/torch/include/ATen/DimVector.h +2 -0
- .venv/Lib/site-packages/torch/include/ATen/Dimname.h +1 -0
- .venv/Lib/site-packages/torch/include/ATen/Dispatch.h +808 -0
- .venv/Lib/site-packages/torch/include/ATen/FuncTorchTLS.h +46 -0
- .venv/Lib/site-packages/torch/include/ATen/FunctionalTensorWrapper.h +454 -0
- .venv/Lib/site-packages/torch/include/ATen/Functions.h +1454 -0
- .venv/Lib/site-packages/torch/include/ATen/Generator.h +2 -0
- .venv/Lib/site-packages/torch/include/ATen/InferSize.h +88 -0
- .venv/Lib/site-packages/torch/include/ATen/InitialTensorOptions.h +15 -0
- .venv/Lib/site-packages/torch/include/ATen/Layout.h +2 -0
- .venv/Lib/site-packages/torch/include/ATen/LegacyBatchedFallback.h +25 -0
- .venv/Lib/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h +160 -0
- .venv/Lib/site-packages/torch/include/ATen/LegacyVmapMode.h +26 -0
- .venv/Lib/site-packages/torch/include/ATen/LegacyVmapTransforms.h +183 -0
- .venv/Lib/site-packages/torch/include/ATen/LinalgBackend.h +31 -0
- .venv/Lib/site-packages/torch/include/ATen/MapAllocator.h +143 -0
- .venv/Lib/site-packages/torch/include/ATen/MatrixRef.h +107 -0
- .venv/Lib/site-packages/torch/include/ATen/MemoryOverlap.h +42 -0
- .venv/Lib/site-packages/torch/include/ATen/MetaFunctions.h +29 -0
- .venv/Lib/site-packages/torch/include/ATen/MetaFunctions_inl.h +325 -0
- .venv/Lib/site-packages/torch/include/ATen/MethodOperators.h +443 -0
.venv/Lib/site-packages/torch/fx/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
.venv/Lib/site-packages/torch/fx/__pycache__/_compatibility.cpython-39.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
.venv/Lib/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc
ADDED
|
Binary file (6.52 kB). View file
|
|
|
.venv/Lib/site-packages/torch/fx/__pycache__/_pytree.cpython-39.pyc
ADDED
|
Binary file (3.75 kB). View file
|
|
|
.venv/Lib/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc
ADDED
|
Binary file (36.7 kB). View file
|
|
|
.venv/Lib/site-packages/torch/fx/__pycache__/_utils.cpython-39.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
.venv/Lib/site-packages/torch/fx/__pycache__/traceback.cpython-39.pyc
ADDED
|
Binary file (2.37 kB). View file
|
|
|
.venv/Lib/site-packages/torch/include/ATen/ATen.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if !defined(_MSC_VER) && __cplusplus < 201703L
|
| 4 |
+
#error C++17 or later compatible compiler is required to use ATen.
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#include <ATen/Context.h>
|
| 8 |
+
#include <ATen/Device.h>
|
| 9 |
+
#include <ATen/DeviceGuard.h>
|
| 10 |
+
#include <ATen/DimVector.h>
|
| 11 |
+
#include <ATen/Dispatch.h>
|
| 12 |
+
#include <ATen/Formatting.h>
|
| 13 |
+
#include <ATen/Functions.h>
|
| 14 |
+
#include <ATen/NamedTensor.h>
|
| 15 |
+
#include <ATen/ScalarOps.h>
|
| 16 |
+
#include <ATen/Tensor.h>
|
| 17 |
+
#include <ATen/TensorGeometry.h>
|
| 18 |
+
#include <ATen/TensorIndexing.h>
|
| 19 |
+
#include <ATen/TensorOperators.h>
|
| 20 |
+
#include <ATen/Version.h>
|
| 21 |
+
#include <ATen/core/ATenGeneral.h>
|
| 22 |
+
#include <ATen/core/Generator.h>
|
| 23 |
+
#include <ATen/core/Reduction.h>
|
| 24 |
+
#include <ATen/core/Scalar.h>
|
| 25 |
+
#include <ATen/core/UnsafeFromTH.h>
|
| 26 |
+
#include <ATen/core/ivalue.h>
|
| 27 |
+
#include <ATen/core/jit_type.h>
|
| 28 |
+
#include <c10/core/Allocator.h>
|
| 29 |
+
#include <c10/core/InferenceMode.h>
|
| 30 |
+
#include <c10/core/Layout.h>
|
| 31 |
+
#include <c10/core/Storage.h>
|
| 32 |
+
#include <c10/core/TensorOptions.h>
|
| 33 |
+
#include <c10/util/Exception.h>
|
| 34 |
+
|
| 35 |
+
// TODO: try to remove this
|
| 36 |
+
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
|
| 37 |
+
#include <ATen/NativeFunctions.h>
|
.venv/Lib/site-packages/torch/include/ATen/AccumulateType.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Config.h>
|
| 3 |
+
#include <c10/core/DeviceType.h>
|
| 4 |
+
#include <c10/core/ScalarType.h>
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 7 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 8 |
+
#include <c10/util/Float8_e5m2.h>
|
| 9 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 10 |
+
#include <c10/util/Half.h>
|
| 11 |
+
|
| 12 |
+
// Defines the accumulation type for a scalar type.
|
| 13 |
+
// Example:
|
| 14 |
+
// using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
|
| 15 |
+
//
|
| 16 |
+
// Accumulation types are an important concept in numeric computing
|
| 17 |
+
// because you frequently want to perform intermediate computations
|
| 18 |
+
// at a higher precision than the input and output precision, to avoid
|
| 19 |
+
// compounding internal rounding errors. Accumulation is the most
|
| 20 |
+
// well-known intermediate computation (it is of great importance for
|
| 21 |
+
// sum reduction and matrix multiply, for example), but in PyTorch
|
| 22 |
+
// acc_type ends up getting used for all sorts of other intermediate
|
| 23 |
+
// computations, so it perhaps would be more accurately (ahem) called an
|
| 24 |
+
// "accurate" type. acc_type is especially important for reduced
|
| 25 |
+
// precision operations like float16 and bfloat16, where relatively
|
| 26 |
+
// benign looking inputs can easily end up overflowing/underflowing.
|
| 27 |
+
//
|
| 28 |
+
// acc_type is parametrized by whether or not you are running on CUDA
|
| 29 |
+
// or not, because on CUDA double precision operations are expensive
|
| 30 |
+
// and so by default, we don't actually want to use double as an
|
| 31 |
+
// acc_type on CUDA. A lot of things are typed out below, but
|
| 32 |
+
// basically, the table is generated by a few rules:
|
| 33 |
+
//
|
| 34 |
+
// If bool:
|
| 35 |
+
// Use 'bool' as acc_type.
|
| 36 |
+
// If floating point:
|
| 37 |
+
// If CUDA, use 'float' as acc_type (unless scalar_t is double),
|
| 38 |
+
// otherwise (CPU) use 'double'
|
| 39 |
+
// If integral:
|
| 40 |
+
// Use 'int64_t' as acc_type
|
| 41 |
+
//
|
| 42 |
+
// You're not forced to use this template; if you happen to know
|
| 43 |
+
// something specific about your use case, you can specify your own
|
| 44 |
+
// desired behavior. This template, however, will give you a reasonable
|
| 45 |
+
// default that will work for all dtypes supported in PyTorch.
|
| 46 |
+
|
| 47 |
+
#if defined(__CUDACC__)
|
| 48 |
+
#include <cuda.h>
|
| 49 |
+
#include <cuda_fp16.h>
|
| 50 |
+
#elif defined(__HIPCC__)
|
| 51 |
+
#include <hip/hip_fp16.h>
|
| 52 |
+
#include <hip/hip_runtime.h>
|
| 53 |
+
#endif
|
| 54 |
+
|
| 55 |
+
namespace at {
|
| 56 |
+
|
| 57 |
+
template <typename T, c10::DeviceType D>
|
| 58 |
+
struct AccumulateTypeDevice {};
|
| 59 |
+
|
| 60 |
+
template <typename T, bool>
|
| 61 |
+
struct AccumulateType {};
|
| 62 |
+
|
| 63 |
+
template <typename T>
|
| 64 |
+
struct AccumulateType<T, false> {
|
| 65 |
+
using type = typename AccumulateTypeDevice<T, c10::DeviceType::CPU>::type;
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
template <typename T>
|
| 69 |
+
struct AccumulateType<T, true> {
|
| 70 |
+
using type = typename AccumulateTypeDevice<T, c10::DeviceType::CUDA>::type;
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template <typename T, c10::DeviceType device>
|
| 74 |
+
using acc_type_device = typename AccumulateTypeDevice<T, device>::type;
|
| 75 |
+
|
| 76 |
+
template <typename T, bool is_cuda>
|
| 77 |
+
using acc_type = typename AccumulateType<T, is_cuda>::type;
|
| 78 |
+
|
| 79 |
+
#define ACC_TYPE(t, acc_t, device_type) \
|
| 80 |
+
template <> \
|
| 81 |
+
struct AccumulateTypeDevice<t, device_type> { \
|
| 82 |
+
using type = acc_t; \
|
| 83 |
+
};
|
| 84 |
+
#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
|
| 85 |
+
#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
|
| 86 |
+
#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
|
| 87 |
+
#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
|
| 88 |
+
|
| 89 |
+
MPS_ACC_TYPE(BFloat16, float);
|
| 90 |
+
MPS_ACC_TYPE(Half, float);
|
| 91 |
+
MPS_ACC_TYPE(Float8_e5m2, float);
|
| 92 |
+
MPS_ACC_TYPE(Float8_e4m3fn, float);
|
| 93 |
+
MPS_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 94 |
+
MPS_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 95 |
+
MPS_ACC_TYPE(float, float);
|
| 96 |
+
MPS_ACC_TYPE(double, float);
|
| 97 |
+
MPS_ACC_TYPE(int8_t, int64_t);
|
| 98 |
+
MPS_ACC_TYPE(uint8_t, int64_t);
|
| 99 |
+
MPS_ACC_TYPE(char, int64_t);
|
| 100 |
+
MPS_ACC_TYPE(int16_t, int64_t);
|
| 101 |
+
MPS_ACC_TYPE(int32_t, int64_t);
|
| 102 |
+
MPS_ACC_TYPE(int64_t, int64_t);
|
| 103 |
+
MPS_ACC_TYPE(bool, bool);
|
| 104 |
+
MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 105 |
+
MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
| 106 |
+
MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
|
| 107 |
+
|
| 108 |
+
XPU_ACC_TYPE(BFloat16, float);
|
| 109 |
+
XPU_ACC_TYPE(Half, float);
|
| 110 |
+
XPU_ACC_TYPE(Float8_e5m2, float);
|
| 111 |
+
XPU_ACC_TYPE(Float8_e4m3fn, float);
|
| 112 |
+
XPU_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 113 |
+
XPU_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 114 |
+
XPU_ACC_TYPE(float, float);
|
| 115 |
+
XPU_ACC_TYPE(double, double);
|
| 116 |
+
XPU_ACC_TYPE(int8_t, int64_t);
|
| 117 |
+
XPU_ACC_TYPE(uint8_t, int64_t);
|
| 118 |
+
XPU_ACC_TYPE(char, int64_t);
|
| 119 |
+
XPU_ACC_TYPE(int16_t, int64_t);
|
| 120 |
+
XPU_ACC_TYPE(int32_t, int64_t);
|
| 121 |
+
XPU_ACC_TYPE(int64_t, int64_t);
|
| 122 |
+
XPU_ACC_TYPE(bool, bool);
|
| 123 |
+
XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 124 |
+
XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
| 125 |
+
XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
| 126 |
+
|
| 127 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 128 |
+
CUDA_ACC_TYPE(half, float);
|
| 129 |
+
#endif
|
| 130 |
+
CUDA_ACC_TYPE(BFloat16, float);
|
| 131 |
+
CUDA_ACC_TYPE(Half, float);
|
| 132 |
+
CUDA_ACC_TYPE(Float8_e5m2, float);
|
| 133 |
+
CUDA_ACC_TYPE(Float8_e4m3fn, float);
|
| 134 |
+
CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 135 |
+
CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 136 |
+
CUDA_ACC_TYPE(float, float);
|
| 137 |
+
CUDA_ACC_TYPE(double, double);
|
| 138 |
+
CUDA_ACC_TYPE(int8_t, int64_t);
|
| 139 |
+
CUDA_ACC_TYPE(uint8_t, int64_t);
|
| 140 |
+
CUDA_ACC_TYPE(char, int64_t);
|
| 141 |
+
CUDA_ACC_TYPE(int16_t, int64_t);
|
| 142 |
+
CUDA_ACC_TYPE(int32_t, int64_t);
|
| 143 |
+
CUDA_ACC_TYPE(int64_t, int64_t);
|
| 144 |
+
CUDA_ACC_TYPE(bool, bool);
|
| 145 |
+
CUDA_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 146 |
+
CUDA_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
| 147 |
+
CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
| 148 |
+
|
| 149 |
+
CPU_ACC_TYPE(BFloat16, float);
|
| 150 |
+
CPU_ACC_TYPE(Half, float);
|
| 151 |
+
CPU_ACC_TYPE(Float8_e5m2, float);
|
| 152 |
+
CPU_ACC_TYPE(Float8_e4m3fn, float);
|
| 153 |
+
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 154 |
+
CPU_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 155 |
+
CPU_ACC_TYPE(float, double);
|
| 156 |
+
CPU_ACC_TYPE(double, double);
|
| 157 |
+
CPU_ACC_TYPE(int8_t, int64_t);
|
| 158 |
+
CPU_ACC_TYPE(uint8_t, int64_t);
|
| 159 |
+
CPU_ACC_TYPE(char, int64_t);
|
| 160 |
+
CPU_ACC_TYPE(int16_t, int64_t);
|
| 161 |
+
CPU_ACC_TYPE(int32_t, int64_t);
|
| 162 |
+
CPU_ACC_TYPE(int64_t, int64_t);
|
| 163 |
+
CPU_ACC_TYPE(bool, bool);
|
| 164 |
+
CPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 165 |
+
CPU_ACC_TYPE(c10::complex<float>, c10::complex<double>);
|
| 166 |
+
CPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
| 167 |
+
|
| 168 |
+
TORCH_API c10::ScalarType toAccumulateType(
|
| 169 |
+
c10::ScalarType type,
|
| 170 |
+
c10::DeviceType device);
|
| 171 |
+
TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
|
| 172 |
+
|
| 173 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/ArrayRef.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/ArrayRef.h>
|
.venv/Lib/site-packages/torch/include/ATen/Backend.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Backend.h>
|
.venv/Lib/site-packages/torch/include/ATen/CPUApplyUtils.h
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/CollapseDims.h>
|
| 4 |
+
#include <ATen/Parallel.h>
|
| 5 |
+
#include <ATen/TensorUtils.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
#include <cstring>
|
| 8 |
+
#include <limits>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
|
| 12 |
+
/*
|
| 13 |
+
* The basic strategy for apply is as follows:
|
| 14 |
+
*
|
| 15 |
+
* 1. Starting with the outermost index, loop until we reach a dimension where
|
| 16 |
+
* the data is no longer contiguous, i.e. the stride at that dimension is not
|
| 17 |
+
* equal to the size of the tensor defined by the outer dimensions. Let's call
|
| 18 |
+
* this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
|
| 19 |
+
* A is equal to the entire Tensor. Let's call the inner tensor B.
|
| 20 |
+
*
|
| 21 |
+
* 2. We loop through the indices in B, starting at its outermost dimension. For
|
| 22 |
+
* example, if B is a 2x2 matrix, then we do:
|
| 23 |
+
*
|
| 24 |
+
* B[0][0]
|
| 25 |
+
* B[0][1]
|
| 26 |
+
* B[1][0]
|
| 27 |
+
* B[1][1]
|
| 28 |
+
*
|
| 29 |
+
* We set the offset into the underlying storage as (storageOffset + stride_B *
|
| 30 |
+
* index_B), i.e. basically we compute the offset into the storage as we would
|
| 31 |
+
* normally for a Tensor. But because we are guaranteed the subsequent data is
|
| 32 |
+
* contiguous in memory, we can simply loop for sizeof(A) iterations and perform
|
| 33 |
+
* the operation, without having to follow the order described by the strides of
|
| 34 |
+
* A.
|
| 35 |
+
*
|
| 36 |
+
* 3. As an optimization, we merge dimensions of A that are contiguous in
|
| 37 |
+
* memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
|
| 38 |
+
* then the first two dimensions can be merged for the purposes of APPLY,
|
| 39 |
+
* reducing the number of nested loops.
|
| 40 |
+
*/
|
| 41 |
+
|
| 42 |
+
inline Tensor sort_strides(Tensor& tensor_) {
|
| 43 |
+
IntArrayRef strides = tensor_.strides();
|
| 44 |
+
std::vector<int64_t> indices;
|
| 45 |
+
indices.reserve(tensor_.ndimension());
|
| 46 |
+
for (const auto i : c10::irange(tensor_.ndimension())) {
|
| 47 |
+
indices.push_back(i);
|
| 48 |
+
}
|
| 49 |
+
std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
|
| 50 |
+
return strides[i1] > strides[i2];
|
| 51 |
+
});
|
| 52 |
+
Tensor tensor = tensor_.permute(indices);
|
| 53 |
+
return tensor;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
template <typename T, int N>
|
| 57 |
+
struct strided_tensor_iter_fixed {
|
| 58 |
+
public:
|
| 59 |
+
T* data_ = NULL;
|
| 60 |
+
int64_t dim_ = 0;
|
| 61 |
+
|
| 62 |
+
int64_t counter_[N] = {0};
|
| 63 |
+
int64_t sizes_[N] = {0};
|
| 64 |
+
int64_t strides_[N] = {0};
|
| 65 |
+
|
| 66 |
+
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
|
| 67 |
+
void operator=(strided_tensor_iter_fixed const& x) = delete;
|
| 68 |
+
strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
|
| 69 |
+
strided_tensor_iter_fixed(
|
| 70 |
+
Tensor& tensor,
|
| 71 |
+
C10_UNUSED bool sort_strides = false)
|
| 72 |
+
: data_(tensor.data_ptr<T>()) {
|
| 73 |
+
std::memset(counter_, 0, sizeof(int64_t) * N);
|
| 74 |
+
if (tensor.dim() > 0) {
|
| 75 |
+
std::memcpy(
|
| 76 |
+
sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
|
| 77 |
+
std::memcpy(
|
| 78 |
+
strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
|
| 79 |
+
}
|
| 80 |
+
dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
|
| 81 |
+
}
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
template <typename T>
|
| 85 |
+
struct strided_tensor_iter {
|
| 86 |
+
private:
|
| 87 |
+
public:
|
| 88 |
+
T* data_ = NULL;
|
| 89 |
+
int64_t dim_;
|
| 90 |
+
|
| 91 |
+
std::vector<int64_t> counter_;
|
| 92 |
+
std::vector<int64_t> sizes_;
|
| 93 |
+
std::vector<int64_t> strides_;
|
| 94 |
+
|
| 95 |
+
strided_tensor_iter(strided_tensor_iter const&) = delete;
|
| 96 |
+
void operator=(strided_tensor_iter const& x) = delete;
|
| 97 |
+
strided_tensor_iter(strided_tensor_iter&&) = default;
|
| 98 |
+
strided_tensor_iter(Tensor& tensor)
|
| 99 |
+
: data_(tensor.data_ptr<T>()),
|
| 100 |
+
dim_(tensor.ndimension()),
|
| 101 |
+
counter_(dim_, 0),
|
| 102 |
+
sizes_(tensor.sizes().vec()),
|
| 103 |
+
strides_(tensor.strides().vec()) {
|
| 104 |
+
dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
|
| 105 |
+
}
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
|
| 109 |
+
if (tensors.empty())
|
| 110 |
+
return true;
|
| 111 |
+
int64_t all_numel = tensors[0].numel();
|
| 112 |
+
for (const auto i : c10::irange(1, tensors.size())) {
|
| 113 |
+
if (tensors[i].numel() != all_numel)
|
| 114 |
+
return false;
|
| 115 |
+
}
|
| 116 |
+
return true;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
|
| 120 |
+
std::ostringstream oss;
|
| 121 |
+
oss << "inconsistent tensor size, expected ";
|
| 122 |
+
for (size_t i = 0; i < tensors.size() - 1; i++) {
|
| 123 |
+
oss << tensors[i].sizes() << ", ";
|
| 124 |
+
}
|
| 125 |
+
oss << "and " << tensors[tensors.size() - 1].sizes()
|
| 126 |
+
<< " to have the same number of elements, but got ";
|
| 127 |
+
for (size_t i = 0; i < tensors.size() - 1; i++) {
|
| 128 |
+
oss << tensors[i].numel() << ", ";
|
| 129 |
+
}
|
| 130 |
+
oss << "and " << tensors[tensors.size() - 1].numel()
|
| 131 |
+
<< " elements respectively";
|
| 132 |
+
return oss.str();
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
|
| 136 |
+
checkDeviceType("CPU_tensor_apply", tensors, kCPU);
|
| 137 |
+
checkLayout("CPU_tensor_apply", tensors, kStrided);
|
| 138 |
+
if (!_all_equal_numel(tensors))
|
| 139 |
+
AT_ERROR(_all_equal_numel_error(tensors));
|
| 140 |
+
// An empty tensor has no elements
|
| 141 |
+
for (auto& t : tensors)
|
| 142 |
+
if (t.numel() == 0)
|
| 143 |
+
return false;
|
| 144 |
+
return true;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
|
| 148 |
+
int64_t dim = 0;
|
| 149 |
+
for (auto& t : tensors)
|
| 150 |
+
dim = std::max(dim, t.ndimension());
|
| 151 |
+
return dim;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
inline void iterate(int64_t /*size*/){};
|
| 155 |
+
|
| 156 |
+
template <typename Arg, typename... Args>
|
| 157 |
+
inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
|
| 158 |
+
iter.counter_[iter.dim_ - 1] += size;
|
| 159 |
+
iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
|
| 160 |
+
iterate(size, iter_tail...);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
inline bool iterate_continue() {
|
| 164 |
+
return true;
|
| 165 |
+
};
|
| 166 |
+
|
| 167 |
+
template <typename Arg, typename... Args>
|
| 168 |
+
inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
|
| 169 |
+
return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
|
| 170 |
+
iterate_continue(iter_tail...);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
inline int64_t max_iterate_size() {
|
| 174 |
+
return std::numeric_limits<int64_t>::max();
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
template <typename Arg, typename... Args>
|
| 178 |
+
inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
|
| 179 |
+
return std::min(
|
| 180 |
+
(iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
|
| 181 |
+
max_iterate_size(iter_tail...));
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
inline void iterate_overflow(){};
|
| 185 |
+
|
| 186 |
+
template <typename Arg, typename... Args>
|
| 187 |
+
inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
|
| 188 |
+
if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
|
| 189 |
+
for (int64_t i = iter.dim_ - 1; i > 0; i--) {
|
| 190 |
+
if (iter.counter_[i] == iter.sizes_[i]) {
|
| 191 |
+
iter.counter_[i] = 0;
|
| 192 |
+
iter.counter_[i - 1]++;
|
| 193 |
+
iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
|
| 194 |
+
iter.strides_[i - 1];
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
iterate_overflow(iter_tail...);
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
inline void forward(int64_t /*offset*/){};
|
| 202 |
+
|
| 203 |
+
template <typename Arg, typename... Args>
|
| 204 |
+
inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
|
| 205 |
+
int64_t multi = offset;
|
| 206 |
+
for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
|
| 207 |
+
int64_t inc = multi % iter.sizes_[i];
|
| 208 |
+
multi = multi / iter.sizes_[i];
|
| 209 |
+
iter.data_ = iter.data_ + inc * iter.strides_[i];
|
| 210 |
+
iter.counter_[i] += inc;
|
| 211 |
+
}
|
| 212 |
+
forward(offset, iter_tail...);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
inline int64_t max_dim() {
|
| 216 |
+
return 0;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
template <typename Arg, typename... Args>
|
| 220 |
+
inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
|
| 221 |
+
return std::max(iter.dim_, max_dim(iter_tail...));
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
inline void apply_op(){};
|
| 225 |
+
|
| 226 |
+
template <typename Op, typename... Args>
|
| 227 |
+
inline void apply_op(
|
| 228 |
+
int64_t numel,
|
| 229 |
+
int64_t offset,
|
| 230 |
+
const Op& op,
|
| 231 |
+
Args... iters) {
|
| 232 |
+
// For 0-dim tensors
|
| 233 |
+
if (numel == 1 && max_dim(iters...) == 0) {
|
| 234 |
+
op(*iters.data_...);
|
| 235 |
+
return;
|
| 236 |
+
}
|
| 237 |
+
if (offset > 0)
|
| 238 |
+
forward(offset, iters...);
|
| 239 |
+
// Splitting this into chunks helps the compiler create faster assembly
|
| 240 |
+
for (int64_t i = 0; i < numel;) {
|
| 241 |
+
for (; iterate_continue(iters...) && i < numel;) {
|
| 242 |
+
op(*iters.data_...);
|
| 243 |
+
iterate(1, iters...);
|
| 244 |
+
i++;
|
| 245 |
+
}
|
| 246 |
+
iterate_overflow(iters...);
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
/*
|
| 251 |
+
Apply a pointwise operator to sequence of tensors
|
| 252 |
+
|
| 253 |
+
The calling convention for op is a function/functor that takes the same
|
| 254 |
+
number of pointers of type scalar as the number of given tensors. For example,
|
| 255 |
+
to compute a = b * c, op would be of the form:
|
| 256 |
+
[](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
|
| 257 |
+
b_val[0] * c_val[0]; };
|
| 258 |
+
*/
|
| 259 |
+
|
| 260 |
+
template <typename scalar1, typename scalar2, typename Op>
|
| 261 |
+
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
|
| 262 |
+
if (!_apply_preamble({tensor1, tensor2}))
|
| 263 |
+
return;
|
| 264 |
+
if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
|
| 265 |
+
apply_op(
|
| 266 |
+
tensor1.numel(),
|
| 267 |
+
0,
|
| 268 |
+
op,
|
| 269 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 270 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
|
| 271 |
+
} else {
|
| 272 |
+
apply_op(
|
| 273 |
+
tensor1.numel(),
|
| 274 |
+
0,
|
| 275 |
+
op,
|
| 276 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 277 |
+
strided_tensor_iter<scalar2>(tensor2));
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
template <typename scalar1, typename scalar2, typename scalar3, typename Op>
|
| 282 |
+
inline void CPU_tensor_apply3(
|
| 283 |
+
Tensor tensor1,
|
| 284 |
+
Tensor tensor2,
|
| 285 |
+
Tensor tensor3,
|
| 286 |
+
const Op op) {
|
| 287 |
+
if (!_apply_preamble({tensor1, tensor2, tensor3}))
|
| 288 |
+
return;
|
| 289 |
+
if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
|
| 290 |
+
apply_op(
|
| 291 |
+
tensor1.numel(),
|
| 292 |
+
0,
|
| 293 |
+
op,
|
| 294 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 295 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
|
| 296 |
+
strided_tensor_iter_fixed<scalar3, 8>(tensor3));
|
| 297 |
+
} else {
|
| 298 |
+
apply_op(
|
| 299 |
+
tensor1.numel(),
|
| 300 |
+
0,
|
| 301 |
+
op,
|
| 302 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 303 |
+
strided_tensor_iter<scalar2>(tensor2),
|
| 304 |
+
strided_tensor_iter<scalar3>(tensor3));
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
template <
|
| 309 |
+
typename scalar1,
|
| 310 |
+
typename scalar2,
|
| 311 |
+
typename scalar3,
|
| 312 |
+
typename scalar4,
|
| 313 |
+
typename Op>
|
| 314 |
+
inline void CPU_tensor_apply4(
|
| 315 |
+
Tensor tensor1,
|
| 316 |
+
Tensor tensor2,
|
| 317 |
+
Tensor tensor3,
|
| 318 |
+
Tensor tensor4,
|
| 319 |
+
const Op op) {
|
| 320 |
+
if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
|
| 321 |
+
return;
|
| 322 |
+
if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
|
| 323 |
+
apply_op(
|
| 324 |
+
tensor1.numel(),
|
| 325 |
+
0,
|
| 326 |
+
op,
|
| 327 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 328 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
|
| 329 |
+
strided_tensor_iter_fixed<scalar3, 8>(tensor3),
|
| 330 |
+
strided_tensor_iter_fixed<scalar4, 8>(tensor4));
|
| 331 |
+
} else {
|
| 332 |
+
apply_op(
|
| 333 |
+
tensor1.numel(),
|
| 334 |
+
0,
|
| 335 |
+
op,
|
| 336 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 337 |
+
strided_tensor_iter<scalar2>(tensor2),
|
| 338 |
+
strided_tensor_iter<scalar3>(tensor3),
|
| 339 |
+
strided_tensor_iter<scalar4>(tensor4));
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/CPUFixedAllocator.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Allocator.h>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
|
| 6 |
+
// This file creates a fake allocator that just throws exceptions if
|
| 7 |
+
// it is actually used.
|
| 8 |
+
|
| 9 |
+
// state passed to the allocator is the std::function<void(void*)> called
|
| 10 |
+
// when the blob is release by ATen
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
static cpu_fixed_malloc(void*, ptrdiff_t) {
|
| 15 |
+
AT_ERROR("attempting to resize a tensor view of an external blob");
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
static cpu_fixed_realloc(void*, void*, ptrdiff_t) {
|
| 19 |
+
AT_ERROR("attempting to resize a tensor view of an external blob");
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
static cpu_fixed_free(void* state, void* allocation) {
|
| 23 |
+
auto on_release = static_cast<std::function<void(void*)>*>(state);
|
| 24 |
+
(*on_release)(allocation);
|
| 25 |
+
delete on_release;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
static Allocator CPU_fixed_allocator = {
|
| 29 |
+
cpu_fixed_malloc,
|
| 30 |
+
cpu_fixed_realloc,
|
| 31 |
+
cpu_fixed_free};
|
| 32 |
+
|
| 33 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/CPUFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CPUFunctions_inl.h>
|
.venv/Lib/site-packages/torch/include/ATen/CPUFunctions_inl.h
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_cpu_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_adaptive_avg_pool2d_cpu_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_cpu_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_adaptive_avg_pool3d_cpu_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_cpu_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_add_relu_cpu_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_addmm_activation_cpu_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_aminmax_cpu_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_amp_update_scale_cpu_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_assert_async_cpu_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_batch_norm_with_update_cpu_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cdist_backward_cpu_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cdist_forward_cpu_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_cholesky_solve_helper_cpu_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_compute_linear_combination_cpu_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_cpu_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_cpu_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_convert_weight_to_int4pack_cpu_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_ctc_loss_cpu_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_ctc_loss_backward_cpu_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_cummax_helper_cpu_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_cummin_helper_cpu_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_dirichlet_grad_cpu_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_efficientzerotensor_cpu_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_embedding_bag_cpu_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_embedding_bag_backward_cpu_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_embedding_bag_dense_backward_cpu_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_embedding_bag_forward_only_cpu_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_cpu_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_empty_affine_quantized_cpu_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_cpu_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_cpu_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cpu_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_cpu_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cpu_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cpu_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_fft_c2c_cpu_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_fft_c2r_cpu_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_fft_r2c_cpu_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_foobar_cpu_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_functional_assert_async_cpu_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_fused_adagrad_cpu_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_fused_adam_cpu_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_fused_adamw_cpu_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_cpu_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_fused_sdp_choice_cpu_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_fused_sgd_cpu_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_cpu_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_cpu_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_index_put_impl_cpu_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_int_mm_cpu_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward_cpu_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_linalg_det_cpu_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_linalg_eigh_cpu_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_linalg_eigvals_cpu_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_linalg_slogdet_cpu_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_linalg_solve_ex_cpu_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_linalg_svd_cpu_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_local_scalar_dense_cpu_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_log_softmax_cpu_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_log_softmax_backward_data_cpu_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_logcumsumexp_cpu_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_make_dep_token_cpu_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_cpu_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_masked_softmax_cpu_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_masked_softmax_backward_cpu_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_native_batch_norm_legit_cpu_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_native_multi_head_attention_cpu_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_cpu_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_nested_from_padded_cpu_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_nested_tensor_from_mask_cpu_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_cpu_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_nested_view_from_buffer_cpu_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward_cpu_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_pdist_backward_cpu_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_pdist_forward_cpu_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_prelu_kernel_cpu_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_prelu_kernel_backward_cpu_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_reshape_alias_cpu_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_sample_dirichlet_cpu_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_cpu_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_cpu_dispatch.h>
|
| 103 |
+
#include <ATen/ops/_segment_reduce_backward_cpu_dispatch.h>
|
| 104 |
+
#include <ATen/ops/_slow_conv2d_backward_cpu_dispatch.h>
|
| 105 |
+
#include <ATen/ops/_slow_conv2d_forward_cpu_dispatch.h>
|
| 106 |
+
#include <ATen/ops/_softmax_cpu_dispatch.h>
|
| 107 |
+
#include <ATen/ops/_softmax_backward_data_cpu_dispatch.h>
|
| 108 |
+
#include <ATen/ops/_spdiags_cpu_dispatch.h>
|
| 109 |
+
#include <ATen/ops/_stack_cpu_dispatch.h>
|
| 110 |
+
#include <ATen/ops/_standard_gamma_cpu_dispatch.h>
|
| 111 |
+
#include <ATen/ops/_standard_gamma_grad_cpu_dispatch.h>
|
| 112 |
+
#include <ATen/ops/_test_functorch_fallback_cpu_dispatch.h>
|
| 113 |
+
#include <ATen/ops/_test_optional_filled_intlist_cpu_dispatch.h>
|
| 114 |
+
#include <ATen/ops/_test_optional_floatlist_cpu_dispatch.h>
|
| 115 |
+
#include <ATen/ops/_test_optional_intlist_cpu_dispatch.h>
|
| 116 |
+
#include <ATen/ops/_to_sparse_cpu_dispatch.h>
|
| 117 |
+
#include <ATen/ops/_to_sparse_bsc_cpu_dispatch.h>
|
| 118 |
+
#include <ATen/ops/_to_sparse_bsr_cpu_dispatch.h>
|
| 119 |
+
#include <ATen/ops/_to_sparse_csc_cpu_dispatch.h>
|
| 120 |
+
#include <ATen/ops/_to_sparse_csr_cpu_dispatch.h>
|
| 121 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_cpu_dispatch.h>
|
| 122 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_cpu_dispatch.h>
|
| 123 |
+
#include <ATen/ops/_unique_cpu_dispatch.h>
|
| 124 |
+
#include <ATen/ops/_unique2_cpu_dispatch.h>
|
| 125 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_cpu_dispatch.h>
|
| 126 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_cpu_dispatch.h>
|
| 127 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_cpu_dispatch.h>
|
| 128 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_cpu_dispatch.h>
|
| 129 |
+
#include <ATen/ops/_upsample_nearest_exact1d_cpu_dispatch.h>
|
| 130 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_cpu_dispatch.h>
|
| 131 |
+
#include <ATen/ops/_upsample_nearest_exact2d_cpu_dispatch.h>
|
| 132 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_cpu_dispatch.h>
|
| 133 |
+
#include <ATen/ops/_upsample_nearest_exact3d_cpu_dispatch.h>
|
| 134 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_cpu_dispatch.h>
|
| 135 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_cpu_dispatch.h>
|
| 136 |
+
#include <ATen/ops/_weight_int4pack_mm_cpu_dispatch.h>
|
| 137 |
+
#include <ATen/ops/_weight_int8pack_mm_cpu_dispatch.h>
|
| 138 |
+
#include <ATen/ops/_weight_norm_interface_cpu_dispatch.h>
|
| 139 |
+
#include <ATen/ops/_weight_norm_interface_backward_cpu_dispatch.h>
|
| 140 |
+
#include <ATen/ops/abs_cpu_dispatch.h>
|
| 141 |
+
#include <ATen/ops/acos_cpu_dispatch.h>
|
| 142 |
+
#include <ATen/ops/acosh_cpu_dispatch.h>
|
| 143 |
+
#include <ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h>
|
| 144 |
+
#include <ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h>
|
| 145 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h>
|
| 146 |
+
#include <ATen/ops/adaptive_max_pool2d_cpu_dispatch.h>
|
| 147 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h>
|
| 148 |
+
#include <ATen/ops/adaptive_max_pool3d_cpu_dispatch.h>
|
| 149 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h>
|
| 150 |
+
#include <ATen/ops/add_cpu_dispatch.h>
|
| 151 |
+
#include <ATen/ops/addbmm_cpu_dispatch.h>
|
| 152 |
+
#include <ATen/ops/addcdiv_cpu_dispatch.h>
|
| 153 |
+
#include <ATen/ops/addcmul_cpu_dispatch.h>
|
| 154 |
+
#include <ATen/ops/addmm_cpu_dispatch.h>
|
| 155 |
+
#include <ATen/ops/addmv_cpu_dispatch.h>
|
| 156 |
+
#include <ATen/ops/addr_cpu_dispatch.h>
|
| 157 |
+
#include <ATen/ops/all_cpu_dispatch.h>
|
| 158 |
+
#include <ATen/ops/amax_cpu_dispatch.h>
|
| 159 |
+
#include <ATen/ops/amin_cpu_dispatch.h>
|
| 160 |
+
#include <ATen/ops/aminmax_cpu_dispatch.h>
|
| 161 |
+
#include <ATen/ops/angle_cpu_dispatch.h>
|
| 162 |
+
#include <ATen/ops/any_cpu_dispatch.h>
|
| 163 |
+
#include <ATen/ops/arange_cpu_dispatch.h>
|
| 164 |
+
#include <ATen/ops/argmax_cpu_dispatch.h>
|
| 165 |
+
#include <ATen/ops/argmin_cpu_dispatch.h>
|
| 166 |
+
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
| 167 |
+
#include <ATen/ops/asin_cpu_dispatch.h>
|
| 168 |
+
#include <ATen/ops/asinh_cpu_dispatch.h>
|
| 169 |
+
#include <ATen/ops/atan_cpu_dispatch.h>
|
| 170 |
+
#include <ATen/ops/atan2_cpu_dispatch.h>
|
| 171 |
+
#include <ATen/ops/atanh_cpu_dispatch.h>
|
| 172 |
+
#include <ATen/ops/avg_pool2d_cpu_dispatch.h>
|
| 173 |
+
#include <ATen/ops/avg_pool2d_backward_cpu_dispatch.h>
|
| 174 |
+
#include <ATen/ops/avg_pool3d_cpu_dispatch.h>
|
| 175 |
+
#include <ATen/ops/avg_pool3d_backward_cpu_dispatch.h>
|
| 176 |
+
#include <ATen/ops/baddbmm_cpu_dispatch.h>
|
| 177 |
+
#include <ATen/ops/batch_norm_backward_cpu_dispatch.h>
|
| 178 |
+
#include <ATen/ops/batch_norm_update_stats_cpu_dispatch.h>
|
| 179 |
+
#include <ATen/ops/bernoulli_cpu_dispatch.h>
|
| 180 |
+
#include <ATen/ops/binary_cross_entropy_cpu_dispatch.h>
|
| 181 |
+
#include <ATen/ops/binary_cross_entropy_backward_cpu_dispatch.h>
|
| 182 |
+
#include <ATen/ops/bincount_cpu_dispatch.h>
|
| 183 |
+
#include <ATen/ops/binomial_cpu_dispatch.h>
|
| 184 |
+
#include <ATen/ops/bitwise_and_cpu_dispatch.h>
|
| 185 |
+
#include <ATen/ops/bitwise_left_shift_cpu_dispatch.h>
|
| 186 |
+
#include <ATen/ops/bitwise_not_cpu_dispatch.h>
|
| 187 |
+
#include <ATen/ops/bitwise_or_cpu_dispatch.h>
|
| 188 |
+
#include <ATen/ops/bitwise_right_shift_cpu_dispatch.h>
|
| 189 |
+
#include <ATen/ops/bitwise_xor_cpu_dispatch.h>
|
| 190 |
+
#include <ATen/ops/bmm_cpu_dispatch.h>
|
| 191 |
+
#include <ATen/ops/bucketize_cpu_dispatch.h>
|
| 192 |
+
#include <ATen/ops/cat_cpu_dispatch.h>
|
| 193 |
+
#include <ATen/ops/cauchy_cpu_dispatch.h>
|
| 194 |
+
#include <ATen/ops/ceil_cpu_dispatch.h>
|
| 195 |
+
#include <ATen/ops/channel_shuffle_cpu_dispatch.h>
|
| 196 |
+
#include <ATen/ops/cholesky_cpu_dispatch.h>
|
| 197 |
+
#include <ATen/ops/cholesky_inverse_cpu_dispatch.h>
|
| 198 |
+
#include <ATen/ops/clamp_cpu_dispatch.h>
|
| 199 |
+
#include <ATen/ops/clamp_max_cpu_dispatch.h>
|
| 200 |
+
#include <ATen/ops/clamp_min_cpu_dispatch.h>
|
| 201 |
+
#include <ATen/ops/col2im_cpu_dispatch.h>
|
| 202 |
+
#include <ATen/ops/complex_cpu_dispatch.h>
|
| 203 |
+
#include <ATen/ops/conj_physical_cpu_dispatch.h>
|
| 204 |
+
#include <ATen/ops/copysign_cpu_dispatch.h>
|
| 205 |
+
#include <ATen/ops/cos_cpu_dispatch.h>
|
| 206 |
+
#include <ATen/ops/cosh_cpu_dispatch.h>
|
| 207 |
+
#include <ATen/ops/count_nonzero_cpu_dispatch.h>
|
| 208 |
+
#include <ATen/ops/cumprod_cpu_dispatch.h>
|
| 209 |
+
#include <ATen/ops/cumsum_cpu_dispatch.h>
|
| 210 |
+
#include <ATen/ops/dequantize_cpu_dispatch.h>
|
| 211 |
+
#include <ATen/ops/digamma_cpu_dispatch.h>
|
| 212 |
+
#include <ATen/ops/div_cpu_dispatch.h>
|
| 213 |
+
#include <ATen/ops/dot_cpu_dispatch.h>
|
| 214 |
+
#include <ATen/ops/elu_cpu_dispatch.h>
|
| 215 |
+
#include <ATen/ops/elu_backward_cpu_dispatch.h>
|
| 216 |
+
#include <ATen/ops/embedding_dense_backward_cpu_dispatch.h>
|
| 217 |
+
#include <ATen/ops/embedding_renorm_cpu_dispatch.h>
|
| 218 |
+
#include <ATen/ops/empty_cpu_dispatch.h>
|
| 219 |
+
#include <ATen/ops/empty_strided_cpu_dispatch.h>
|
| 220 |
+
#include <ATen/ops/eq_cpu_dispatch.h>
|
| 221 |
+
#include <ATen/ops/equal_cpu_dispatch.h>
|
| 222 |
+
#include <ATen/ops/erf_cpu_dispatch.h>
|
| 223 |
+
#include <ATen/ops/erfc_cpu_dispatch.h>
|
| 224 |
+
#include <ATen/ops/erfinv_cpu_dispatch.h>
|
| 225 |
+
#include <ATen/ops/exp_cpu_dispatch.h>
|
| 226 |
+
#include <ATen/ops/exp2_cpu_dispatch.h>
|
| 227 |
+
#include <ATen/ops/expm1_cpu_dispatch.h>
|
| 228 |
+
#include <ATen/ops/exponential_cpu_dispatch.h>
|
| 229 |
+
#include <ATen/ops/eye_cpu_dispatch.h>
|
| 230 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_cpu_dispatch.h>
|
| 231 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_cpu_dispatch.h>
|
| 232 |
+
#include <ATen/ops/fill_cpu_dispatch.h>
|
| 233 |
+
#include <ATen/ops/flip_cpu_dispatch.h>
|
| 234 |
+
#include <ATen/ops/floor_cpu_dispatch.h>
|
| 235 |
+
#include <ATen/ops/floor_divide_cpu_dispatch.h>
|
| 236 |
+
#include <ATen/ops/fmax_cpu_dispatch.h>
|
| 237 |
+
#include <ATen/ops/fmin_cpu_dispatch.h>
|
| 238 |
+
#include <ATen/ops/fmod_cpu_dispatch.h>
|
| 239 |
+
#include <ATen/ops/frac_cpu_dispatch.h>
|
| 240 |
+
#include <ATen/ops/fractional_max_pool2d_cpu_dispatch.h>
|
| 241 |
+
#include <ATen/ops/fractional_max_pool2d_backward_cpu_dispatch.h>
|
| 242 |
+
#include <ATen/ops/fractional_max_pool3d_cpu_dispatch.h>
|
| 243 |
+
#include <ATen/ops/fractional_max_pool3d_backward_cpu_dispatch.h>
|
| 244 |
+
#include <ATen/ops/frexp_cpu_dispatch.h>
|
| 245 |
+
#include <ATen/ops/from_file_cpu_dispatch.h>
|
| 246 |
+
#include <ATen/ops/gather_cpu_dispatch.h>
|
| 247 |
+
#include <ATen/ops/gcd_cpu_dispatch.h>
|
| 248 |
+
#include <ATen/ops/ge_cpu_dispatch.h>
|
| 249 |
+
#include <ATen/ops/gelu_cpu_dispatch.h>
|
| 250 |
+
#include <ATen/ops/gelu_backward_cpu_dispatch.h>
|
| 251 |
+
#include <ATen/ops/geometric_cpu_dispatch.h>
|
| 252 |
+
#include <ATen/ops/geqrf_cpu_dispatch.h>
|
| 253 |
+
#include <ATen/ops/glu_cpu_dispatch.h>
|
| 254 |
+
#include <ATen/ops/glu_backward_cpu_dispatch.h>
|
| 255 |
+
#include <ATen/ops/glu_backward_jvp_cpu_dispatch.h>
|
| 256 |
+
#include <ATen/ops/glu_jvp_cpu_dispatch.h>
|
| 257 |
+
#include <ATen/ops/grid_sampler_2d_cpu_dispatch.h>
|
| 258 |
+
#include <ATen/ops/grid_sampler_2d_backward_cpu_dispatch.h>
|
| 259 |
+
#include <ATen/ops/grid_sampler_3d_cpu_dispatch.h>
|
| 260 |
+
#include <ATen/ops/grid_sampler_3d_backward_cpu_dispatch.h>
|
| 261 |
+
#include <ATen/ops/gt_cpu_dispatch.h>
|
| 262 |
+
#include <ATen/ops/hardshrink_cpu_dispatch.h>
|
| 263 |
+
#include <ATen/ops/hardshrink_backward_cpu_dispatch.h>
|
| 264 |
+
#include <ATen/ops/hardsigmoid_cpu_dispatch.h>
|
| 265 |
+
#include <ATen/ops/hardsigmoid_backward_cpu_dispatch.h>
|
| 266 |
+
#include <ATen/ops/hardswish_cpu_dispatch.h>
|
| 267 |
+
#include <ATen/ops/hardswish_backward_cpu_dispatch.h>
|
| 268 |
+
#include <ATen/ops/hardtanh_cpu_dispatch.h>
|
| 269 |
+
#include <ATen/ops/hardtanh_backward_cpu_dispatch.h>
|
| 270 |
+
#include <ATen/ops/heaviside_cpu_dispatch.h>
|
| 271 |
+
#include <ATen/ops/histc_cpu_dispatch.h>
|
| 272 |
+
#include <ATen/ops/histogram_cpu_dispatch.h>
|
| 273 |
+
#include <ATen/ops/huber_loss_cpu_dispatch.h>
|
| 274 |
+
#include <ATen/ops/huber_loss_backward_cpu_dispatch.h>
|
| 275 |
+
#include <ATen/ops/hypot_cpu_dispatch.h>
|
| 276 |
+
#include <ATen/ops/i0_cpu_dispatch.h>
|
| 277 |
+
#include <ATen/ops/igamma_cpu_dispatch.h>
|
| 278 |
+
#include <ATen/ops/igammac_cpu_dispatch.h>
|
| 279 |
+
#include <ATen/ops/im2col_cpu_dispatch.h>
|
| 280 |
+
#include <ATen/ops/index_cpu_dispatch.h>
|
| 281 |
+
#include <ATen/ops/index_add_cpu_dispatch.h>
|
| 282 |
+
#include <ATen/ops/index_copy_cpu_dispatch.h>
|
| 283 |
+
#include <ATen/ops/index_fill_cpu_dispatch.h>
|
| 284 |
+
#include <ATen/ops/index_reduce_cpu_dispatch.h>
|
| 285 |
+
#include <ATen/ops/index_select_cpu_dispatch.h>
|
| 286 |
+
#include <ATen/ops/is_set_to_cpu_dispatch.h>
|
| 287 |
+
#include <ATen/ops/isin_cpu_dispatch.h>
|
| 288 |
+
#include <ATen/ops/isnan_cpu_dispatch.h>
|
| 289 |
+
#include <ATen/ops/isneginf_cpu_dispatch.h>
|
| 290 |
+
#include <ATen/ops/isposinf_cpu_dispatch.h>
|
| 291 |
+
#include <ATen/ops/kthvalue_cpu_dispatch.h>
|
| 292 |
+
#include <ATen/ops/lcm_cpu_dispatch.h>
|
| 293 |
+
#include <ATen/ops/le_cpu_dispatch.h>
|
| 294 |
+
#include <ATen/ops/leaky_relu_cpu_dispatch.h>
|
| 295 |
+
#include <ATen/ops/leaky_relu_backward_cpu_dispatch.h>
|
| 296 |
+
#include <ATen/ops/lerp_cpu_dispatch.h>
|
| 297 |
+
#include <ATen/ops/lgamma_cpu_dispatch.h>
|
| 298 |
+
#include <ATen/ops/linalg_cholesky_ex_cpu_dispatch.h>
|
| 299 |
+
#include <ATen/ops/linalg_cross_cpu_dispatch.h>
|
| 300 |
+
#include <ATen/ops/linalg_eig_cpu_dispatch.h>
|
| 301 |
+
#include <ATen/ops/linalg_eigvals_cpu_dispatch.h>
|
| 302 |
+
#include <ATen/ops/linalg_householder_product_cpu_dispatch.h>
|
| 303 |
+
#include <ATen/ops/linalg_inv_ex_cpu_dispatch.h>
|
| 304 |
+
#include <ATen/ops/linalg_ldl_factor_ex_cpu_dispatch.h>
|
| 305 |
+
#include <ATen/ops/linalg_ldl_solve_cpu_dispatch.h>
|
| 306 |
+
#include <ATen/ops/linalg_lstsq_cpu_dispatch.h>
|
| 307 |
+
#include <ATen/ops/linalg_lu_cpu_dispatch.h>
|
| 308 |
+
#include <ATen/ops/linalg_lu_factor_ex_cpu_dispatch.h>
|
| 309 |
+
#include <ATen/ops/linalg_lu_solve_cpu_dispatch.h>
|
| 310 |
+
#include <ATen/ops/linalg_matrix_exp_cpu_dispatch.h>
|
| 311 |
+
#include <ATen/ops/linalg_qr_cpu_dispatch.h>
|
| 312 |
+
#include <ATen/ops/linalg_solve_triangular_cpu_dispatch.h>
|
| 313 |
+
#include <ATen/ops/linalg_vector_norm_cpu_dispatch.h>
|
| 314 |
+
#include <ATen/ops/linspace_cpu_dispatch.h>
|
| 315 |
+
#include <ATen/ops/log_cpu_dispatch.h>
|
| 316 |
+
#include <ATen/ops/log10_cpu_dispatch.h>
|
| 317 |
+
#include <ATen/ops/log1p_cpu_dispatch.h>
|
| 318 |
+
#include <ATen/ops/log2_cpu_dispatch.h>
|
| 319 |
+
#include <ATen/ops/log_normal_cpu_dispatch.h>
|
| 320 |
+
#include <ATen/ops/log_sigmoid_backward_cpu_dispatch.h>
|
| 321 |
+
#include <ATen/ops/log_sigmoid_forward_cpu_dispatch.h>
|
| 322 |
+
#include <ATen/ops/logaddexp_cpu_dispatch.h>
|
| 323 |
+
#include <ATen/ops/logaddexp2_cpu_dispatch.h>
|
| 324 |
+
#include <ATen/ops/logical_and_cpu_dispatch.h>
|
| 325 |
+
#include <ATen/ops/logical_not_cpu_dispatch.h>
|
| 326 |
+
#include <ATen/ops/logical_or_cpu_dispatch.h>
|
| 327 |
+
#include <ATen/ops/logical_xor_cpu_dispatch.h>
|
| 328 |
+
#include <ATen/ops/logit_cpu_dispatch.h>
|
| 329 |
+
#include <ATen/ops/logit_backward_cpu_dispatch.h>
|
| 330 |
+
#include <ATen/ops/logspace_cpu_dispatch.h>
|
| 331 |
+
#include <ATen/ops/lshift_cpu_dispatch.h>
|
| 332 |
+
#include <ATen/ops/lt_cpu_dispatch.h>
|
| 333 |
+
#include <ATen/ops/lu_unpack_cpu_dispatch.h>
|
| 334 |
+
#include <ATen/ops/masked_fill_cpu_dispatch.h>
|
| 335 |
+
#include <ATen/ops/masked_scatter_cpu_dispatch.h>
|
| 336 |
+
#include <ATen/ops/masked_select_cpu_dispatch.h>
|
| 337 |
+
#include <ATen/ops/max_cpu_dispatch.h>
|
| 338 |
+
#include <ATen/ops/max_pool2d_with_indices_cpu_dispatch.h>
|
| 339 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_cpu_dispatch.h>
|
| 340 |
+
#include <ATen/ops/max_pool3d_with_indices_cpu_dispatch.h>
|
| 341 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_cpu_dispatch.h>
|
| 342 |
+
#include <ATen/ops/max_unpool2d_cpu_dispatch.h>
|
| 343 |
+
#include <ATen/ops/max_unpool3d_cpu_dispatch.h>
|
| 344 |
+
#include <ATen/ops/maximum_cpu_dispatch.h>
|
| 345 |
+
#include <ATen/ops/mean_cpu_dispatch.h>
|
| 346 |
+
#include <ATen/ops/median_cpu_dispatch.h>
|
| 347 |
+
#include <ATen/ops/min_cpu_dispatch.h>
|
| 348 |
+
#include <ATen/ops/minimum_cpu_dispatch.h>
|
| 349 |
+
#include <ATen/ops/mish_cpu_dispatch.h>
|
| 350 |
+
#include <ATen/ops/mish_backward_cpu_dispatch.h>
|
| 351 |
+
#include <ATen/ops/mkldnn_rnn_layer_cpu_dispatch.h>
|
| 352 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_cpu_dispatch.h>
|
| 353 |
+
#include <ATen/ops/mm_cpu_dispatch.h>
|
| 354 |
+
#include <ATen/ops/mode_cpu_dispatch.h>
|
| 355 |
+
#include <ATen/ops/mse_loss_cpu_dispatch.h>
|
| 356 |
+
#include <ATen/ops/mse_loss_backward_cpu_dispatch.h>
|
| 357 |
+
#include <ATen/ops/mul_cpu_dispatch.h>
|
| 358 |
+
#include <ATen/ops/multi_margin_loss_cpu_dispatch.h>
|
| 359 |
+
#include <ATen/ops/multi_margin_loss_backward_cpu_dispatch.h>
|
| 360 |
+
#include <ATen/ops/multilabel_margin_loss_backward_cpu_dispatch.h>
|
| 361 |
+
#include <ATen/ops/multilabel_margin_loss_forward_cpu_dispatch.h>
|
| 362 |
+
#include <ATen/ops/multinomial_cpu_dispatch.h>
|
| 363 |
+
#include <ATen/ops/mvlgamma_cpu_dispatch.h>
|
| 364 |
+
#include <ATen/ops/nan_to_num_cpu_dispatch.h>
|
| 365 |
+
#include <ATen/ops/nanmedian_cpu_dispatch.h>
|
| 366 |
+
#include <ATen/ops/nansum_cpu_dispatch.h>
|
| 367 |
+
#include <ATen/ops/narrow_copy_cpu_dispatch.h>
|
| 368 |
+
#include <ATen/ops/native_batch_norm_cpu_dispatch.h>
|
| 369 |
+
#include <ATen/ops/native_batch_norm_backward_cpu_dispatch.h>
|
| 370 |
+
#include <ATen/ops/native_channel_shuffle_cpu_dispatch.h>
|
| 371 |
+
#include <ATen/ops/native_dropout_cpu_dispatch.h>
|
| 372 |
+
#include <ATen/ops/native_dropout_backward_cpu_dispatch.h>
|
| 373 |
+
#include <ATen/ops/native_group_norm_cpu_dispatch.h>
|
| 374 |
+
#include <ATen/ops/native_group_norm_backward_cpu_dispatch.h>
|
| 375 |
+
#include <ATen/ops/native_layer_norm_cpu_dispatch.h>
|
| 376 |
+
#include <ATen/ops/native_layer_norm_backward_cpu_dispatch.h>
|
| 377 |
+
#include <ATen/ops/ne_cpu_dispatch.h>
|
| 378 |
+
#include <ATen/ops/neg_cpu_dispatch.h>
|
| 379 |
+
#include <ATen/ops/nextafter_cpu_dispatch.h>
|
| 380 |
+
#include <ATen/ops/nll_loss2d_backward_cpu_dispatch.h>
|
| 381 |
+
#include <ATen/ops/nll_loss2d_forward_cpu_dispatch.h>
|
| 382 |
+
#include <ATen/ops/nll_loss_backward_cpu_dispatch.h>
|
| 383 |
+
#include <ATen/ops/nll_loss_forward_cpu_dispatch.h>
|
| 384 |
+
#include <ATen/ops/nonzero_cpu_dispatch.h>
|
| 385 |
+
#include <ATen/ops/nonzero_static_cpu_dispatch.h>
|
| 386 |
+
#include <ATen/ops/norm_cpu_dispatch.h>
|
| 387 |
+
#include <ATen/ops/normal_cpu_dispatch.h>
|
| 388 |
+
#include <ATen/ops/ormqr_cpu_dispatch.h>
|
| 389 |
+
#include <ATen/ops/pixel_shuffle_cpu_dispatch.h>
|
| 390 |
+
#include <ATen/ops/pixel_unshuffle_cpu_dispatch.h>
|
| 391 |
+
#include <ATen/ops/poisson_cpu_dispatch.h>
|
| 392 |
+
#include <ATen/ops/polar_cpu_dispatch.h>
|
| 393 |
+
#include <ATen/ops/polygamma_cpu_dispatch.h>
|
| 394 |
+
#include <ATen/ops/pow_cpu_dispatch.h>
|
| 395 |
+
#include <ATen/ops/prod_cpu_dispatch.h>
|
| 396 |
+
#include <ATen/ops/put_cpu_dispatch.h>
|
| 397 |
+
#include <ATen/ops/quantize_per_channel_cpu_dispatch.h>
|
| 398 |
+
#include <ATen/ops/quantize_per_tensor_cpu_dispatch.h>
|
| 399 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_cpu_dispatch.h>
|
| 400 |
+
#include <ATen/ops/random_cpu_dispatch.h>
|
| 401 |
+
#include <ATen/ops/randperm_cpu_dispatch.h>
|
| 402 |
+
#include <ATen/ops/range_cpu_dispatch.h>
|
| 403 |
+
#include <ATen/ops/reciprocal_cpu_dispatch.h>
|
| 404 |
+
#include <ATen/ops/reflection_pad1d_cpu_dispatch.h>
|
| 405 |
+
#include <ATen/ops/reflection_pad1d_backward_cpu_dispatch.h>
|
| 406 |
+
#include <ATen/ops/reflection_pad2d_cpu_dispatch.h>
|
| 407 |
+
#include <ATen/ops/reflection_pad2d_backward_cpu_dispatch.h>
|
| 408 |
+
#include <ATen/ops/reflection_pad3d_cpu_dispatch.h>
|
| 409 |
+
#include <ATen/ops/reflection_pad3d_backward_cpu_dispatch.h>
|
| 410 |
+
#include <ATen/ops/relu_cpu_dispatch.h>
|
| 411 |
+
#include <ATen/ops/remainder_cpu_dispatch.h>
|
| 412 |
+
#include <ATen/ops/renorm_cpu_dispatch.h>
|
| 413 |
+
#include <ATen/ops/repeat_interleave_cpu_dispatch.h>
|
| 414 |
+
#include <ATen/ops/replication_pad1d_cpu_dispatch.h>
|
| 415 |
+
#include <ATen/ops/replication_pad1d_backward_cpu_dispatch.h>
|
| 416 |
+
#include <ATen/ops/replication_pad2d_cpu_dispatch.h>
|
| 417 |
+
#include <ATen/ops/replication_pad2d_backward_cpu_dispatch.h>
|
| 418 |
+
#include <ATen/ops/replication_pad3d_cpu_dispatch.h>
|
| 419 |
+
#include <ATen/ops/replication_pad3d_backward_cpu_dispatch.h>
|
| 420 |
+
#include <ATen/ops/resize_cpu_dispatch.h>
|
| 421 |
+
#include <ATen/ops/roll_cpu_dispatch.h>
|
| 422 |
+
#include <ATen/ops/round_cpu_dispatch.h>
|
| 423 |
+
#include <ATen/ops/rrelu_with_noise_cpu_dispatch.h>
|
| 424 |
+
#include <ATen/ops/rshift_cpu_dispatch.h>
|
| 425 |
+
#include <ATen/ops/rsqrt_cpu_dispatch.h>
|
| 426 |
+
#include <ATen/ops/rsub_cpu_dispatch.h>
|
| 427 |
+
#include <ATen/ops/scatter_cpu_dispatch.h>
|
| 428 |
+
#include <ATen/ops/scatter_add_cpu_dispatch.h>
|
| 429 |
+
#include <ATen/ops/scatter_reduce_cpu_dispatch.h>
|
| 430 |
+
#include <ATen/ops/searchsorted_cpu_dispatch.h>
|
| 431 |
+
#include <ATen/ops/segment_reduce_cpu_dispatch.h>
|
| 432 |
+
#include <ATen/ops/set_cpu_dispatch.h>
|
| 433 |
+
#include <ATen/ops/sgn_cpu_dispatch.h>
|
| 434 |
+
#include <ATen/ops/sigmoid_cpu_dispatch.h>
|
| 435 |
+
#include <ATen/ops/sigmoid_backward_cpu_dispatch.h>
|
| 436 |
+
#include <ATen/ops/sign_cpu_dispatch.h>
|
| 437 |
+
#include <ATen/ops/signbit_cpu_dispatch.h>
|
| 438 |
+
#include <ATen/ops/silu_cpu_dispatch.h>
|
| 439 |
+
#include <ATen/ops/silu_backward_cpu_dispatch.h>
|
| 440 |
+
#include <ATen/ops/sin_cpu_dispatch.h>
|
| 441 |
+
#include <ATen/ops/sinc_cpu_dispatch.h>
|
| 442 |
+
#include <ATen/ops/sinh_cpu_dispatch.h>
|
| 443 |
+
#include <ATen/ops/slow_conv3d_forward_cpu_dispatch.h>
|
| 444 |
+
#include <ATen/ops/slow_conv_dilated2d_cpu_dispatch.h>
|
| 445 |
+
#include <ATen/ops/slow_conv_dilated3d_cpu_dispatch.h>
|
| 446 |
+
#include <ATen/ops/slow_conv_transpose2d_cpu_dispatch.h>
|
| 447 |
+
#include <ATen/ops/slow_conv_transpose3d_cpu_dispatch.h>
|
| 448 |
+
#include <ATen/ops/smooth_l1_loss_cpu_dispatch.h>
|
| 449 |
+
#include <ATen/ops/smooth_l1_loss_backward_cpu_dispatch.h>
|
| 450 |
+
#include <ATen/ops/softplus_cpu_dispatch.h>
|
| 451 |
+
#include <ATen/ops/softplus_backward_cpu_dispatch.h>
|
| 452 |
+
#include <ATen/ops/softshrink_cpu_dispatch.h>
|
| 453 |
+
#include <ATen/ops/softshrink_backward_cpu_dispatch.h>
|
| 454 |
+
#include <ATen/ops/sort_cpu_dispatch.h>
|
| 455 |
+
#include <ATen/ops/special_airy_ai_cpu_dispatch.h>
|
| 456 |
+
#include <ATen/ops/special_bessel_j0_cpu_dispatch.h>
|
| 457 |
+
#include <ATen/ops/special_bessel_j1_cpu_dispatch.h>
|
| 458 |
+
#include <ATen/ops/special_bessel_y0_cpu_dispatch.h>
|
| 459 |
+
#include <ATen/ops/special_bessel_y1_cpu_dispatch.h>
|
| 460 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_cpu_dispatch.h>
|
| 461 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_cpu_dispatch.h>
|
| 462 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_cpu_dispatch.h>
|
| 463 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_cpu_dispatch.h>
|
| 464 |
+
#include <ATen/ops/special_entr_cpu_dispatch.h>
|
| 465 |
+
#include <ATen/ops/special_erfcx_cpu_dispatch.h>
|
| 466 |
+
#include <ATen/ops/special_hermite_polynomial_h_cpu_dispatch.h>
|
| 467 |
+
#include <ATen/ops/special_hermite_polynomial_he_cpu_dispatch.h>
|
| 468 |
+
#include <ATen/ops/special_i0e_cpu_dispatch.h>
|
| 469 |
+
#include <ATen/ops/special_i1_cpu_dispatch.h>
|
| 470 |
+
#include <ATen/ops/special_i1e_cpu_dispatch.h>
|
| 471 |
+
#include <ATen/ops/special_laguerre_polynomial_l_cpu_dispatch.h>
|
| 472 |
+
#include <ATen/ops/special_legendre_polynomial_p_cpu_dispatch.h>
|
| 473 |
+
#include <ATen/ops/special_log_ndtr_cpu_dispatch.h>
|
| 474 |
+
#include <ATen/ops/special_modified_bessel_i0_cpu_dispatch.h>
|
| 475 |
+
#include <ATen/ops/special_modified_bessel_i1_cpu_dispatch.h>
|
| 476 |
+
#include <ATen/ops/special_modified_bessel_k0_cpu_dispatch.h>
|
| 477 |
+
#include <ATen/ops/special_modified_bessel_k1_cpu_dispatch.h>
|
| 478 |
+
#include <ATen/ops/special_ndtri_cpu_dispatch.h>
|
| 479 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_cpu_dispatch.h>
|
| 480 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_cpu_dispatch.h>
|
| 481 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_cpu_dispatch.h>
|
| 482 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_cpu_dispatch.h>
|
| 483 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_cpu_dispatch.h>
|
| 484 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_cpu_dispatch.h>
|
| 485 |
+
#include <ATen/ops/special_spherical_bessel_j0_cpu_dispatch.h>
|
| 486 |
+
#include <ATen/ops/special_xlog1py_cpu_dispatch.h>
|
| 487 |
+
#include <ATen/ops/special_zeta_cpu_dispatch.h>
|
| 488 |
+
#include <ATen/ops/sqrt_cpu_dispatch.h>
|
| 489 |
+
#include <ATen/ops/sspaddmm_cpu_dispatch.h>
|
| 490 |
+
#include <ATen/ops/std_cpu_dispatch.h>
|
| 491 |
+
#include <ATen/ops/std_mean_cpu_dispatch.h>
|
| 492 |
+
#include <ATen/ops/sub_cpu_dispatch.h>
|
| 493 |
+
#include <ATen/ops/sum_cpu_dispatch.h>
|
| 494 |
+
#include <ATen/ops/take_cpu_dispatch.h>
|
| 495 |
+
#include <ATen/ops/tan_cpu_dispatch.h>
|
| 496 |
+
#include <ATen/ops/tanh_cpu_dispatch.h>
|
| 497 |
+
#include <ATen/ops/tanh_backward_cpu_dispatch.h>
|
| 498 |
+
#include <ATen/ops/threshold_cpu_dispatch.h>
|
| 499 |
+
#include <ATen/ops/threshold_backward_cpu_dispatch.h>
|
| 500 |
+
#include <ATen/ops/to_mkldnn_cpu_dispatch.h>
|
| 501 |
+
#include <ATen/ops/topk_cpu_dispatch.h>
|
| 502 |
+
#include <ATen/ops/trace_cpu_dispatch.h>
|
| 503 |
+
#include <ATen/ops/triangular_solve_cpu_dispatch.h>
|
| 504 |
+
#include <ATen/ops/tril_cpu_dispatch.h>
|
| 505 |
+
#include <ATen/ops/tril_indices_cpu_dispatch.h>
|
| 506 |
+
#include <ATen/ops/triu_cpu_dispatch.h>
|
| 507 |
+
#include <ATen/ops/triu_indices_cpu_dispatch.h>
|
| 508 |
+
#include <ATen/ops/trunc_cpu_dispatch.h>
|
| 509 |
+
#include <ATen/ops/unfold_cpu_dispatch.h>
|
| 510 |
+
#include <ATen/ops/unfold_backward_cpu_dispatch.h>
|
| 511 |
+
#include <ATen/ops/uniform_cpu_dispatch.h>
|
| 512 |
+
#include <ATen/ops/unique_consecutive_cpu_dispatch.h>
|
| 513 |
+
#include <ATen/ops/unique_dim_cpu_dispatch.h>
|
| 514 |
+
#include <ATen/ops/unique_dim_consecutive_cpu_dispatch.h>
|
| 515 |
+
#include <ATen/ops/upsample_bicubic2d_cpu_dispatch.h>
|
| 516 |
+
#include <ATen/ops/upsample_bicubic2d_backward_cpu_dispatch.h>
|
| 517 |
+
#include <ATen/ops/upsample_bilinear2d_cpu_dispatch.h>
|
| 518 |
+
#include <ATen/ops/upsample_bilinear2d_backward_cpu_dispatch.h>
|
| 519 |
+
#include <ATen/ops/upsample_linear1d_cpu_dispatch.h>
|
| 520 |
+
#include <ATen/ops/upsample_linear1d_backward_cpu_dispatch.h>
|
| 521 |
+
#include <ATen/ops/upsample_nearest1d_cpu_dispatch.h>
|
| 522 |
+
#include <ATen/ops/upsample_nearest1d_backward_cpu_dispatch.h>
|
| 523 |
+
#include <ATen/ops/upsample_nearest2d_cpu_dispatch.h>
|
| 524 |
+
#include <ATen/ops/upsample_nearest2d_backward_cpu_dispatch.h>
|
| 525 |
+
#include <ATen/ops/upsample_nearest3d_cpu_dispatch.h>
|
| 526 |
+
#include <ATen/ops/upsample_nearest3d_backward_cpu_dispatch.h>
|
| 527 |
+
#include <ATen/ops/upsample_trilinear3d_cpu_dispatch.h>
|
| 528 |
+
#include <ATen/ops/upsample_trilinear3d_backward_cpu_dispatch.h>
|
| 529 |
+
#include <ATen/ops/var_cpu_dispatch.h>
|
| 530 |
+
#include <ATen/ops/var_mean_cpu_dispatch.h>
|
| 531 |
+
#include <ATen/ops/vdot_cpu_dispatch.h>
|
| 532 |
+
#include <ATen/ops/view_cpu_dispatch.h>
|
| 533 |
+
#include <ATen/ops/view_as_complex_cpu_dispatch.h>
|
| 534 |
+
#include <ATen/ops/view_as_real_cpu_dispatch.h>
|
| 535 |
+
#include <ATen/ops/where_cpu_dispatch.h>
|
| 536 |
+
#include <ATen/ops/xlogy_cpu_dispatch.h>
|
| 537 |
+
#include <ATen/ops/zero_cpu_dispatch.h>
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
|
.venv/Lib/site-packages/torch/include/ATen/CPUGeneratorImpl.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Generator.h>
|
| 4 |
+
#include <ATen/core/MT19937RNGEngine.h>
|
| 5 |
+
#include <c10/core/GeneratorImpl.h>
|
| 6 |
+
#include <optional>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
|
| 11 |
+
// Constructors
|
| 12 |
+
CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
|
| 13 |
+
~CPUGeneratorImpl() override = default;
|
| 14 |
+
|
| 15 |
+
// CPUGeneratorImpl methods
|
| 16 |
+
std::shared_ptr<CPUGeneratorImpl> clone() const;
|
| 17 |
+
void set_current_seed(uint64_t seed) override;
|
| 18 |
+
void set_offset(uint64_t offset) override;
|
| 19 |
+
uint64_t get_offset() const override;
|
| 20 |
+
uint64_t current_seed() const override;
|
| 21 |
+
uint64_t seed() override;
|
| 22 |
+
void set_state(const c10::TensorImpl& new_state) override;
|
| 23 |
+
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
| 24 |
+
static c10::DeviceType device_type();
|
| 25 |
+
uint32_t random();
|
| 26 |
+
uint64_t random64();
|
| 27 |
+
std::optional<float> next_float_normal_sample();
|
| 28 |
+
std::optional<double> next_double_normal_sample();
|
| 29 |
+
void set_next_float_normal_sample(std::optional<float> randn);
|
| 30 |
+
void set_next_double_normal_sample(std::optional<double> randn);
|
| 31 |
+
at::mt19937 engine();
|
| 32 |
+
void set_engine(at::mt19937 engine);
|
| 33 |
+
|
| 34 |
+
private:
|
| 35 |
+
CPUGeneratorImpl* clone_impl() const override;
|
| 36 |
+
at::mt19937 engine_;
|
| 37 |
+
std::optional<float> next_float_normal_sample_;
|
| 38 |
+
std::optional<double> next_double_normal_sample_;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
namespace detail {
|
| 42 |
+
|
| 43 |
+
TORCH_API const Generator& getDefaultCPUGenerator();
|
| 44 |
+
TORCH_API Generator
|
| 45 |
+
createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
|
| 46 |
+
|
| 47 |
+
} // namespace detail
|
| 48 |
+
|
| 49 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/CUDAFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CUDAFunctions_inl.h>
|
.venv/Lib/site-packages/torch/include/ATen/CUDAFunctions_inl.h
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_cuda_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_adaptive_avg_pool2d_cuda_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_cuda_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_adaptive_avg_pool3d_cuda_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_cuda_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_addmm_activation_cuda_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_aminmax_cuda_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_cuda_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_amp_update_scale_cuda_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_assert_async_cuda_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_batch_norm_with_update_cuda_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_cdist_backward_cuda_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cdist_forward_cuda_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cholesky_solve_helper_cuda_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_chunk_cat_cuda_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_compute_linear_combination_cuda_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_conv_depthwise2d_cuda_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_cuda_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_cuda_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_convert_weight_to_int4pack_cuda_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_cslt_compress_cuda_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_cslt_sparse_mm_cuda_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_cslt_sparse_mm_search_cuda_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_ctc_loss_cuda_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_ctc_loss_backward_cuda_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_cudnn_ctc_loss_cuda_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_cudnn_init_dropout_state_cuda_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_cudnn_rnn_cuda_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_cudnn_rnn_backward_cuda_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_cuda_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_cummax_helper_cuda_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_cummin_helper_cuda_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_dirichlet_grad_cuda_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_efficient_attention_backward_cuda_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_efficient_attention_forward_cuda_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_efficientzerotensor_cuda_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_embedding_bag_cuda_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_embedding_bag_backward_cuda_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_embedding_bag_dense_backward_cuda_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_embedding_bag_forward_only_cuda_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_cuda_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_cuda_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cuda_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_cuda_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cuda_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cuda_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_fft_c2c_cuda_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_fft_c2r_cuda_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_fft_r2c_cuda_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_cuda_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_flash_attention_backward_cuda_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_flash_attention_forward_cuda_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_foreach_abs_cuda_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_foreach_acos_cuda_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_foreach_add_cuda_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_foreach_addcdiv_cuda_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_foreach_addcmul_cuda_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_foreach_asin_cuda_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_foreach_atan_cuda_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_foreach_ceil_cuda_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_foreach_clamp_max_cuda_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_foreach_clamp_min_cuda_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_foreach_copy_cuda_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_foreach_cos_cuda_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_foreach_cosh_cuda_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_foreach_div_cuda_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_foreach_erf_cuda_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_foreach_erfc_cuda_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_foreach_exp_cuda_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_foreach_expm1_cuda_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_foreach_floor_cuda_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_foreach_frac_cuda_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_foreach_lerp_cuda_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_foreach_lgamma_cuda_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_foreach_log_cuda_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_foreach_log10_cuda_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_foreach_log1p_cuda_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_foreach_log2_cuda_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_foreach_max_cuda_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_foreach_maximum_cuda_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_foreach_minimum_cuda_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_foreach_mul_cuda_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_foreach_neg_cuda_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_foreach_norm_cuda_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_foreach_pow_cuda_dispatch.h>
|
| 103 |
+
#include <ATen/ops/_foreach_reciprocal_cuda_dispatch.h>
|
| 104 |
+
#include <ATen/ops/_foreach_round_cuda_dispatch.h>
|
| 105 |
+
#include <ATen/ops/_foreach_sigmoid_cuda_dispatch.h>
|
| 106 |
+
#include <ATen/ops/_foreach_sign_cuda_dispatch.h>
|
| 107 |
+
#include <ATen/ops/_foreach_sin_cuda_dispatch.h>
|
| 108 |
+
#include <ATen/ops/_foreach_sinh_cuda_dispatch.h>
|
| 109 |
+
#include <ATen/ops/_foreach_sqrt_cuda_dispatch.h>
|
| 110 |
+
#include <ATen/ops/_foreach_sub_cuda_dispatch.h>
|
| 111 |
+
#include <ATen/ops/_foreach_tan_cuda_dispatch.h>
|
| 112 |
+
#include <ATen/ops/_foreach_tanh_cuda_dispatch.h>
|
| 113 |
+
#include <ATen/ops/_foreach_trunc_cuda_dispatch.h>
|
| 114 |
+
#include <ATen/ops/_foreach_zero_cuda_dispatch.h>
|
| 115 |
+
#include <ATen/ops/_fused_adam_cuda_dispatch.h>
|
| 116 |
+
#include <ATen/ops/_fused_adamw_cuda_dispatch.h>
|
| 117 |
+
#include <ATen/ops/_fused_dropout_cuda_dispatch.h>
|
| 118 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_cuda_dispatch.h>
|
| 119 |
+
#include <ATen/ops/_fused_sdp_choice_cuda_dispatch.h>
|
| 120 |
+
#include <ATen/ops/_fused_sgd_cuda_dispatch.h>
|
| 121 |
+
#include <ATen/ops/_index_put_impl_cuda_dispatch.h>
|
| 122 |
+
#include <ATen/ops/_int_mm_cuda_dispatch.h>
|
| 123 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward_cuda_dispatch.h>
|
| 124 |
+
#include <ATen/ops/_linalg_det_cuda_dispatch.h>
|
| 125 |
+
#include <ATen/ops/_linalg_eigh_cuda_dispatch.h>
|
| 126 |
+
#include <ATen/ops/_linalg_eigvals_cuda_dispatch.h>
|
| 127 |
+
#include <ATen/ops/_linalg_slogdet_cuda_dispatch.h>
|
| 128 |
+
#include <ATen/ops/_linalg_solve_ex_cuda_dispatch.h>
|
| 129 |
+
#include <ATen/ops/_linalg_svd_cuda_dispatch.h>
|
| 130 |
+
#include <ATen/ops/_local_scalar_dense_cuda_dispatch.h>
|
| 131 |
+
#include <ATen/ops/_log_softmax_cuda_dispatch.h>
|
| 132 |
+
#include <ATen/ops/_log_softmax_backward_data_cuda_dispatch.h>
|
| 133 |
+
#include <ATen/ops/_logcumsumexp_cuda_dispatch.h>
|
| 134 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_cuda_dispatch.h>
|
| 135 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_cuda_dispatch.h>
|
| 136 |
+
#include <ATen/ops/_masked_scale_cuda_dispatch.h>
|
| 137 |
+
#include <ATen/ops/_masked_softmax_cuda_dispatch.h>
|
| 138 |
+
#include <ATen/ops/_masked_softmax_backward_cuda_dispatch.h>
|
| 139 |
+
#include <ATen/ops/_mixed_dtypes_linear_cuda_dispatch.h>
|
| 140 |
+
#include <ATen/ops/_native_batch_norm_legit_cuda_dispatch.h>
|
| 141 |
+
#include <ATen/ops/_native_multi_head_attention_cuda_dispatch.h>
|
| 142 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_cuda_dispatch.h>
|
| 143 |
+
#include <ATen/ops/_nested_from_padded_cuda_dispatch.h>
|
| 144 |
+
#include <ATen/ops/_nested_tensor_from_mask_cuda_dispatch.h>
|
| 145 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_cuda_dispatch.h>
|
| 146 |
+
#include <ATen/ops/_nested_view_from_buffer_cuda_dispatch.h>
|
| 147 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward_cuda_dispatch.h>
|
| 148 |
+
#include <ATen/ops/_pdist_backward_cuda_dispatch.h>
|
| 149 |
+
#include <ATen/ops/_pdist_forward_cuda_dispatch.h>
|
| 150 |
+
#include <ATen/ops/_prelu_kernel_cuda_dispatch.h>
|
| 151 |
+
#include <ATen/ops/_prelu_kernel_backward_cuda_dispatch.h>
|
| 152 |
+
#include <ATen/ops/_reshape_alias_cuda_dispatch.h>
|
| 153 |
+
#include <ATen/ops/_sample_dirichlet_cuda_dispatch.h>
|
| 154 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h>
|
| 155 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_cuda_dispatch.h>
|
| 156 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_cuda_dispatch.h>
|
| 157 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_cuda_dispatch.h>
|
| 158 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_cuda_dispatch.h>
|
| 159 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_cuda_dispatch.h>
|
| 160 |
+
#include <ATen/ops/_scaled_mm_cuda_dispatch.h>
|
| 161 |
+
#include <ATen/ops/_segment_reduce_backward_cuda_dispatch.h>
|
| 162 |
+
#include <ATen/ops/_slow_conv2d_backward_cuda_dispatch.h>
|
| 163 |
+
#include <ATen/ops/_slow_conv2d_forward_cuda_dispatch.h>
|
| 164 |
+
#include <ATen/ops/_softmax_cuda_dispatch.h>
|
| 165 |
+
#include <ATen/ops/_softmax_backward_data_cuda_dispatch.h>
|
| 166 |
+
#include <ATen/ops/_sparse_semi_structured_addmm_cuda_dispatch.h>
|
| 167 |
+
#include <ATen/ops/_sparse_semi_structured_apply_cuda_dispatch.h>
|
| 168 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense_cuda_dispatch.h>
|
| 169 |
+
#include <ATen/ops/_sparse_semi_structured_linear_cuda_dispatch.h>
|
| 170 |
+
#include <ATen/ops/_sparse_semi_structured_mm_cuda_dispatch.h>
|
| 171 |
+
#include <ATen/ops/_sparse_semi_structured_tile_cuda_dispatch.h>
|
| 172 |
+
#include <ATen/ops/_standard_gamma_cuda_dispatch.h>
|
| 173 |
+
#include <ATen/ops/_standard_gamma_grad_cuda_dispatch.h>
|
| 174 |
+
#include <ATen/ops/_thnn_fused_gru_cell_cuda_dispatch.h>
|
| 175 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_cuda_dispatch.h>
|
| 176 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_cuda_dispatch.h>
|
| 177 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_cuda_dispatch.h>
|
| 178 |
+
#include <ATen/ops/_to_sparse_cuda_dispatch.h>
|
| 179 |
+
#include <ATen/ops/_to_sparse_bsc_cuda_dispatch.h>
|
| 180 |
+
#include <ATen/ops/_to_sparse_bsr_cuda_dispatch.h>
|
| 181 |
+
#include <ATen/ops/_to_sparse_csc_cuda_dispatch.h>
|
| 182 |
+
#include <ATen/ops/_to_sparse_csr_cuda_dispatch.h>
|
| 183 |
+
#include <ATen/ops/_to_sparse_semi_structured_cuda_dispatch.h>
|
| 184 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_cuda_dispatch.h>
|
| 185 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_cuda_dispatch.h>
|
| 186 |
+
#include <ATen/ops/_triton_multi_head_attention_cuda_dispatch.h>
|
| 187 |
+
#include <ATen/ops/_triton_scaled_dot_attention_cuda_dispatch.h>
|
| 188 |
+
#include <ATen/ops/_unique_cuda_dispatch.h>
|
| 189 |
+
#include <ATen/ops/_unique2_cuda_dispatch.h>
|
| 190 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_cuda_dispatch.h>
|
| 191 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_cuda_dispatch.h>
|
| 192 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_cuda_dispatch.h>
|
| 193 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_cuda_dispatch.h>
|
| 194 |
+
#include <ATen/ops/_upsample_nearest_exact1d_cuda_dispatch.h>
|
| 195 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_cuda_dispatch.h>
|
| 196 |
+
#include <ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h>
|
| 197 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_cuda_dispatch.h>
|
| 198 |
+
#include <ATen/ops/_upsample_nearest_exact3d_cuda_dispatch.h>
|
| 199 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_cuda_dispatch.h>
|
| 200 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_cuda_dispatch.h>
|
| 201 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_cuda_dispatch.h>
|
| 202 |
+
#include <ATen/ops/_weight_int4pack_mm_cuda_dispatch.h>
|
| 203 |
+
#include <ATen/ops/_weight_norm_interface_cuda_dispatch.h>
|
| 204 |
+
#include <ATen/ops/_weight_norm_interface_backward_cuda_dispatch.h>
|
| 205 |
+
#include <ATen/ops/abs_cuda_dispatch.h>
|
| 206 |
+
#include <ATen/ops/acos_cuda_dispatch.h>
|
| 207 |
+
#include <ATen/ops/acosh_cuda_dispatch.h>
|
| 208 |
+
#include <ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h>
|
| 209 |
+
#include <ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h>
|
| 210 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h>
|
| 211 |
+
#include <ATen/ops/adaptive_max_pool2d_cuda_dispatch.h>
|
| 212 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h>
|
| 213 |
+
#include <ATen/ops/adaptive_max_pool3d_cuda_dispatch.h>
|
| 214 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h>
|
| 215 |
+
#include <ATen/ops/add_cuda_dispatch.h>
|
| 216 |
+
#include <ATen/ops/addbmm_cuda_dispatch.h>
|
| 217 |
+
#include <ATen/ops/addcdiv_cuda_dispatch.h>
|
| 218 |
+
#include <ATen/ops/addcmul_cuda_dispatch.h>
|
| 219 |
+
#include <ATen/ops/addmm_cuda_dispatch.h>
|
| 220 |
+
#include <ATen/ops/addmv_cuda_dispatch.h>
|
| 221 |
+
#include <ATen/ops/addr_cuda_dispatch.h>
|
| 222 |
+
#include <ATen/ops/all_cuda_dispatch.h>
|
| 223 |
+
#include <ATen/ops/amax_cuda_dispatch.h>
|
| 224 |
+
#include <ATen/ops/amin_cuda_dispatch.h>
|
| 225 |
+
#include <ATen/ops/aminmax_cuda_dispatch.h>
|
| 226 |
+
#include <ATen/ops/angle_cuda_dispatch.h>
|
| 227 |
+
#include <ATen/ops/any_cuda_dispatch.h>
|
| 228 |
+
#include <ATen/ops/arange_cuda_dispatch.h>
|
| 229 |
+
#include <ATen/ops/argmax_cuda_dispatch.h>
|
| 230 |
+
#include <ATen/ops/argmin_cuda_dispatch.h>
|
| 231 |
+
#include <ATen/ops/as_strided_cuda_dispatch.h>
|
| 232 |
+
#include <ATen/ops/asin_cuda_dispatch.h>
|
| 233 |
+
#include <ATen/ops/asinh_cuda_dispatch.h>
|
| 234 |
+
#include <ATen/ops/atan_cuda_dispatch.h>
|
| 235 |
+
#include <ATen/ops/atan2_cuda_dispatch.h>
|
| 236 |
+
#include <ATen/ops/atanh_cuda_dispatch.h>
|
| 237 |
+
#include <ATen/ops/avg_pool2d_cuda_dispatch.h>
|
| 238 |
+
#include <ATen/ops/avg_pool2d_backward_cuda_dispatch.h>
|
| 239 |
+
#include <ATen/ops/avg_pool3d_cuda_dispatch.h>
|
| 240 |
+
#include <ATen/ops/avg_pool3d_backward_cuda_dispatch.h>
|
| 241 |
+
#include <ATen/ops/baddbmm_cuda_dispatch.h>
|
| 242 |
+
#include <ATen/ops/batch_norm_backward_cuda_dispatch.h>
|
| 243 |
+
#include <ATen/ops/batch_norm_backward_elemt_cuda_dispatch.h>
|
| 244 |
+
#include <ATen/ops/batch_norm_backward_reduce_cuda_dispatch.h>
|
| 245 |
+
#include <ATen/ops/batch_norm_elemt_cuda_dispatch.h>
|
| 246 |
+
#include <ATen/ops/batch_norm_gather_stats_cuda_dispatch.h>
|
| 247 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_cuda_dispatch.h>
|
| 248 |
+
#include <ATen/ops/batch_norm_stats_cuda_dispatch.h>
|
| 249 |
+
#include <ATen/ops/batch_norm_update_stats_cuda_dispatch.h>
|
| 250 |
+
#include <ATen/ops/bernoulli_cuda_dispatch.h>
|
| 251 |
+
#include <ATen/ops/binary_cross_entropy_cuda_dispatch.h>
|
| 252 |
+
#include <ATen/ops/binary_cross_entropy_backward_cuda_dispatch.h>
|
| 253 |
+
#include <ATen/ops/bincount_cuda_dispatch.h>
|
| 254 |
+
#include <ATen/ops/binomial_cuda_dispatch.h>
|
| 255 |
+
#include <ATen/ops/bitwise_and_cuda_dispatch.h>
|
| 256 |
+
#include <ATen/ops/bitwise_left_shift_cuda_dispatch.h>
|
| 257 |
+
#include <ATen/ops/bitwise_not_cuda_dispatch.h>
|
| 258 |
+
#include <ATen/ops/bitwise_or_cuda_dispatch.h>
|
| 259 |
+
#include <ATen/ops/bitwise_right_shift_cuda_dispatch.h>
|
| 260 |
+
#include <ATen/ops/bitwise_xor_cuda_dispatch.h>
|
| 261 |
+
#include <ATen/ops/bmm_cuda_dispatch.h>
|
| 262 |
+
#include <ATen/ops/bucketize_cuda_dispatch.h>
|
| 263 |
+
#include <ATen/ops/cat_cuda_dispatch.h>
|
| 264 |
+
#include <ATen/ops/cauchy_cuda_dispatch.h>
|
| 265 |
+
#include <ATen/ops/ceil_cuda_dispatch.h>
|
| 266 |
+
#include <ATen/ops/channel_shuffle_cuda_dispatch.h>
|
| 267 |
+
#include <ATen/ops/cholesky_cuda_dispatch.h>
|
| 268 |
+
#include <ATen/ops/cholesky_inverse_cuda_dispatch.h>
|
| 269 |
+
#include <ATen/ops/clamp_cuda_dispatch.h>
|
| 270 |
+
#include <ATen/ops/clamp_max_cuda_dispatch.h>
|
| 271 |
+
#include <ATen/ops/clamp_min_cuda_dispatch.h>
|
| 272 |
+
#include <ATen/ops/col2im_cuda_dispatch.h>
|
| 273 |
+
#include <ATen/ops/complex_cuda_dispatch.h>
|
| 274 |
+
#include <ATen/ops/conj_physical_cuda_dispatch.h>
|
| 275 |
+
#include <ATen/ops/conv_depthwise3d_cuda_dispatch.h>
|
| 276 |
+
#include <ATen/ops/convolution_backward_cuda_dispatch.h>
|
| 277 |
+
#include <ATen/ops/copysign_cuda_dispatch.h>
|
| 278 |
+
#include <ATen/ops/cos_cuda_dispatch.h>
|
| 279 |
+
#include <ATen/ops/cosh_cuda_dispatch.h>
|
| 280 |
+
#include <ATen/ops/count_nonzero_cuda_dispatch.h>
|
| 281 |
+
#include <ATen/ops/cudnn_affine_grid_generator_cuda_dispatch.h>
|
| 282 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_cuda_dispatch.h>
|
| 283 |
+
#include <ATen/ops/cudnn_batch_norm_cuda_dispatch.h>
|
| 284 |
+
#include <ATen/ops/cudnn_batch_norm_backward_cuda_dispatch.h>
|
| 285 |
+
#include <ATen/ops/cudnn_convolution_cuda_dispatch.h>
|
| 286 |
+
#include <ATen/ops/cudnn_convolution_add_relu_cuda_dispatch.h>
|
| 287 |
+
#include <ATen/ops/cudnn_convolution_relu_cuda_dispatch.h>
|
| 288 |
+
#include <ATen/ops/cudnn_convolution_transpose_cuda_dispatch.h>
|
| 289 |
+
#include <ATen/ops/cudnn_grid_sampler_cuda_dispatch.h>
|
| 290 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h>
|
| 291 |
+
#include <ATen/ops/cumprod_cuda_dispatch.h>
|
| 292 |
+
#include <ATen/ops/cumsum_cuda_dispatch.h>
|
| 293 |
+
#include <ATen/ops/dequantize_cuda_dispatch.h>
|
| 294 |
+
#include <ATen/ops/digamma_cuda_dispatch.h>
|
| 295 |
+
#include <ATen/ops/div_cuda_dispatch.h>
|
| 296 |
+
#include <ATen/ops/dot_cuda_dispatch.h>
|
| 297 |
+
#include <ATen/ops/elu_cuda_dispatch.h>
|
| 298 |
+
#include <ATen/ops/elu_backward_cuda_dispatch.h>
|
| 299 |
+
#include <ATen/ops/embedding_dense_backward_cuda_dispatch.h>
|
| 300 |
+
#include <ATen/ops/embedding_renorm_cuda_dispatch.h>
|
| 301 |
+
#include <ATen/ops/empty_cuda_dispatch.h>
|
| 302 |
+
#include <ATen/ops/empty_strided_cuda_dispatch.h>
|
| 303 |
+
#include <ATen/ops/eq_cuda_dispatch.h>
|
| 304 |
+
#include <ATen/ops/equal_cuda_dispatch.h>
|
| 305 |
+
#include <ATen/ops/erf_cuda_dispatch.h>
|
| 306 |
+
#include <ATen/ops/erfc_cuda_dispatch.h>
|
| 307 |
+
#include <ATen/ops/erfinv_cuda_dispatch.h>
|
| 308 |
+
#include <ATen/ops/exp_cuda_dispatch.h>
|
| 309 |
+
#include <ATen/ops/exp2_cuda_dispatch.h>
|
| 310 |
+
#include <ATen/ops/expm1_cuda_dispatch.h>
|
| 311 |
+
#include <ATen/ops/exponential_cuda_dispatch.h>
|
| 312 |
+
#include <ATen/ops/eye_cuda_dispatch.h>
|
| 313 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_cuda_dispatch.h>
|
| 314 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_cuda_dispatch.h>
|
| 315 |
+
#include <ATen/ops/fill_cuda_dispatch.h>
|
| 316 |
+
#include <ATen/ops/flip_cuda_dispatch.h>
|
| 317 |
+
#include <ATen/ops/floor_cuda_dispatch.h>
|
| 318 |
+
#include <ATen/ops/floor_divide_cuda_dispatch.h>
|
| 319 |
+
#include <ATen/ops/fmax_cuda_dispatch.h>
|
| 320 |
+
#include <ATen/ops/fmin_cuda_dispatch.h>
|
| 321 |
+
#include <ATen/ops/fmod_cuda_dispatch.h>
|
| 322 |
+
#include <ATen/ops/frac_cuda_dispatch.h>
|
| 323 |
+
#include <ATen/ops/fractional_max_pool2d_cuda_dispatch.h>
|
| 324 |
+
#include <ATen/ops/fractional_max_pool2d_backward_cuda_dispatch.h>
|
| 325 |
+
#include <ATen/ops/fractional_max_pool3d_cuda_dispatch.h>
|
| 326 |
+
#include <ATen/ops/fractional_max_pool3d_backward_cuda_dispatch.h>
|
| 327 |
+
#include <ATen/ops/frexp_cuda_dispatch.h>
|
| 328 |
+
#include <ATen/ops/gather_cuda_dispatch.h>
|
| 329 |
+
#include <ATen/ops/gcd_cuda_dispatch.h>
|
| 330 |
+
#include <ATen/ops/ge_cuda_dispatch.h>
|
| 331 |
+
#include <ATen/ops/gelu_cuda_dispatch.h>
|
| 332 |
+
#include <ATen/ops/gelu_backward_cuda_dispatch.h>
|
| 333 |
+
#include <ATen/ops/geometric_cuda_dispatch.h>
|
| 334 |
+
#include <ATen/ops/geqrf_cuda_dispatch.h>
|
| 335 |
+
#include <ATen/ops/glu_cuda_dispatch.h>
|
| 336 |
+
#include <ATen/ops/glu_backward_cuda_dispatch.h>
|
| 337 |
+
#include <ATen/ops/glu_backward_jvp_cuda_dispatch.h>
|
| 338 |
+
#include <ATen/ops/glu_jvp_cuda_dispatch.h>
|
| 339 |
+
#include <ATen/ops/grid_sampler_2d_cuda_dispatch.h>
|
| 340 |
+
#include <ATen/ops/grid_sampler_2d_backward_cuda_dispatch.h>
|
| 341 |
+
#include <ATen/ops/grid_sampler_3d_cuda_dispatch.h>
|
| 342 |
+
#include <ATen/ops/grid_sampler_3d_backward_cuda_dispatch.h>
|
| 343 |
+
#include <ATen/ops/gt_cuda_dispatch.h>
|
| 344 |
+
#include <ATen/ops/hardshrink_cuda_dispatch.h>
|
| 345 |
+
#include <ATen/ops/hardshrink_backward_cuda_dispatch.h>
|
| 346 |
+
#include <ATen/ops/hardsigmoid_cuda_dispatch.h>
|
| 347 |
+
#include <ATen/ops/hardsigmoid_backward_cuda_dispatch.h>
|
| 348 |
+
#include <ATen/ops/hardswish_cuda_dispatch.h>
|
| 349 |
+
#include <ATen/ops/hardswish_backward_cuda_dispatch.h>
|
| 350 |
+
#include <ATen/ops/hardtanh_cuda_dispatch.h>
|
| 351 |
+
#include <ATen/ops/hardtanh_backward_cuda_dispatch.h>
|
| 352 |
+
#include <ATen/ops/heaviside_cuda_dispatch.h>
|
| 353 |
+
#include <ATen/ops/histc_cuda_dispatch.h>
|
| 354 |
+
#include <ATen/ops/huber_loss_cuda_dispatch.h>
|
| 355 |
+
#include <ATen/ops/huber_loss_backward_cuda_dispatch.h>
|
| 356 |
+
#include <ATen/ops/hypot_cuda_dispatch.h>
|
| 357 |
+
#include <ATen/ops/i0_cuda_dispatch.h>
|
| 358 |
+
#include <ATen/ops/igamma_cuda_dispatch.h>
|
| 359 |
+
#include <ATen/ops/igammac_cuda_dispatch.h>
|
| 360 |
+
#include <ATen/ops/im2col_cuda_dispatch.h>
|
| 361 |
+
#include <ATen/ops/index_cuda_dispatch.h>
|
| 362 |
+
#include <ATen/ops/index_add_cuda_dispatch.h>
|
| 363 |
+
#include <ATen/ops/index_copy_cuda_dispatch.h>
|
| 364 |
+
#include <ATen/ops/index_fill_cuda_dispatch.h>
|
| 365 |
+
#include <ATen/ops/index_reduce_cuda_dispatch.h>
|
| 366 |
+
#include <ATen/ops/index_select_cuda_dispatch.h>
|
| 367 |
+
#include <ATen/ops/is_set_to_cuda_dispatch.h>
|
| 368 |
+
#include <ATen/ops/isin_cuda_dispatch.h>
|
| 369 |
+
#include <ATen/ops/isnan_cuda_dispatch.h>
|
| 370 |
+
#include <ATen/ops/isneginf_cuda_dispatch.h>
|
| 371 |
+
#include <ATen/ops/isposinf_cuda_dispatch.h>
|
| 372 |
+
#include <ATen/ops/kthvalue_cuda_dispatch.h>
|
| 373 |
+
#include <ATen/ops/lcm_cuda_dispatch.h>
|
| 374 |
+
#include <ATen/ops/le_cuda_dispatch.h>
|
| 375 |
+
#include <ATen/ops/leaky_relu_cuda_dispatch.h>
|
| 376 |
+
#include <ATen/ops/leaky_relu_backward_cuda_dispatch.h>
|
| 377 |
+
#include <ATen/ops/lerp_cuda_dispatch.h>
|
| 378 |
+
#include <ATen/ops/lgamma_cuda_dispatch.h>
|
| 379 |
+
#include <ATen/ops/linalg_cholesky_ex_cuda_dispatch.h>
|
| 380 |
+
#include <ATen/ops/linalg_cross_cuda_dispatch.h>
|
| 381 |
+
#include <ATen/ops/linalg_eig_cuda_dispatch.h>
|
| 382 |
+
#include <ATen/ops/linalg_eigvals_cuda_dispatch.h>
|
| 383 |
+
#include <ATen/ops/linalg_householder_product_cuda_dispatch.h>
|
| 384 |
+
#include <ATen/ops/linalg_inv_ex_cuda_dispatch.h>
|
| 385 |
+
#include <ATen/ops/linalg_ldl_factor_ex_cuda_dispatch.h>
|
| 386 |
+
#include <ATen/ops/linalg_ldl_solve_cuda_dispatch.h>
|
| 387 |
+
#include <ATen/ops/linalg_lstsq_cuda_dispatch.h>
|
| 388 |
+
#include <ATen/ops/linalg_lu_cuda_dispatch.h>
|
| 389 |
+
#include <ATen/ops/linalg_lu_factor_ex_cuda_dispatch.h>
|
| 390 |
+
#include <ATen/ops/linalg_lu_solve_cuda_dispatch.h>
|
| 391 |
+
#include <ATen/ops/linalg_matrix_exp_cuda_dispatch.h>
|
| 392 |
+
#include <ATen/ops/linalg_qr_cuda_dispatch.h>
|
| 393 |
+
#include <ATen/ops/linalg_solve_triangular_cuda_dispatch.h>
|
| 394 |
+
#include <ATen/ops/linalg_vector_norm_cuda_dispatch.h>
|
| 395 |
+
#include <ATen/ops/linspace_cuda_dispatch.h>
|
| 396 |
+
#include <ATen/ops/log_cuda_dispatch.h>
|
| 397 |
+
#include <ATen/ops/log10_cuda_dispatch.h>
|
| 398 |
+
#include <ATen/ops/log1p_cuda_dispatch.h>
|
| 399 |
+
#include <ATen/ops/log2_cuda_dispatch.h>
|
| 400 |
+
#include <ATen/ops/log_normal_cuda_dispatch.h>
|
| 401 |
+
#include <ATen/ops/log_sigmoid_backward_cuda_dispatch.h>
|
| 402 |
+
#include <ATen/ops/log_sigmoid_forward_cuda_dispatch.h>
|
| 403 |
+
#include <ATen/ops/logaddexp_cuda_dispatch.h>
|
| 404 |
+
#include <ATen/ops/logaddexp2_cuda_dispatch.h>
|
| 405 |
+
#include <ATen/ops/logical_and_cuda_dispatch.h>
|
| 406 |
+
#include <ATen/ops/logical_not_cuda_dispatch.h>
|
| 407 |
+
#include <ATen/ops/logical_or_cuda_dispatch.h>
|
| 408 |
+
#include <ATen/ops/logical_xor_cuda_dispatch.h>
|
| 409 |
+
#include <ATen/ops/logit_cuda_dispatch.h>
|
| 410 |
+
#include <ATen/ops/logit_backward_cuda_dispatch.h>
|
| 411 |
+
#include <ATen/ops/logspace_cuda_dispatch.h>
|
| 412 |
+
#include <ATen/ops/lshift_cuda_dispatch.h>
|
| 413 |
+
#include <ATen/ops/lt_cuda_dispatch.h>
|
| 414 |
+
#include <ATen/ops/lu_unpack_cuda_dispatch.h>
|
| 415 |
+
#include <ATen/ops/masked_fill_cuda_dispatch.h>
|
| 416 |
+
#include <ATen/ops/masked_scatter_cuda_dispatch.h>
|
| 417 |
+
#include <ATen/ops/masked_select_cuda_dispatch.h>
|
| 418 |
+
#include <ATen/ops/max_cuda_dispatch.h>
|
| 419 |
+
#include <ATen/ops/max_pool2d_with_indices_cuda_dispatch.h>
|
| 420 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_cuda_dispatch.h>
|
| 421 |
+
#include <ATen/ops/max_pool3d_with_indices_cuda_dispatch.h>
|
| 422 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_cuda_dispatch.h>
|
| 423 |
+
#include <ATen/ops/max_unpool2d_cuda_dispatch.h>
|
| 424 |
+
#include <ATen/ops/max_unpool3d_cuda_dispatch.h>
|
| 425 |
+
#include <ATen/ops/maximum_cuda_dispatch.h>
|
| 426 |
+
#include <ATen/ops/mean_cuda_dispatch.h>
|
| 427 |
+
#include <ATen/ops/median_cuda_dispatch.h>
|
| 428 |
+
#include <ATen/ops/min_cuda_dispatch.h>
|
| 429 |
+
#include <ATen/ops/minimum_cuda_dispatch.h>
|
| 430 |
+
#include <ATen/ops/miopen_batch_norm_cuda_dispatch.h>
|
| 431 |
+
#include <ATen/ops/miopen_batch_norm_backward_cuda_dispatch.h>
|
| 432 |
+
#include <ATen/ops/miopen_convolution_cuda_dispatch.h>
|
| 433 |
+
#include <ATen/ops/miopen_convolution_add_relu_cuda_dispatch.h>
|
| 434 |
+
#include <ATen/ops/miopen_convolution_relu_cuda_dispatch.h>
|
| 435 |
+
#include <ATen/ops/miopen_convolution_transpose_cuda_dispatch.h>
|
| 436 |
+
#include <ATen/ops/miopen_depthwise_convolution_cuda_dispatch.h>
|
| 437 |
+
#include <ATen/ops/miopen_rnn_cuda_dispatch.h>
|
| 438 |
+
#include <ATen/ops/miopen_rnn_backward_cuda_dispatch.h>
|
| 439 |
+
#include <ATen/ops/mish_cuda_dispatch.h>
|
| 440 |
+
#include <ATen/ops/mish_backward_cuda_dispatch.h>
|
| 441 |
+
#include <ATen/ops/mm_cuda_dispatch.h>
|
| 442 |
+
#include <ATen/ops/mode_cuda_dispatch.h>
|
| 443 |
+
#include <ATen/ops/mse_loss_cuda_dispatch.h>
|
| 444 |
+
#include <ATen/ops/mse_loss_backward_cuda_dispatch.h>
|
| 445 |
+
#include <ATen/ops/mul_cuda_dispatch.h>
|
| 446 |
+
#include <ATen/ops/multi_margin_loss_cuda_dispatch.h>
|
| 447 |
+
#include <ATen/ops/multi_margin_loss_backward_cuda_dispatch.h>
|
| 448 |
+
#include <ATen/ops/multilabel_margin_loss_backward_cuda_dispatch.h>
|
| 449 |
+
#include <ATen/ops/multilabel_margin_loss_forward_cuda_dispatch.h>
|
| 450 |
+
#include <ATen/ops/multinomial_cuda_dispatch.h>
|
| 451 |
+
#include <ATen/ops/mvlgamma_cuda_dispatch.h>
|
| 452 |
+
#include <ATen/ops/nan_to_num_cuda_dispatch.h>
|
| 453 |
+
#include <ATen/ops/nanmedian_cuda_dispatch.h>
|
| 454 |
+
#include <ATen/ops/nansum_cuda_dispatch.h>
|
| 455 |
+
#include <ATen/ops/native_batch_norm_cuda_dispatch.h>
|
| 456 |
+
#include <ATen/ops/native_batch_norm_backward_cuda_dispatch.h>
|
| 457 |
+
#include <ATen/ops/native_dropout_cuda_dispatch.h>
|
| 458 |
+
#include <ATen/ops/native_dropout_backward_cuda_dispatch.h>
|
| 459 |
+
#include <ATen/ops/native_group_norm_cuda_dispatch.h>
|
| 460 |
+
#include <ATen/ops/native_group_norm_backward_cuda_dispatch.h>
|
| 461 |
+
#include <ATen/ops/native_layer_norm_cuda_dispatch.h>
|
| 462 |
+
#include <ATen/ops/native_layer_norm_backward_cuda_dispatch.h>
|
| 463 |
+
#include <ATen/ops/ne_cuda_dispatch.h>
|
| 464 |
+
#include <ATen/ops/neg_cuda_dispatch.h>
|
| 465 |
+
#include <ATen/ops/nextafter_cuda_dispatch.h>
|
| 466 |
+
#include <ATen/ops/nll_loss2d_backward_cuda_dispatch.h>
|
| 467 |
+
#include <ATen/ops/nll_loss2d_forward_cuda_dispatch.h>
|
| 468 |
+
#include <ATen/ops/nll_loss_backward_cuda_dispatch.h>
|
| 469 |
+
#include <ATen/ops/nll_loss_forward_cuda_dispatch.h>
|
| 470 |
+
#include <ATen/ops/nonzero_cuda_dispatch.h>
|
| 471 |
+
#include <ATen/ops/norm_cuda_dispatch.h>
|
| 472 |
+
#include <ATen/ops/normal_cuda_dispatch.h>
|
| 473 |
+
#include <ATen/ops/ormqr_cuda_dispatch.h>
|
| 474 |
+
#include <ATen/ops/poisson_cuda_dispatch.h>
|
| 475 |
+
#include <ATen/ops/polar_cuda_dispatch.h>
|
| 476 |
+
#include <ATen/ops/polygamma_cuda_dispatch.h>
|
| 477 |
+
#include <ATen/ops/pow_cuda_dispatch.h>
|
| 478 |
+
#include <ATen/ops/prod_cuda_dispatch.h>
|
| 479 |
+
#include <ATen/ops/put_cuda_dispatch.h>
|
| 480 |
+
#include <ATen/ops/quantize_per_channel_cuda_dispatch.h>
|
| 481 |
+
#include <ATen/ops/quantize_per_tensor_cuda_dispatch.h>
|
| 482 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_cuda_dispatch.h>
|
| 483 |
+
#include <ATen/ops/random_cuda_dispatch.h>
|
| 484 |
+
#include <ATen/ops/randperm_cuda_dispatch.h>
|
| 485 |
+
#include <ATen/ops/range_cuda_dispatch.h>
|
| 486 |
+
#include <ATen/ops/reciprocal_cuda_dispatch.h>
|
| 487 |
+
#include <ATen/ops/record_stream_cuda_dispatch.h>
|
| 488 |
+
#include <ATen/ops/reflection_pad1d_cuda_dispatch.h>
|
| 489 |
+
#include <ATen/ops/reflection_pad1d_backward_cuda_dispatch.h>
|
| 490 |
+
#include <ATen/ops/reflection_pad2d_cuda_dispatch.h>
|
| 491 |
+
#include <ATen/ops/reflection_pad2d_backward_cuda_dispatch.h>
|
| 492 |
+
#include <ATen/ops/reflection_pad3d_cuda_dispatch.h>
|
| 493 |
+
#include <ATen/ops/reflection_pad3d_backward_cuda_dispatch.h>
|
| 494 |
+
#include <ATen/ops/relu_cuda_dispatch.h>
|
| 495 |
+
#include <ATen/ops/remainder_cuda_dispatch.h>
|
| 496 |
+
#include <ATen/ops/renorm_cuda_dispatch.h>
|
| 497 |
+
#include <ATen/ops/repeat_interleave_cuda_dispatch.h>
|
| 498 |
+
#include <ATen/ops/replication_pad1d_cuda_dispatch.h>
|
| 499 |
+
#include <ATen/ops/replication_pad1d_backward_cuda_dispatch.h>
|
| 500 |
+
#include <ATen/ops/replication_pad2d_cuda_dispatch.h>
|
| 501 |
+
#include <ATen/ops/replication_pad2d_backward_cuda_dispatch.h>
|
| 502 |
+
#include <ATen/ops/replication_pad3d_cuda_dispatch.h>
|
| 503 |
+
#include <ATen/ops/replication_pad3d_backward_cuda_dispatch.h>
|
| 504 |
+
#include <ATen/ops/resize_cuda_dispatch.h>
|
| 505 |
+
#include <ATen/ops/roll_cuda_dispatch.h>
|
| 506 |
+
#include <ATen/ops/round_cuda_dispatch.h>
|
| 507 |
+
#include <ATen/ops/rrelu_with_noise_cuda_dispatch.h>
|
| 508 |
+
#include <ATen/ops/rshift_cuda_dispatch.h>
|
| 509 |
+
#include <ATen/ops/rsqrt_cuda_dispatch.h>
|
| 510 |
+
#include <ATen/ops/rsub_cuda_dispatch.h>
|
| 511 |
+
#include <ATen/ops/scatter_cuda_dispatch.h>
|
| 512 |
+
#include <ATen/ops/scatter_add_cuda_dispatch.h>
|
| 513 |
+
#include <ATen/ops/scatter_reduce_cuda_dispatch.h>
|
| 514 |
+
#include <ATen/ops/searchsorted_cuda_dispatch.h>
|
| 515 |
+
#include <ATen/ops/segment_reduce_cuda_dispatch.h>
|
| 516 |
+
#include <ATen/ops/set_cuda_dispatch.h>
|
| 517 |
+
#include <ATen/ops/sgn_cuda_dispatch.h>
|
| 518 |
+
#include <ATen/ops/sigmoid_cuda_dispatch.h>
|
| 519 |
+
#include <ATen/ops/sigmoid_backward_cuda_dispatch.h>
|
| 520 |
+
#include <ATen/ops/sign_cuda_dispatch.h>
|
| 521 |
+
#include <ATen/ops/signbit_cuda_dispatch.h>
|
| 522 |
+
#include <ATen/ops/silu_cuda_dispatch.h>
|
| 523 |
+
#include <ATen/ops/silu_backward_cuda_dispatch.h>
|
| 524 |
+
#include <ATen/ops/sin_cuda_dispatch.h>
|
| 525 |
+
#include <ATen/ops/sinc_cuda_dispatch.h>
|
| 526 |
+
#include <ATen/ops/sinh_cuda_dispatch.h>
|
| 527 |
+
#include <ATen/ops/slow_conv_dilated2d_cuda_dispatch.h>
|
| 528 |
+
#include <ATen/ops/slow_conv_dilated3d_cuda_dispatch.h>
|
| 529 |
+
#include <ATen/ops/slow_conv_transpose2d_cuda_dispatch.h>
|
| 530 |
+
#include <ATen/ops/slow_conv_transpose3d_cuda_dispatch.h>
|
| 531 |
+
#include <ATen/ops/smooth_l1_loss_cuda_dispatch.h>
|
| 532 |
+
#include <ATen/ops/smooth_l1_loss_backward_cuda_dispatch.h>
|
| 533 |
+
#include <ATen/ops/softplus_cuda_dispatch.h>
|
| 534 |
+
#include <ATen/ops/softplus_backward_cuda_dispatch.h>
|
| 535 |
+
#include <ATen/ops/softshrink_cuda_dispatch.h>
|
| 536 |
+
#include <ATen/ops/softshrink_backward_cuda_dispatch.h>
|
| 537 |
+
#include <ATen/ops/sort_cuda_dispatch.h>
|
| 538 |
+
#include <ATen/ops/special_airy_ai_cuda_dispatch.h>
|
| 539 |
+
#include <ATen/ops/special_bessel_j0_cuda_dispatch.h>
|
| 540 |
+
#include <ATen/ops/special_bessel_j1_cuda_dispatch.h>
|
| 541 |
+
#include <ATen/ops/special_bessel_y0_cuda_dispatch.h>
|
| 542 |
+
#include <ATen/ops/special_bessel_y1_cuda_dispatch.h>
|
| 543 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_cuda_dispatch.h>
|
| 544 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_cuda_dispatch.h>
|
| 545 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_cuda_dispatch.h>
|
| 546 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_cuda_dispatch.h>
|
| 547 |
+
#include <ATen/ops/special_entr_cuda_dispatch.h>
|
| 548 |
+
#include <ATen/ops/special_erfcx_cuda_dispatch.h>
|
| 549 |
+
#include <ATen/ops/special_hermite_polynomial_h_cuda_dispatch.h>
|
| 550 |
+
#include <ATen/ops/special_hermite_polynomial_he_cuda_dispatch.h>
|
| 551 |
+
#include <ATen/ops/special_i0e_cuda_dispatch.h>
|
| 552 |
+
#include <ATen/ops/special_i1_cuda_dispatch.h>
|
| 553 |
+
#include <ATen/ops/special_i1e_cuda_dispatch.h>
|
| 554 |
+
#include <ATen/ops/special_laguerre_polynomial_l_cuda_dispatch.h>
|
| 555 |
+
#include <ATen/ops/special_legendre_polynomial_p_cuda_dispatch.h>
|
| 556 |
+
#include <ATen/ops/special_log_ndtr_cuda_dispatch.h>
|
| 557 |
+
#include <ATen/ops/special_modified_bessel_i0_cuda_dispatch.h>
|
| 558 |
+
#include <ATen/ops/special_modified_bessel_i1_cuda_dispatch.h>
|
| 559 |
+
#include <ATen/ops/special_modified_bessel_k0_cuda_dispatch.h>
|
| 560 |
+
#include <ATen/ops/special_modified_bessel_k1_cuda_dispatch.h>
|
| 561 |
+
#include <ATen/ops/special_ndtri_cuda_dispatch.h>
|
| 562 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_cuda_dispatch.h>
|
| 563 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_cuda_dispatch.h>
|
| 564 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_cuda_dispatch.h>
|
| 565 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_cuda_dispatch.h>
|
| 566 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_cuda_dispatch.h>
|
| 567 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_cuda_dispatch.h>
|
| 568 |
+
#include <ATen/ops/special_spherical_bessel_j0_cuda_dispatch.h>
|
| 569 |
+
#include <ATen/ops/special_xlog1py_cuda_dispatch.h>
|
| 570 |
+
#include <ATen/ops/special_zeta_cuda_dispatch.h>
|
| 571 |
+
#include <ATen/ops/split_with_sizes_copy_cuda_dispatch.h>
|
| 572 |
+
#include <ATen/ops/sqrt_cuda_dispatch.h>
|
| 573 |
+
#include <ATen/ops/sspaddmm_cuda_dispatch.h>
|
| 574 |
+
#include <ATen/ops/std_cuda_dispatch.h>
|
| 575 |
+
#include <ATen/ops/std_mean_cuda_dispatch.h>
|
| 576 |
+
#include <ATen/ops/sub_cuda_dispatch.h>
|
| 577 |
+
#include <ATen/ops/sum_cuda_dispatch.h>
|
| 578 |
+
#include <ATen/ops/take_cuda_dispatch.h>
|
| 579 |
+
#include <ATen/ops/tan_cuda_dispatch.h>
|
| 580 |
+
#include <ATen/ops/tanh_cuda_dispatch.h>
|
| 581 |
+
#include <ATen/ops/tanh_backward_cuda_dispatch.h>
|
| 582 |
+
#include <ATen/ops/threshold_cuda_dispatch.h>
|
| 583 |
+
#include <ATen/ops/threshold_backward_cuda_dispatch.h>
|
| 584 |
+
#include <ATen/ops/topk_cuda_dispatch.h>
|
| 585 |
+
#include <ATen/ops/trace_cuda_dispatch.h>
|
| 586 |
+
#include <ATen/ops/triangular_solve_cuda_dispatch.h>
|
| 587 |
+
#include <ATen/ops/tril_cuda_dispatch.h>
|
| 588 |
+
#include <ATen/ops/tril_indices_cuda_dispatch.h>
|
| 589 |
+
#include <ATen/ops/triu_cuda_dispatch.h>
|
| 590 |
+
#include <ATen/ops/triu_indices_cuda_dispatch.h>
|
| 591 |
+
#include <ATen/ops/trunc_cuda_dispatch.h>
|
| 592 |
+
#include <ATen/ops/unfold_cuda_dispatch.h>
|
| 593 |
+
#include <ATen/ops/unfold_backward_cuda_dispatch.h>
|
| 594 |
+
#include <ATen/ops/uniform_cuda_dispatch.h>
|
| 595 |
+
#include <ATen/ops/unique_consecutive_cuda_dispatch.h>
|
| 596 |
+
#include <ATen/ops/unique_dim_cuda_dispatch.h>
|
| 597 |
+
#include <ATen/ops/unique_dim_consecutive_cuda_dispatch.h>
|
| 598 |
+
#include <ATen/ops/upsample_bicubic2d_cuda_dispatch.h>
|
| 599 |
+
#include <ATen/ops/upsample_bicubic2d_backward_cuda_dispatch.h>
|
| 600 |
+
#include <ATen/ops/upsample_bilinear2d_cuda_dispatch.h>
|
| 601 |
+
#include <ATen/ops/upsample_bilinear2d_backward_cuda_dispatch.h>
|
| 602 |
+
#include <ATen/ops/upsample_linear1d_cuda_dispatch.h>
|
| 603 |
+
#include <ATen/ops/upsample_linear1d_backward_cuda_dispatch.h>
|
| 604 |
+
#include <ATen/ops/upsample_nearest1d_cuda_dispatch.h>
|
| 605 |
+
#include <ATen/ops/upsample_nearest1d_backward_cuda_dispatch.h>
|
| 606 |
+
#include <ATen/ops/upsample_nearest2d_cuda_dispatch.h>
|
| 607 |
+
#include <ATen/ops/upsample_nearest2d_backward_cuda_dispatch.h>
|
| 608 |
+
#include <ATen/ops/upsample_nearest3d_cuda_dispatch.h>
|
| 609 |
+
#include <ATen/ops/upsample_nearest3d_backward_cuda_dispatch.h>
|
| 610 |
+
#include <ATen/ops/upsample_trilinear3d_cuda_dispatch.h>
|
| 611 |
+
#include <ATen/ops/upsample_trilinear3d_backward_cuda_dispatch.h>
|
| 612 |
+
#include <ATen/ops/var_cuda_dispatch.h>
|
| 613 |
+
#include <ATen/ops/var_mean_cuda_dispatch.h>
|
| 614 |
+
#include <ATen/ops/vdot_cuda_dispatch.h>
|
| 615 |
+
#include <ATen/ops/view_cuda_dispatch.h>
|
| 616 |
+
#include <ATen/ops/view_as_complex_cuda_dispatch.h>
|
| 617 |
+
#include <ATen/ops/view_as_real_cuda_dispatch.h>
|
| 618 |
+
#include <ATen/ops/where_cuda_dispatch.h>
|
| 619 |
+
#include <ATen/ops/xlogy_cuda_dispatch.h>
|
| 620 |
+
#include <ATen/ops/zero_cuda_dispatch.h>
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
|
.venv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h>
|
.venv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeexplicitautogradnonfunctional_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_addmm_activation_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_conj_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_fw_primal_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_linalg_det_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_linalg_eigh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_linalg_slogdet_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_linalg_solve_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_linalg_svd_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_log_softmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_log_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_make_dual_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_neg_view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_nested_get_values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_reshape_alias_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_softmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_trilinear_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_upsample_nearest_exact1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_upsample_nearest_exact2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_upsample_nearest_exact3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 54 |
+
#include <ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 55 |
+
#include <ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 56 |
+
#include <ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 57 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 58 |
+
#include <ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 59 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 60 |
+
#include <ATen/ops/add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 61 |
+
#include <ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 62 |
+
#include <ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 63 |
+
#include <ATen/ops/addmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 64 |
+
#include <ATen/ops/addmv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 65 |
+
#include <ATen/ops/alias_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 66 |
+
#include <ATen/ops/all_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 67 |
+
#include <ATen/ops/amax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 68 |
+
#include <ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 69 |
+
#include <ATen/ops/aminmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 70 |
+
#include <ATen/ops/any_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 71 |
+
#include <ATen/ops/argmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 72 |
+
#include <ATen/ops/argmin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 73 |
+
#include <ATen/ops/as_strided_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 74 |
+
#include <ATen/ops/as_strided_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 75 |
+
#include <ATen/ops/as_strided_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 76 |
+
#include <ATen/ops/asin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 77 |
+
#include <ATen/ops/asinh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 78 |
+
#include <ATen/ops/atan_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 79 |
+
#include <ATen/ops/atan2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 80 |
+
#include <ATen/ops/atanh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 81 |
+
#include <ATen/ops/avg_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 82 |
+
#include <ATen/ops/avg_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 83 |
+
#include <ATen/ops/avg_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 84 |
+
#include <ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 85 |
+
#include <ATen/ops/baddbmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 86 |
+
#include <ATen/ops/bernoulli_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 87 |
+
#include <ATen/ops/bitwise_and_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 88 |
+
#include <ATen/ops/bitwise_left_shift_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 89 |
+
#include <ATen/ops/bitwise_not_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 90 |
+
#include <ATen/ops/bitwise_or_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 91 |
+
#include <ATen/ops/bitwise_right_shift_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 92 |
+
#include <ATen/ops/bitwise_xor_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 93 |
+
#include <ATen/ops/bmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 94 |
+
#include <ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 95 |
+
#include <ATen/ops/ccol_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 96 |
+
#include <ATen/ops/ceil_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 97 |
+
#include <ATen/ops/clamp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 98 |
+
#include <ATen/ops/clamp_max_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 99 |
+
#include <ATen/ops/clamp_min_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 100 |
+
#include <ATen/ops/col_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 101 |
+
#include <ATen/ops/copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 102 |
+
#include <ATen/ops/copysign_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 103 |
+
#include <ATen/ops/cos_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 104 |
+
#include <ATen/ops/cosh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 105 |
+
#include <ATen/ops/crow_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 106 |
+
#include <ATen/ops/cumprod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 107 |
+
#include <ATen/ops/cumsum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 108 |
+
#include <ATen/ops/detach_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 109 |
+
#include <ATen/ops/diag_embed_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 110 |
+
#include <ATen/ops/diagonal_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 111 |
+
#include <ATen/ops/diagonal_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 112 |
+
#include <ATen/ops/digamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 113 |
+
#include <ATen/ops/div_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 114 |
+
#include <ATen/ops/elu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 115 |
+
#include <ATen/ops/elu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 116 |
+
#include <ATen/ops/eq_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 117 |
+
#include <ATen/ops/erf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 118 |
+
#include <ATen/ops/erfc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 119 |
+
#include <ATen/ops/erfinv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 120 |
+
#include <ATen/ops/exp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 121 |
+
#include <ATen/ops/exp2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 122 |
+
#include <ATen/ops/expand_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 123 |
+
#include <ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 124 |
+
#include <ATen/ops/floor_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 125 |
+
#include <ATen/ops/fmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 126 |
+
#include <ATen/ops/fmin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 127 |
+
#include <ATen/ops/fmod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 128 |
+
#include <ATen/ops/frac_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 129 |
+
#include <ATen/ops/fractional_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 130 |
+
#include <ATen/ops/fractional_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 131 |
+
#include <ATen/ops/fractional_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 132 |
+
#include <ATen/ops/gather_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 133 |
+
#include <ATen/ops/gcd_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 134 |
+
#include <ATen/ops/ge_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 135 |
+
#include <ATen/ops/gelu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 136 |
+
#include <ATen/ops/gelu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 137 |
+
#include <ATen/ops/glu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 138 |
+
#include <ATen/ops/gt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 139 |
+
#include <ATen/ops/hardshrink_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 140 |
+
#include <ATen/ops/hardshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 141 |
+
#include <ATen/ops/hardsigmoid_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 142 |
+
#include <ATen/ops/hardsigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 143 |
+
#include <ATen/ops/heaviside_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 144 |
+
#include <ATen/ops/hypot_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 145 |
+
#include <ATen/ops/i0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 146 |
+
#include <ATen/ops/igamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 147 |
+
#include <ATen/ops/igammac_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 148 |
+
#include <ATen/ops/index_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 149 |
+
#include <ATen/ops/index_add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 150 |
+
#include <ATen/ops/index_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 151 |
+
#include <ATen/ops/index_reduce_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 152 |
+
#include <ATen/ops/indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 153 |
+
#include <ATen/ops/isin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 154 |
+
#include <ATen/ops/isneginf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 155 |
+
#include <ATen/ops/isposinf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 156 |
+
#include <ATen/ops/lcm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 157 |
+
#include <ATen/ops/le_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 158 |
+
#include <ATen/ops/leaky_relu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 159 |
+
#include <ATen/ops/leaky_relu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 160 |
+
#include <ATen/ops/lerp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 161 |
+
#include <ATen/ops/lgamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 162 |
+
#include <ATen/ops/lift_fresh_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 163 |
+
#include <ATen/ops/linalg_cholesky_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 164 |
+
#include <ATen/ops/linalg_cross_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 165 |
+
#include <ATen/ops/linalg_inv_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 166 |
+
#include <ATen/ops/linalg_ldl_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 167 |
+
#include <ATen/ops/linalg_ldl_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 168 |
+
#include <ATen/ops/linalg_lu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 169 |
+
#include <ATen/ops/linalg_lu_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 170 |
+
#include <ATen/ops/linalg_lu_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 171 |
+
#include <ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 172 |
+
#include <ATen/ops/linalg_qr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 173 |
+
#include <ATen/ops/linalg_vector_norm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 174 |
+
#include <ATen/ops/log_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 175 |
+
#include <ATen/ops/log10_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 176 |
+
#include <ATen/ops/log1p_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 177 |
+
#include <ATen/ops/log2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 178 |
+
#include <ATen/ops/logaddexp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 179 |
+
#include <ATen/ops/logaddexp2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 180 |
+
#include <ATen/ops/logit_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 181 |
+
#include <ATen/ops/logsumexp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 182 |
+
#include <ATen/ops/lt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 183 |
+
#include <ATen/ops/lu_unpack_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 184 |
+
#include <ATen/ops/max_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 185 |
+
#include <ATen/ops/max_pool2d_with_indices_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 186 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 187 |
+
#include <ATen/ops/maximum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 188 |
+
#include <ATen/ops/mean_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 189 |
+
#include <ATen/ops/min_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 190 |
+
#include <ATen/ops/minimum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 191 |
+
#include <ATen/ops/mish_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 192 |
+
#include <ATen/ops/mm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 193 |
+
#include <ATen/ops/mse_loss_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 194 |
+
#include <ATen/ops/mul_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 195 |
+
#include <ATen/ops/narrow_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 196 |
+
#include <ATen/ops/ne_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 197 |
+
#include <ATen/ops/neg_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 198 |
+
#include <ATen/ops/new_empty_strided_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 199 |
+
#include <ATen/ops/nextafter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 200 |
+
#include <ATen/ops/nll_loss_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 201 |
+
#include <ATen/ops/nll_loss_forward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 202 |
+
#include <ATen/ops/norm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 203 |
+
#include <ATen/ops/permute_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 204 |
+
#include <ATen/ops/pixel_shuffle_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 205 |
+
#include <ATen/ops/pixel_unshuffle_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 206 |
+
#include <ATen/ops/polygamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 207 |
+
#include <ATen/ops/pow_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 208 |
+
#include <ATen/ops/prod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 209 |
+
#include <ATen/ops/reciprocal_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 210 |
+
#include <ATen/ops/reflection_pad1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 211 |
+
#include <ATen/ops/reflection_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 212 |
+
#include <ATen/ops/reflection_pad3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 213 |
+
#include <ATen/ops/reflection_pad3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 214 |
+
#include <ATen/ops/remainder_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 215 |
+
#include <ATen/ops/renorm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 216 |
+
#include <ATen/ops/replication_pad1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 217 |
+
#include <ATen/ops/replication_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 218 |
+
#include <ATen/ops/replication_pad2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 219 |
+
#include <ATen/ops/replication_pad3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 220 |
+
#include <ATen/ops/round_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 221 |
+
#include <ATen/ops/row_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 222 |
+
#include <ATen/ops/rsqrt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 223 |
+
#include <ATen/ops/scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 224 |
+
#include <ATen/ops/scatter_add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 225 |
+
#include <ATen/ops/scatter_reduce_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 226 |
+
#include <ATen/ops/select_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 227 |
+
#include <ATen/ops/select_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 228 |
+
#include <ATen/ops/select_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 229 |
+
#include <ATen/ops/sgn_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 230 |
+
#include <ATen/ops/sigmoid_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 231 |
+
#include <ATen/ops/sigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 232 |
+
#include <ATen/ops/sign_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 233 |
+
#include <ATen/ops/signbit_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 234 |
+
#include <ATen/ops/silu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 235 |
+
#include <ATen/ops/silu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 236 |
+
#include <ATen/ops/sin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 237 |
+
#include <ATen/ops/sinc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 238 |
+
#include <ATen/ops/sinh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 239 |
+
#include <ATen/ops/slice_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 240 |
+
#include <ATen/ops/slice_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 241 |
+
#include <ATen/ops/slow_conv_transpose2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 242 |
+
#include <ATen/ops/smooth_l1_loss_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 243 |
+
#include <ATen/ops/softplus_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 244 |
+
#include <ATen/ops/softplus_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 245 |
+
#include <ATen/ops/softshrink_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 246 |
+
#include <ATen/ops/softshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 247 |
+
#include <ATen/ops/sort_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 248 |
+
#include <ATen/ops/special_airy_ai_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 249 |
+
#include <ATen/ops/special_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 250 |
+
#include <ATen/ops/special_bessel_j1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 251 |
+
#include <ATen/ops/special_bessel_y0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 252 |
+
#include <ATen/ops/special_bessel_y1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 253 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 254 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 255 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 256 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 257 |
+
#include <ATen/ops/special_entr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 258 |
+
#include <ATen/ops/special_erfcx_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 259 |
+
#include <ATen/ops/special_hermite_polynomial_h_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 260 |
+
#include <ATen/ops/special_hermite_polynomial_he_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 261 |
+
#include <ATen/ops/special_i0e_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 262 |
+
#include <ATen/ops/special_i1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 263 |
+
#include <ATen/ops/special_i1e_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 264 |
+
#include <ATen/ops/special_laguerre_polynomial_l_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 265 |
+
#include <ATen/ops/special_legendre_polynomial_p_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 266 |
+
#include <ATen/ops/special_log_ndtr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 267 |
+
#include <ATen/ops/special_modified_bessel_i0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 268 |
+
#include <ATen/ops/special_modified_bessel_i1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 269 |
+
#include <ATen/ops/special_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 270 |
+
#include <ATen/ops/special_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 271 |
+
#include <ATen/ops/special_ndtri_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 272 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 273 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 274 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 275 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 276 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 277 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 278 |
+
#include <ATen/ops/special_spherical_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 279 |
+
#include <ATen/ops/special_xlog1py_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 280 |
+
#include <ATen/ops/special_zeta_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 281 |
+
#include <ATen/ops/split_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 282 |
+
#include <ATen/ops/split_with_sizes_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 283 |
+
#include <ATen/ops/sqrt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 284 |
+
#include <ATen/ops/squeeze_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 285 |
+
#include <ATen/ops/sub_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 286 |
+
#include <ATen/ops/sum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 287 |
+
#include <ATen/ops/t_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 288 |
+
#include <ATen/ops/tan_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 289 |
+
#include <ATen/ops/tanh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 290 |
+
#include <ATen/ops/tanh_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 291 |
+
#include <ATen/ops/threshold_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 292 |
+
#include <ATen/ops/threshold_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 293 |
+
#include <ATen/ops/topk_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 294 |
+
#include <ATen/ops/transpose_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 295 |
+
#include <ATen/ops/triangular_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 296 |
+
#include <ATen/ops/tril_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 297 |
+
#include <ATen/ops/triu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 298 |
+
#include <ATen/ops/trunc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 299 |
+
#include <ATen/ops/unbind_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 300 |
+
#include <ATen/ops/unfold_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 301 |
+
#include <ATen/ops/unsqueeze_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 302 |
+
#include <ATen/ops/upsample_bicubic2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 303 |
+
#include <ATen/ops/upsample_bicubic2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 304 |
+
#include <ATen/ops/upsample_bilinear2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 305 |
+
#include <ATen/ops/upsample_bilinear2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 306 |
+
#include <ATen/ops/upsample_linear1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 307 |
+
#include <ATen/ops/upsample_linear1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 308 |
+
#include <ATen/ops/upsample_nearest1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 309 |
+
#include <ATen/ops/upsample_nearest1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 310 |
+
#include <ATen/ops/upsample_nearest2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 311 |
+
#include <ATen/ops/upsample_nearest2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 312 |
+
#include <ATen/ops/upsample_nearest3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 313 |
+
#include <ATen/ops/upsample_nearest3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 314 |
+
#include <ATen/ops/upsample_trilinear3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 315 |
+
#include <ATen/ops/upsample_trilinear3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 316 |
+
#include <ATen/ops/values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 317 |
+
#include <ATen/ops/view_as_complex_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 318 |
+
#include <ATen/ops/view_as_real_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 319 |
+
#include <ATen/ops/view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 320 |
+
#include <ATen/ops/xlogy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
.venv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeImplicitAutogradFunctions_inl.h>
|
.venv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeimplicitautograd_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_add_batch_dim_compositeimplicitautograd_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_assert_tensor_metadata_compositeimplicitautograd_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_autocast_to_full_precision_compositeimplicitautograd_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_autocast_to_reduced_precision_compositeimplicitautograd_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_backward_compositeimplicitautograd_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_batch_norm_impl_index_compositeimplicitautograd_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_compositeimplicitautograd_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_cast_Byte_compositeimplicitautograd_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_cast_Char_compositeimplicitautograd_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_cast_Double_compositeimplicitautograd_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_cast_Float_compositeimplicitautograd_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cast_Half_compositeimplicitautograd_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cast_Int_compositeimplicitautograd_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_cast_Long_compositeimplicitautograd_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_cast_Short_compositeimplicitautograd_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_choose_qparams_per_tensor_compositeimplicitautograd_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convolution_compositeimplicitautograd_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_convolution_double_backward_compositeimplicitautograd_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_convolution_mode_compositeimplicitautograd_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_cufft_clear_plan_cache_compositeimplicitautograd_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_compositeimplicitautograd_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_compositeimplicitautograd_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_compositeimplicitautograd_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_debug_has_internal_overlap_compositeimplicitautograd_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_dim_arange_compositeimplicitautograd_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_gather_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_compositeimplicitautograd_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_compositeimplicitautograd_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_is_zerotensor_compositeimplicitautograd_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_lu_with_info_compositeimplicitautograd_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_nnpack_available_compositeimplicitautograd_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_pack_padded_sequence_backward_compositeimplicitautograd_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_pad_circular_compositeimplicitautograd_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_pad_enum_compositeimplicitautograd_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_pad_packed_sequence_compositeimplicitautograd_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_propagate_xla_data_compositeimplicitautograd_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_remove_batch_dim_compositeimplicitautograd_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_reshape_from_tensor_compositeimplicitautograd_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_rowwise_prune_compositeimplicitautograd_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_saturate_weight_to_fp16_compositeimplicitautograd_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_compositeimplicitautograd_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_shape_as_tensor_compositeimplicitautograd_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_sobol_engine_draw_compositeimplicitautograd_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_sobol_engine_ff_compositeimplicitautograd_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_sobol_engine_initialize_state_compositeimplicitautograd_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_sobol_engine_scramble_compositeimplicitautograd_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_sparse_log_softmax_compositeimplicitautograd_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_sparse_mm_compositeimplicitautograd_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_sparse_softmax_compositeimplicitautograd_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_sparse_sum_compositeimplicitautograd_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_test_ambiguous_defaults_compositeimplicitautograd_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_compositeimplicitautograd_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_test_check_tensor_compositeimplicitautograd_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_test_serialization_subcmul_compositeimplicitautograd_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_test_string_default_compositeimplicitautograd_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_to_cpu_compositeimplicitautograd_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_unpack_dual_compositeimplicitautograd_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_compositeimplicitautograd_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_compositeimplicitautograd_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_upsample_nearest_exact1d_compositeimplicitautograd_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_upsample_nearest_exact2d_compositeimplicitautograd_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_upsample_nearest_exact3d_compositeimplicitautograd_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_compositeimplicitautograd_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_version_compositeimplicitautograd_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_weight_norm_compositeimplicitautograd_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_compositeimplicitautograd_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_wrapped_linear_prepack_compositeimplicitautograd_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked_compositeimplicitautograd_dispatch.h>
|
| 103 |
+
#include <ATen/ops/absolute_compositeimplicitautograd_dispatch.h>
|
| 104 |
+
#include <ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h>
|
| 105 |
+
#include <ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h>
|
| 106 |
+
#include <ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h>
|
| 107 |
+
#include <ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h>
|
| 108 |
+
#include <ATen/ops/adjoint_compositeimplicitautograd_dispatch.h>
|
| 109 |
+
#include <ATen/ops/affine_grid_generator_backward_compositeimplicitautograd_dispatch.h>
|
| 110 |
+
#include <ATen/ops/align_as_compositeimplicitautograd_dispatch.h>
|
| 111 |
+
#include <ATen/ops/align_tensors_compositeimplicitautograd_dispatch.h>
|
| 112 |
+
#include <ATen/ops/align_to_compositeimplicitautograd_dispatch.h>
|
| 113 |
+
#include <ATen/ops/all_compositeimplicitautograd_dispatch.h>
|
| 114 |
+
#include <ATen/ops/alpha_dropout_compositeimplicitautograd_dispatch.h>
|
| 115 |
+
#include <ATen/ops/and_compositeimplicitautograd_dispatch.h>
|
| 116 |
+
#include <ATen/ops/any_compositeimplicitautograd_dispatch.h>
|
| 117 |
+
#include <ATen/ops/arccos_compositeimplicitautograd_dispatch.h>
|
| 118 |
+
#include <ATen/ops/arccosh_compositeimplicitautograd_dispatch.h>
|
| 119 |
+
#include <ATen/ops/arcsin_compositeimplicitautograd_dispatch.h>
|
| 120 |
+
#include <ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h>
|
| 121 |
+
#include <ATen/ops/arctan_compositeimplicitautograd_dispatch.h>
|
| 122 |
+
#include <ATen/ops/arctan2_compositeimplicitautograd_dispatch.h>
|
| 123 |
+
#include <ATen/ops/arctanh_compositeimplicitautograd_dispatch.h>
|
| 124 |
+
#include <ATen/ops/argsort_compositeimplicitautograd_dispatch.h>
|
| 125 |
+
#include <ATen/ops/argwhere_compositeimplicitautograd_dispatch.h>
|
| 126 |
+
#include <ATen/ops/atleast_1d_compositeimplicitautograd_dispatch.h>
|
| 127 |
+
#include <ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h>
|
| 128 |
+
#include <ATen/ops/atleast_3d_compositeimplicitautograd_dispatch.h>
|
| 129 |
+
#include <ATen/ops/avg_pool1d_compositeimplicitautograd_dispatch.h>
|
| 130 |
+
#include <ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h>
|
| 131 |
+
#include <ATen/ops/bilinear_compositeimplicitautograd_dispatch.h>
|
| 132 |
+
#include <ATen/ops/broadcast_tensors_compositeimplicitautograd_dispatch.h>
|
| 133 |
+
#include <ATen/ops/broadcast_to_compositeimplicitautograd_dispatch.h>
|
| 134 |
+
#include <ATen/ops/can_cast_compositeimplicitautograd_dispatch.h>
|
| 135 |
+
#include <ATen/ops/cartesian_prod_compositeimplicitautograd_dispatch.h>
|
| 136 |
+
#include <ATen/ops/cat_compositeimplicitautograd_dispatch.h>
|
| 137 |
+
#include <ATen/ops/cdist_compositeimplicitautograd_dispatch.h>
|
| 138 |
+
#include <ATen/ops/chain_matmul_compositeimplicitautograd_dispatch.h>
|
| 139 |
+
#include <ATen/ops/chalf_compositeimplicitautograd_dispatch.h>
|
| 140 |
+
#include <ATen/ops/choose_qparams_optimized_compositeimplicitautograd_dispatch.h>
|
| 141 |
+
#include <ATen/ops/chunk_compositeimplicitautograd_dispatch.h>
|
| 142 |
+
#include <ATen/ops/clip_compositeimplicitautograd_dispatch.h>
|
| 143 |
+
#include <ATen/ops/coalesce_compositeimplicitautograd_dispatch.h>
|
| 144 |
+
#include <ATen/ops/column_stack_compositeimplicitautograd_dispatch.h>
|
| 145 |
+
#include <ATen/ops/combinations_compositeimplicitautograd_dispatch.h>
|
| 146 |
+
#include <ATen/ops/concat_compositeimplicitautograd_dispatch.h>
|
| 147 |
+
#include <ATen/ops/concatenate_compositeimplicitautograd_dispatch.h>
|
| 148 |
+
#include <ATen/ops/conj_compositeimplicitautograd_dispatch.h>
|
| 149 |
+
#include <ATen/ops/conj_physical_compositeimplicitautograd_dispatch.h>
|
| 150 |
+
#include <ATen/ops/contiguous_compositeimplicitautograd_dispatch.h>
|
| 151 |
+
#include <ATen/ops/conv1d_compositeimplicitautograd_dispatch.h>
|
| 152 |
+
#include <ATen/ops/conv2d_compositeimplicitautograd_dispatch.h>
|
| 153 |
+
#include <ATen/ops/conv3d_compositeimplicitautograd_dispatch.h>
|
| 154 |
+
#include <ATen/ops/conv_tbc_backward_compositeimplicitautograd_dispatch.h>
|
| 155 |
+
#include <ATen/ops/conv_transpose1d_compositeimplicitautograd_dispatch.h>
|
| 156 |
+
#include <ATen/ops/conv_transpose2d_compositeimplicitautograd_dispatch.h>
|
| 157 |
+
#include <ATen/ops/conv_transpose3d_compositeimplicitautograd_dispatch.h>
|
| 158 |
+
#include <ATen/ops/corrcoef_compositeimplicitautograd_dispatch.h>
|
| 159 |
+
#include <ATen/ops/cosine_embedding_loss_compositeimplicitautograd_dispatch.h>
|
| 160 |
+
#include <ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h>
|
| 161 |
+
#include <ATen/ops/cov_compositeimplicitautograd_dispatch.h>
|
| 162 |
+
#include <ATen/ops/cross_compositeimplicitautograd_dispatch.h>
|
| 163 |
+
#include <ATen/ops/cross_entropy_loss_compositeimplicitautograd_dispatch.h>
|
| 164 |
+
#include <ATen/ops/ctc_loss_compositeimplicitautograd_dispatch.h>
|
| 165 |
+
#include <ATen/ops/cudnn_is_acceptable_compositeimplicitautograd_dispatch.h>
|
| 166 |
+
#include <ATen/ops/cummax_compositeimplicitautograd_dispatch.h>
|
| 167 |
+
#include <ATen/ops/cummaxmin_backward_compositeimplicitautograd_dispatch.h>
|
| 168 |
+
#include <ATen/ops/cummin_compositeimplicitautograd_dispatch.h>
|
| 169 |
+
#include <ATen/ops/cumprod_compositeimplicitautograd_dispatch.h>
|
| 170 |
+
#include <ATen/ops/cumprod_backward_compositeimplicitautograd_dispatch.h>
|
| 171 |
+
#include <ATen/ops/cumsum_compositeimplicitautograd_dispatch.h>
|
| 172 |
+
#include <ATen/ops/cumulative_trapezoid_compositeimplicitautograd_dispatch.h>
|
| 173 |
+
#include <ATen/ops/data_compositeimplicitautograd_dispatch.h>
|
| 174 |
+
#include <ATen/ops/det_compositeimplicitautograd_dispatch.h>
|
| 175 |
+
#include <ATen/ops/diag_compositeimplicitautograd_dispatch.h>
|
| 176 |
+
#include <ATen/ops/diagflat_compositeimplicitautograd_dispatch.h>
|
| 177 |
+
#include <ATen/ops/diagonal_compositeimplicitautograd_dispatch.h>
|
| 178 |
+
#include <ATen/ops/diff_compositeimplicitautograd_dispatch.h>
|
| 179 |
+
#include <ATen/ops/divide_compositeimplicitautograd_dispatch.h>
|
| 180 |
+
#include <ATen/ops/dropout_compositeimplicitautograd_dispatch.h>
|
| 181 |
+
#include <ATen/ops/dsplit_compositeimplicitautograd_dispatch.h>
|
| 182 |
+
#include <ATen/ops/dstack_compositeimplicitautograd_dispatch.h>
|
| 183 |
+
#include <ATen/ops/einsum_compositeimplicitautograd_dispatch.h>
|
| 184 |
+
#include <ATen/ops/embedding_backward_compositeimplicitautograd_dispatch.h>
|
| 185 |
+
#include <ATen/ops/embedding_bag_compositeimplicitautograd_dispatch.h>
|
| 186 |
+
#include <ATen/ops/embedding_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 187 |
+
#include <ATen/ops/empty_compositeimplicitautograd_dispatch.h>
|
| 188 |
+
#include <ATen/ops/expand_as_compositeimplicitautograd_dispatch.h>
|
| 189 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_compositeimplicitautograd_dispatch.h>
|
| 190 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_compositeimplicitautograd_dispatch.h>
|
| 191 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_compositeimplicitautograd_dispatch.h>
|
| 192 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_compositeimplicitautograd_dispatch.h>
|
| 193 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_compositeimplicitautograd_dispatch.h>
|
| 194 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_compositeimplicitautograd_dispatch.h>
|
| 195 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_compositeimplicitautograd_dispatch.h>
|
| 196 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_compositeimplicitautograd_dispatch.h>
|
| 197 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_compositeimplicitautograd_dispatch.h>
|
| 198 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_compositeimplicitautograd_dispatch.h>
|
| 199 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_compositeimplicitautograd_dispatch.h>
|
| 200 |
+
#include <ATen/ops/feature_alpha_dropout_compositeimplicitautograd_dispatch.h>
|
| 201 |
+
#include <ATen/ops/feature_dropout_compositeimplicitautograd_dispatch.h>
|
| 202 |
+
#include <ATen/ops/fft_fft_compositeimplicitautograd_dispatch.h>
|
| 203 |
+
#include <ATen/ops/fft_fft2_compositeimplicitautograd_dispatch.h>
|
| 204 |
+
#include <ATen/ops/fft_fftn_compositeimplicitautograd_dispatch.h>
|
| 205 |
+
#include <ATen/ops/fft_fftshift_compositeimplicitautograd_dispatch.h>
|
| 206 |
+
#include <ATen/ops/fft_hfft_compositeimplicitautograd_dispatch.h>
|
| 207 |
+
#include <ATen/ops/fft_hfft2_compositeimplicitautograd_dispatch.h>
|
| 208 |
+
#include <ATen/ops/fft_hfftn_compositeimplicitautograd_dispatch.h>
|
| 209 |
+
#include <ATen/ops/fft_ifft_compositeimplicitautograd_dispatch.h>
|
| 210 |
+
#include <ATen/ops/fft_ifft2_compositeimplicitautograd_dispatch.h>
|
| 211 |
+
#include <ATen/ops/fft_ifftn_compositeimplicitautograd_dispatch.h>
|
| 212 |
+
#include <ATen/ops/fft_ifftshift_compositeimplicitautograd_dispatch.h>
|
| 213 |
+
#include <ATen/ops/fft_ihfft_compositeimplicitautograd_dispatch.h>
|
| 214 |
+
#include <ATen/ops/fft_ihfft2_compositeimplicitautograd_dispatch.h>
|
| 215 |
+
#include <ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h>
|
| 216 |
+
#include <ATen/ops/fft_irfft_compositeimplicitautograd_dispatch.h>
|
| 217 |
+
#include <ATen/ops/fft_irfft2_compositeimplicitautograd_dispatch.h>
|
| 218 |
+
#include <ATen/ops/fft_irfftn_compositeimplicitautograd_dispatch.h>
|
| 219 |
+
#include <ATen/ops/fft_rfft_compositeimplicitautograd_dispatch.h>
|
| 220 |
+
#include <ATen/ops/fft_rfft2_compositeimplicitautograd_dispatch.h>
|
| 221 |
+
#include <ATen/ops/fft_rfftn_compositeimplicitautograd_dispatch.h>
|
| 222 |
+
#include <ATen/ops/fill_diagonal_compositeimplicitautograd_dispatch.h>
|
| 223 |
+
#include <ATen/ops/fix_compositeimplicitautograd_dispatch.h>
|
| 224 |
+
#include <ATen/ops/flatten_compositeimplicitautograd_dispatch.h>
|
| 225 |
+
#include <ATen/ops/flatten_dense_tensors_compositeimplicitautograd_dispatch.h>
|
| 226 |
+
#include <ATen/ops/fliplr_compositeimplicitautograd_dispatch.h>
|
| 227 |
+
#include <ATen/ops/flipud_compositeimplicitautograd_dispatch.h>
|
| 228 |
+
#include <ATen/ops/float_power_compositeimplicitautograd_dispatch.h>
|
| 229 |
+
#include <ATen/ops/frobenius_norm_compositeimplicitautograd_dispatch.h>
|
| 230 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_compositeimplicitautograd_dispatch.h>
|
| 231 |
+
#include <ATen/ops/gather_compositeimplicitautograd_dispatch.h>
|
| 232 |
+
#include <ATen/ops/gather_backward_compositeimplicitautograd_dispatch.h>
|
| 233 |
+
#include <ATen/ops/ger_compositeimplicitautograd_dispatch.h>
|
| 234 |
+
#include <ATen/ops/gradient_compositeimplicitautograd_dispatch.h>
|
| 235 |
+
#include <ATen/ops/greater_compositeimplicitautograd_dispatch.h>
|
| 236 |
+
#include <ATen/ops/greater_equal_compositeimplicitautograd_dispatch.h>
|
| 237 |
+
#include <ATen/ops/grid_sampler_compositeimplicitautograd_dispatch.h>
|
| 238 |
+
#include <ATen/ops/group_norm_compositeimplicitautograd_dispatch.h>
|
| 239 |
+
#include <ATen/ops/gru_compositeimplicitautograd_dispatch.h>
|
| 240 |
+
#include <ATen/ops/gru_cell_compositeimplicitautograd_dispatch.h>
|
| 241 |
+
#include <ATen/ops/hinge_embedding_loss_compositeimplicitautograd_dispatch.h>
|
| 242 |
+
#include <ATen/ops/histogramdd_compositeimplicitautograd_dispatch.h>
|
| 243 |
+
#include <ATen/ops/hsplit_compositeimplicitautograd_dispatch.h>
|
| 244 |
+
#include <ATen/ops/hstack_compositeimplicitautograd_dispatch.h>
|
| 245 |
+
#include <ATen/ops/imag_compositeimplicitautograd_dispatch.h>
|
| 246 |
+
#include <ATen/ops/index_add_compositeimplicitautograd_dispatch.h>
|
| 247 |
+
#include <ATen/ops/index_copy_compositeimplicitautograd_dispatch.h>
|
| 248 |
+
#include <ATen/ops/index_fill_compositeimplicitautograd_dispatch.h>
|
| 249 |
+
#include <ATen/ops/index_select_compositeimplicitautograd_dispatch.h>
|
| 250 |
+
#include <ATen/ops/index_select_backward_compositeimplicitautograd_dispatch.h>
|
| 251 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_compositeimplicitautograd_dispatch.h>
|
| 252 |
+
#include <ATen/ops/inner_compositeimplicitautograd_dispatch.h>
|
| 253 |
+
#include <ATen/ops/instance_norm_compositeimplicitautograd_dispatch.h>
|
| 254 |
+
#include <ATen/ops/inverse_compositeimplicitautograd_dispatch.h>
|
| 255 |
+
#include <ATen/ops/is_complex_compositeimplicitautograd_dispatch.h>
|
| 256 |
+
#include <ATen/ops/is_conj_compositeimplicitautograd_dispatch.h>
|
| 257 |
+
#include <ATen/ops/is_distributed_compositeimplicitautograd_dispatch.h>
|
| 258 |
+
#include <ATen/ops/is_floating_point_compositeimplicitautograd_dispatch.h>
|
| 259 |
+
#include <ATen/ops/is_inference_compositeimplicitautograd_dispatch.h>
|
| 260 |
+
#include <ATen/ops/is_leaf_compositeimplicitautograd_dispatch.h>
|
| 261 |
+
#include <ATen/ops/is_neg_compositeimplicitautograd_dispatch.h>
|
| 262 |
+
#include <ATen/ops/is_nonzero_compositeimplicitautograd_dispatch.h>
|
| 263 |
+
#include <ATen/ops/is_signed_compositeimplicitautograd_dispatch.h>
|
| 264 |
+
#include <ATen/ops/is_vulkan_available_compositeimplicitautograd_dispatch.h>
|
| 265 |
+
#include <ATen/ops/isclose_compositeimplicitautograd_dispatch.h>
|
| 266 |
+
#include <ATen/ops/isfinite_compositeimplicitautograd_dispatch.h>
|
| 267 |
+
#include <ATen/ops/isreal_compositeimplicitautograd_dispatch.h>
|
| 268 |
+
#include <ATen/ops/istft_compositeimplicitautograd_dispatch.h>
|
| 269 |
+
#include <ATen/ops/item_compositeimplicitautograd_dispatch.h>
|
| 270 |
+
#include <ATen/ops/kl_div_compositeimplicitautograd_dispatch.h>
|
| 271 |
+
#include <ATen/ops/kron_compositeimplicitautograd_dispatch.h>
|
| 272 |
+
#include <ATen/ops/kthvalue_compositeimplicitautograd_dispatch.h>
|
| 273 |
+
#include <ATen/ops/l1_loss_compositeimplicitautograd_dispatch.h>
|
| 274 |
+
#include <ATen/ops/layer_norm_compositeimplicitautograd_dispatch.h>
|
| 275 |
+
#include <ATen/ops/ldexp_compositeimplicitautograd_dispatch.h>
|
| 276 |
+
#include <ATen/ops/less_compositeimplicitautograd_dispatch.h>
|
| 277 |
+
#include <ATen/ops/less_equal_compositeimplicitautograd_dispatch.h>
|
| 278 |
+
#include <ATen/ops/linalg_cholesky_compositeimplicitautograd_dispatch.h>
|
| 279 |
+
#include <ATen/ops/linalg_cond_compositeimplicitautograd_dispatch.h>
|
| 280 |
+
#include <ATen/ops/linalg_det_compositeimplicitautograd_dispatch.h>
|
| 281 |
+
#include <ATen/ops/linalg_diagonal_compositeimplicitautograd_dispatch.h>
|
| 282 |
+
#include <ATen/ops/linalg_eigh_compositeimplicitautograd_dispatch.h>
|
| 283 |
+
#include <ATen/ops/linalg_eigvals_compositeimplicitautograd_dispatch.h>
|
| 284 |
+
#include <ATen/ops/linalg_eigvalsh_compositeimplicitautograd_dispatch.h>
|
| 285 |
+
#include <ATen/ops/linalg_inv_compositeimplicitautograd_dispatch.h>
|
| 286 |
+
#include <ATen/ops/linalg_ldl_factor_compositeimplicitautograd_dispatch.h>
|
| 287 |
+
#include <ATen/ops/linalg_lu_factor_compositeimplicitautograd_dispatch.h>
|
| 288 |
+
#include <ATen/ops/linalg_matmul_compositeimplicitautograd_dispatch.h>
|
| 289 |
+
#include <ATen/ops/linalg_matrix_norm_compositeimplicitautograd_dispatch.h>
|
| 290 |
+
#include <ATen/ops/linalg_matrix_power_compositeimplicitautograd_dispatch.h>
|
| 291 |
+
#include <ATen/ops/linalg_matrix_rank_compositeimplicitautograd_dispatch.h>
|
| 292 |
+
#include <ATen/ops/linalg_multi_dot_compositeimplicitautograd_dispatch.h>
|
| 293 |
+
#include <ATen/ops/linalg_norm_compositeimplicitautograd_dispatch.h>
|
| 294 |
+
#include <ATen/ops/linalg_pinv_compositeimplicitautograd_dispatch.h>
|
| 295 |
+
#include <ATen/ops/linalg_slogdet_compositeimplicitautograd_dispatch.h>
|
| 296 |
+
#include <ATen/ops/linalg_solve_compositeimplicitautograd_dispatch.h>
|
| 297 |
+
#include <ATen/ops/linalg_solve_ex_compositeimplicitautograd_dispatch.h>
|
| 298 |
+
#include <ATen/ops/linalg_svd_compositeimplicitautograd_dispatch.h>
|
| 299 |
+
#include <ATen/ops/linalg_svdvals_compositeimplicitautograd_dispatch.h>
|
| 300 |
+
#include <ATen/ops/linalg_tensorinv_compositeimplicitautograd_dispatch.h>
|
| 301 |
+
#include <ATen/ops/linalg_tensorsolve_compositeimplicitautograd_dispatch.h>
|
| 302 |
+
#include <ATen/ops/linalg_vander_compositeimplicitautograd_dispatch.h>
|
| 303 |
+
#include <ATen/ops/linalg_vecdot_compositeimplicitautograd_dispatch.h>
|
| 304 |
+
#include <ATen/ops/linear_compositeimplicitautograd_dispatch.h>
|
| 305 |
+
#include <ATen/ops/log_sigmoid_compositeimplicitautograd_dispatch.h>
|
| 306 |
+
#include <ATen/ops/log_softmax_compositeimplicitautograd_dispatch.h>
|
| 307 |
+
#include <ATen/ops/logcumsumexp_compositeimplicitautograd_dispatch.h>
|
| 308 |
+
#include <ATen/ops/logdet_compositeimplicitautograd_dispatch.h>
|
| 309 |
+
#include <ATen/ops/logsumexp_compositeimplicitautograd_dispatch.h>
|
| 310 |
+
#include <ATen/ops/lstm_compositeimplicitautograd_dispatch.h>
|
| 311 |
+
#include <ATen/ops/lstm_cell_compositeimplicitautograd_dispatch.h>
|
| 312 |
+
#include <ATen/ops/lu_solve_compositeimplicitautograd_dispatch.h>
|
| 313 |
+
#include <ATen/ops/mH_compositeimplicitautograd_dispatch.h>
|
| 314 |
+
#include <ATen/ops/mT_compositeimplicitautograd_dispatch.h>
|
| 315 |
+
#include <ATen/ops/margin_ranking_loss_compositeimplicitautograd_dispatch.h>
|
| 316 |
+
#include <ATen/ops/masked_select_backward_compositeimplicitautograd_dispatch.h>
|
| 317 |
+
#include <ATen/ops/matmul_compositeimplicitautograd_dispatch.h>
|
| 318 |
+
#include <ATen/ops/matrix_H_compositeimplicitautograd_dispatch.h>
|
| 319 |
+
#include <ATen/ops/matrix_exp_compositeimplicitautograd_dispatch.h>
|
| 320 |
+
#include <ATen/ops/matrix_exp_backward_compositeimplicitautograd_dispatch.h>
|
| 321 |
+
#include <ATen/ops/matrix_power_compositeimplicitautograd_dispatch.h>
|
| 322 |
+
#include <ATen/ops/max_compositeimplicitautograd_dispatch.h>
|
| 323 |
+
#include <ATen/ops/max_pool1d_compositeimplicitautograd_dispatch.h>
|
| 324 |
+
#include <ATen/ops/max_pool1d_with_indices_compositeimplicitautograd_dispatch.h>
|
| 325 |
+
#include <ATen/ops/max_pool2d_compositeimplicitautograd_dispatch.h>
|
| 326 |
+
#include <ATen/ops/max_pool3d_compositeimplicitautograd_dispatch.h>
|
| 327 |
+
#include <ATen/ops/mean_compositeimplicitautograd_dispatch.h>
|
| 328 |
+
#include <ATen/ops/median_compositeimplicitautograd_dispatch.h>
|
| 329 |
+
#include <ATen/ops/meshgrid_compositeimplicitautograd_dispatch.h>
|
| 330 |
+
#include <ATen/ops/min_compositeimplicitautograd_dispatch.h>
|
| 331 |
+
#include <ATen/ops/mish_backward_compositeimplicitautograd_dispatch.h>
|
| 332 |
+
#include <ATen/ops/mode_compositeimplicitautograd_dispatch.h>
|
| 333 |
+
#include <ATen/ops/moveaxis_compositeimplicitautograd_dispatch.h>
|
| 334 |
+
#include <ATen/ops/movedim_compositeimplicitautograd_dispatch.h>
|
| 335 |
+
#include <ATen/ops/msort_compositeimplicitautograd_dispatch.h>
|
| 336 |
+
#include <ATen/ops/multilabel_margin_loss_compositeimplicitautograd_dispatch.h>
|
| 337 |
+
#include <ATen/ops/multiply_compositeimplicitautograd_dispatch.h>
|
| 338 |
+
#include <ATen/ops/nanmean_compositeimplicitautograd_dispatch.h>
|
| 339 |
+
#include <ATen/ops/nanmedian_compositeimplicitautograd_dispatch.h>
|
| 340 |
+
#include <ATen/ops/nanquantile_compositeimplicitautograd_dispatch.h>
|
| 341 |
+
#include <ATen/ops/narrow_compositeimplicitautograd_dispatch.h>
|
| 342 |
+
#include <ATen/ops/native_channel_shuffle_compositeimplicitautograd_dispatch.h>
|
| 343 |
+
#include <ATen/ops/negative_compositeimplicitautograd_dispatch.h>
|
| 344 |
+
#include <ATen/ops/nested_to_padded_tensor_compositeimplicitautograd_dispatch.h>
|
| 345 |
+
#include <ATen/ops/nll_loss_compositeimplicitautograd_dispatch.h>
|
| 346 |
+
#include <ATen/ops/nll_loss2d_compositeimplicitautograd_dispatch.h>
|
| 347 |
+
#include <ATen/ops/nll_loss_nd_compositeimplicitautograd_dispatch.h>
|
| 348 |
+
#include <ATen/ops/nonzero_numpy_compositeimplicitautograd_dispatch.h>
|
| 349 |
+
#include <ATen/ops/norm_compositeimplicitautograd_dispatch.h>
|
| 350 |
+
#include <ATen/ops/norm_except_dim_compositeimplicitautograd_dispatch.h>
|
| 351 |
+
#include <ATen/ops/not_equal_compositeimplicitautograd_dispatch.h>
|
| 352 |
+
#include <ATen/ops/nuclear_norm_compositeimplicitautograd_dispatch.h>
|
| 353 |
+
#include <ATen/ops/numpy_T_compositeimplicitautograd_dispatch.h>
|
| 354 |
+
#include <ATen/ops/one_hot_compositeimplicitautograd_dispatch.h>
|
| 355 |
+
#include <ATen/ops/or_compositeimplicitautograd_dispatch.h>
|
| 356 |
+
#include <ATen/ops/orgqr_compositeimplicitautograd_dispatch.h>
|
| 357 |
+
#include <ATen/ops/outer_compositeimplicitautograd_dispatch.h>
|
| 358 |
+
#include <ATen/ops/output_nr_compositeimplicitautograd_dispatch.h>
|
| 359 |
+
#include <ATen/ops/pad_compositeimplicitautograd_dispatch.h>
|
| 360 |
+
#include <ATen/ops/pad_sequence_compositeimplicitautograd_dispatch.h>
|
| 361 |
+
#include <ATen/ops/pairwise_distance_compositeimplicitautograd_dispatch.h>
|
| 362 |
+
#include <ATen/ops/pdist_compositeimplicitautograd_dispatch.h>
|
| 363 |
+
#include <ATen/ops/pin_memory_compositeimplicitautograd_dispatch.h>
|
| 364 |
+
#include <ATen/ops/pinverse_compositeimplicitautograd_dispatch.h>
|
| 365 |
+
#include <ATen/ops/poisson_nll_loss_compositeimplicitautograd_dispatch.h>
|
| 366 |
+
#include <ATen/ops/positive_compositeimplicitautograd_dispatch.h>
|
| 367 |
+
#include <ATen/ops/prelu_compositeimplicitautograd_dispatch.h>
|
| 368 |
+
#include <ATen/ops/prod_compositeimplicitautograd_dispatch.h>
|
| 369 |
+
#include <ATen/ops/promote_types_compositeimplicitautograd_dispatch.h>
|
| 370 |
+
#include <ATen/ops/qr_compositeimplicitautograd_dispatch.h>
|
| 371 |
+
#include <ATen/ops/quantile_compositeimplicitautograd_dispatch.h>
|
| 372 |
+
#include <ATen/ops/quantized_gru_cell_compositeimplicitautograd_dispatch.h>
|
| 373 |
+
#include <ATen/ops/quantized_lstm_cell_compositeimplicitautograd_dispatch.h>
|
| 374 |
+
#include <ATen/ops/quantized_rnn_relu_cell_compositeimplicitautograd_dispatch.h>
|
| 375 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_compositeimplicitautograd_dispatch.h>
|
| 376 |
+
#include <ATen/ops/rand_compositeimplicitautograd_dispatch.h>
|
| 377 |
+
#include <ATen/ops/randn_compositeimplicitautograd_dispatch.h>
|
| 378 |
+
#include <ATen/ops/ravel_compositeimplicitautograd_dispatch.h>
|
| 379 |
+
#include <ATen/ops/real_compositeimplicitautograd_dispatch.h>
|
| 380 |
+
#include <ATen/ops/refine_names_compositeimplicitautograd_dispatch.h>
|
| 381 |
+
#include <ATen/ops/relu6_compositeimplicitautograd_dispatch.h>
|
| 382 |
+
#include <ATen/ops/rename_compositeimplicitautograd_dispatch.h>
|
| 383 |
+
#include <ATen/ops/repeat_interleave_compositeimplicitautograd_dispatch.h>
|
| 384 |
+
#include <ATen/ops/requires_grad_compositeimplicitautograd_dispatch.h>
|
| 385 |
+
#include <ATen/ops/reshape_compositeimplicitautograd_dispatch.h>
|
| 386 |
+
#include <ATen/ops/reshape_as_compositeimplicitautograd_dispatch.h>
|
| 387 |
+
#include <ATen/ops/resolve_conj_compositeimplicitautograd_dispatch.h>
|
| 388 |
+
#include <ATen/ops/resolve_neg_compositeimplicitautograd_dispatch.h>
|
| 389 |
+
#include <ATen/ops/result_type_compositeimplicitautograd_dispatch.h>
|
| 390 |
+
#include <ATen/ops/retain_grad_compositeimplicitautograd_dispatch.h>
|
| 391 |
+
#include <ATen/ops/retains_grad_compositeimplicitautograd_dispatch.h>
|
| 392 |
+
#include <ATen/ops/rms_norm_compositeimplicitautograd_dispatch.h>
|
| 393 |
+
#include <ATen/ops/rnn_relu_compositeimplicitautograd_dispatch.h>
|
| 394 |
+
#include <ATen/ops/rnn_relu_cell_compositeimplicitautograd_dispatch.h>
|
| 395 |
+
#include <ATen/ops/rnn_tanh_compositeimplicitautograd_dispatch.h>
|
| 396 |
+
#include <ATen/ops/rnn_tanh_cell_compositeimplicitautograd_dispatch.h>
|
| 397 |
+
#include <ATen/ops/row_stack_compositeimplicitautograd_dispatch.h>
|
| 398 |
+
#include <ATen/ops/rrelu_compositeimplicitautograd_dispatch.h>
|
| 399 |
+
#include <ATen/ops/scaled_dot_product_attention_compositeimplicitautograd_dispatch.h>
|
| 400 |
+
#include <ATen/ops/scatter_compositeimplicitautograd_dispatch.h>
|
| 401 |
+
#include <ATen/ops/scatter_add_compositeimplicitautograd_dispatch.h>
|
| 402 |
+
#include <ATen/ops/select_compositeimplicitautograd_dispatch.h>
|
| 403 |
+
#include <ATen/ops/selu_compositeimplicitautograd_dispatch.h>
|
| 404 |
+
#include <ATen/ops/set_compositeimplicitautograd_dispatch.h>
|
| 405 |
+
#include <ATen/ops/set_data_compositeimplicitautograd_dispatch.h>
|
| 406 |
+
#include <ATen/ops/silu_backward_compositeimplicitautograd_dispatch.h>
|
| 407 |
+
#include <ATen/ops/size_compositeimplicitautograd_dispatch.h>
|
| 408 |
+
#include <ATen/ops/slogdet_compositeimplicitautograd_dispatch.h>
|
| 409 |
+
#include <ATen/ops/slow_conv3d_compositeimplicitautograd_dispatch.h>
|
| 410 |
+
#include <ATen/ops/smm_compositeimplicitautograd_dispatch.h>
|
| 411 |
+
#include <ATen/ops/softmax_compositeimplicitautograd_dispatch.h>
|
| 412 |
+
#include <ATen/ops/sort_compositeimplicitautograd_dispatch.h>
|
| 413 |
+
#include <ATen/ops/sparse_bsc_tensor_compositeimplicitautograd_dispatch.h>
|
| 414 |
+
#include <ATen/ops/sparse_bsr_tensor_compositeimplicitautograd_dispatch.h>
|
| 415 |
+
#include <ATen/ops/sparse_coo_tensor_compositeimplicitautograd_dispatch.h>
|
| 416 |
+
#include <ATen/ops/sparse_csc_tensor_compositeimplicitautograd_dispatch.h>
|
| 417 |
+
#include <ATen/ops/sparse_csr_tensor_compositeimplicitautograd_dispatch.h>
|
| 418 |
+
#include <ATen/ops/special_digamma_compositeimplicitautograd_dispatch.h>
|
| 419 |
+
#include <ATen/ops/special_erf_compositeimplicitautograd_dispatch.h>
|
| 420 |
+
#include <ATen/ops/special_erfc_compositeimplicitautograd_dispatch.h>
|
| 421 |
+
#include <ATen/ops/special_erfinv_compositeimplicitautograd_dispatch.h>
|
| 422 |
+
#include <ATen/ops/special_exp2_compositeimplicitautograd_dispatch.h>
|
| 423 |
+
#include <ATen/ops/special_expit_compositeimplicitautograd_dispatch.h>
|
| 424 |
+
#include <ATen/ops/special_expm1_compositeimplicitautograd_dispatch.h>
|
| 425 |
+
#include <ATen/ops/special_gammainc_compositeimplicitautograd_dispatch.h>
|
| 426 |
+
#include <ATen/ops/special_gammaincc_compositeimplicitautograd_dispatch.h>
|
| 427 |
+
#include <ATen/ops/special_gammaln_compositeimplicitautograd_dispatch.h>
|
| 428 |
+
#include <ATen/ops/special_i0_compositeimplicitautograd_dispatch.h>
|
| 429 |
+
#include <ATen/ops/special_log1p_compositeimplicitautograd_dispatch.h>
|
| 430 |
+
#include <ATen/ops/special_log_softmax_compositeimplicitautograd_dispatch.h>
|
| 431 |
+
#include <ATen/ops/special_logit_compositeimplicitautograd_dispatch.h>
|
| 432 |
+
#include <ATen/ops/special_logsumexp_compositeimplicitautograd_dispatch.h>
|
| 433 |
+
#include <ATen/ops/special_multigammaln_compositeimplicitautograd_dispatch.h>
|
| 434 |
+
#include <ATen/ops/special_ndtr_compositeimplicitautograd_dispatch.h>
|
| 435 |
+
#include <ATen/ops/special_polygamma_compositeimplicitautograd_dispatch.h>
|
| 436 |
+
#include <ATen/ops/special_psi_compositeimplicitautograd_dispatch.h>
|
| 437 |
+
#include <ATen/ops/special_round_compositeimplicitautograd_dispatch.h>
|
| 438 |
+
#include <ATen/ops/special_sinc_compositeimplicitautograd_dispatch.h>
|
| 439 |
+
#include <ATen/ops/special_softmax_compositeimplicitautograd_dispatch.h>
|
| 440 |
+
#include <ATen/ops/special_xlogy_compositeimplicitautograd_dispatch.h>
|
| 441 |
+
#include <ATen/ops/split_compositeimplicitautograd_dispatch.h>
|
| 442 |
+
#include <ATen/ops/square_compositeimplicitautograd_dispatch.h>
|
| 443 |
+
#include <ATen/ops/squeeze_compositeimplicitautograd_dispatch.h>
|
| 444 |
+
#include <ATen/ops/sspaddmm_compositeimplicitautograd_dispatch.h>
|
| 445 |
+
#include <ATen/ops/std_compositeimplicitautograd_dispatch.h>
|
| 446 |
+
#include <ATen/ops/std_mean_compositeimplicitautograd_dispatch.h>
|
| 447 |
+
#include <ATen/ops/stft_compositeimplicitautograd_dispatch.h>
|
| 448 |
+
#include <ATen/ops/stride_compositeimplicitautograd_dispatch.h>
|
| 449 |
+
#include <ATen/ops/subtract_compositeimplicitautograd_dispatch.h>
|
| 450 |
+
#include <ATen/ops/sum_compositeimplicitautograd_dispatch.h>
|
| 451 |
+
#include <ATen/ops/sum_to_size_compositeimplicitautograd_dispatch.h>
|
| 452 |
+
#include <ATen/ops/svd_compositeimplicitautograd_dispatch.h>
|
| 453 |
+
#include <ATen/ops/swapaxes_compositeimplicitautograd_dispatch.h>
|
| 454 |
+
#include <ATen/ops/swapdims_compositeimplicitautograd_dispatch.h>
|
| 455 |
+
#include <ATen/ops/sym_numel_compositeimplicitautograd_dispatch.h>
|
| 456 |
+
#include <ATen/ops/sym_size_compositeimplicitautograd_dispatch.h>
|
| 457 |
+
#include <ATen/ops/sym_storage_offset_compositeimplicitautograd_dispatch.h>
|
| 458 |
+
#include <ATen/ops/sym_stride_compositeimplicitautograd_dispatch.h>
|
| 459 |
+
#include <ATen/ops/take_along_dim_compositeimplicitautograd_dispatch.h>
|
| 460 |
+
#include <ATen/ops/tensor_split_compositeimplicitautograd_dispatch.h>
|
| 461 |
+
#include <ATen/ops/tensordot_compositeimplicitautograd_dispatch.h>
|
| 462 |
+
#include <ATen/ops/thnn_conv2d_compositeimplicitautograd_dispatch.h>
|
| 463 |
+
#include <ATen/ops/tile_compositeimplicitautograd_dispatch.h>
|
| 464 |
+
#include <ATen/ops/to_compositeimplicitautograd_dispatch.h>
|
| 465 |
+
#include <ATen/ops/to_dense_compositeimplicitautograd_dispatch.h>
|
| 466 |
+
#include <ATen/ops/to_dense_backward_compositeimplicitautograd_dispatch.h>
|
| 467 |
+
#include <ATen/ops/to_mkldnn_backward_compositeimplicitautograd_dispatch.h>
|
| 468 |
+
#include <ATen/ops/to_sparse_compositeimplicitautograd_dispatch.h>
|
| 469 |
+
#include <ATen/ops/to_sparse_bsc_compositeimplicitautograd_dispatch.h>
|
| 470 |
+
#include <ATen/ops/to_sparse_bsr_compositeimplicitautograd_dispatch.h>
|
| 471 |
+
#include <ATen/ops/to_sparse_csc_compositeimplicitautograd_dispatch.h>
|
| 472 |
+
#include <ATen/ops/to_sparse_csr_compositeimplicitautograd_dispatch.h>
|
| 473 |
+
#include <ATen/ops/trace_backward_compositeimplicitautograd_dispatch.h>
|
| 474 |
+
#include <ATen/ops/transpose_compositeimplicitautograd_dispatch.h>
|
| 475 |
+
#include <ATen/ops/trapezoid_compositeimplicitautograd_dispatch.h>
|
| 476 |
+
#include <ATen/ops/trapz_compositeimplicitautograd_dispatch.h>
|
| 477 |
+
#include <ATen/ops/triplet_margin_loss_compositeimplicitautograd_dispatch.h>
|
| 478 |
+
#include <ATen/ops/true_divide_compositeimplicitautograd_dispatch.h>
|
| 479 |
+
#include <ATen/ops/type_as_compositeimplicitautograd_dispatch.h>
|
| 480 |
+
#include <ATen/ops/unbind_compositeimplicitautograd_dispatch.h>
|
| 481 |
+
#include <ATen/ops/unflatten_compositeimplicitautograd_dispatch.h>
|
| 482 |
+
#include <ATen/ops/unflatten_dense_tensors_compositeimplicitautograd_dispatch.h>
|
| 483 |
+
#include <ATen/ops/unsafe_chunk_compositeimplicitautograd_dispatch.h>
|
| 484 |
+
#include <ATen/ops/upsample_bicubic2d_compositeimplicitautograd_dispatch.h>
|
| 485 |
+
#include <ATen/ops/upsample_bilinear2d_compositeimplicitautograd_dispatch.h>
|
| 486 |
+
#include <ATen/ops/upsample_linear1d_compositeimplicitautograd_dispatch.h>
|
| 487 |
+
#include <ATen/ops/upsample_nearest1d_compositeimplicitautograd_dispatch.h>
|
| 488 |
+
#include <ATen/ops/upsample_nearest2d_compositeimplicitautograd_dispatch.h>
|
| 489 |
+
#include <ATen/ops/upsample_nearest3d_compositeimplicitautograd_dispatch.h>
|
| 490 |
+
#include <ATen/ops/upsample_trilinear3d_compositeimplicitautograd_dispatch.h>
|
| 491 |
+
#include <ATen/ops/value_selecting_reduction_backward_compositeimplicitautograd_dispatch.h>
|
| 492 |
+
#include <ATen/ops/vander_compositeimplicitautograd_dispatch.h>
|
| 493 |
+
#include <ATen/ops/var_compositeimplicitautograd_dispatch.h>
|
| 494 |
+
#include <ATen/ops/var_mean_compositeimplicitautograd_dispatch.h>
|
| 495 |
+
#include <ATen/ops/view_as_compositeimplicitautograd_dispatch.h>
|
| 496 |
+
#include <ATen/ops/vsplit_compositeimplicitautograd_dispatch.h>
|
| 497 |
+
#include <ATen/ops/vstack_compositeimplicitautograd_dispatch.h>
|
| 498 |
+
#include <ATen/ops/where_compositeimplicitautograd_dispatch.h>
|
| 499 |
+
#include <ATen/ops/xor_compositeimplicitautograd_dispatch.h>
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
.venv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h>
|
.venv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeimplicitautogradnestedtensor_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 20 |
+
#include <ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 21 |
+
#include <ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 22 |
+
#include <ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
.venv/Lib/site-packages/torch/include/ATen/Config.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's
|
| 4 |
+
// obvious if you forgot to include Config.h
|
| 5 |
+
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
|
| 6 |
+
//
|
| 7 |
+
// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h
|
| 8 |
+
|
| 9 |
+
#define AT_MKLDNN_ENABLED() 1
|
| 10 |
+
#define AT_MKLDNN_ACL_ENABLED() 0
|
| 11 |
+
#define AT_MKL_ENABLED() 1
|
| 12 |
+
#define AT_MKL_SEQUENTIAL() 0
|
| 13 |
+
#define AT_POCKETFFT_ENABLED() 0
|
| 14 |
+
#define AT_NNPACK_ENABLED() 0
|
| 15 |
+
#define CAFFE2_STATIC_LINK_CUDA() 0
|
| 16 |
+
#define AT_BUILD_WITH_BLAS() 1
|
| 17 |
+
#define AT_BUILD_WITH_LAPACK() 1
|
| 18 |
+
#define AT_PARALLEL_OPENMP 1
|
| 19 |
+
#define AT_PARALLEL_NATIVE 0
|
| 20 |
+
#define AT_BLAS_F2C() 0
|
| 21 |
+
#define AT_BLAS_USE_CBLAS_DOT() 0
|
.venv/Lib/site-packages/torch/include/ATen/Context.h
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/BlasBackend.h>
|
| 4 |
+
#include <ATen/CPUGeneratorImpl.h>
|
| 5 |
+
#include <ATen/DeviceAccelerator.h>
|
| 6 |
+
#include <ATen/LinalgBackend.h>
|
| 7 |
+
#include <ATen/core/ATenGeneral.h>
|
| 8 |
+
#include <ATen/core/DeprecatedTypeProperties.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/LegacyTypeDispatch.h>
|
| 11 |
+
#include <ATen/detail/AcceleratorHooksInterface.h>
|
| 12 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 13 |
+
#include <ATen/detail/HIPHooksInterface.h>
|
| 14 |
+
#include <ATen/detail/IPUHooksInterface.h>
|
| 15 |
+
#include <ATen/detail/MAIAHooksInterface.h>
|
| 16 |
+
#include <ATen/detail/MPSHooksInterface.h>
|
| 17 |
+
#include <ATen/detail/MTIAHooksInterface.h>
|
| 18 |
+
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
| 19 |
+
#include <ATen/detail/XPUHooksInterface.h>
|
| 20 |
+
#include <c10/core/QEngine.h>
|
| 21 |
+
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
| 22 |
+
#include <c10/util/CallOnce.h>
|
| 23 |
+
#include <c10/util/Exception.h>
|
| 24 |
+
#include <c10/util/env.h>
|
| 25 |
+
#include <c10/util/irange.h>
|
| 26 |
+
|
| 27 |
+
#include <cstdint>
|
| 28 |
+
#include <mutex>
|
| 29 |
+
|
| 30 |
+
namespace at {
|
| 31 |
+
|
| 32 |
+
class Tensor;
|
| 33 |
+
|
| 34 |
+
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
|
| 35 |
+
|
| 36 |
+
class TORCH_API Context {
|
| 37 |
+
public:
|
| 38 |
+
Context();
|
| 39 |
+
|
| 40 |
+
const Generator& defaultGenerator(Device device) {
|
| 41 |
+
c10::DeviceType device_type = device.type();
|
| 42 |
+
initCUDAIfNeeded(device_type);
|
| 43 |
+
initHIPIfNeeded(device_type);
|
| 44 |
+
if (device_type == at::kCPU) {
|
| 45 |
+
return at::detail::getDefaultCPUGenerator();
|
| 46 |
+
} else if (device_type == at::kCUDA) {
|
| 47 |
+
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
|
| 48 |
+
} else if (device_type == at::kMPS) {
|
| 49 |
+
return at::detail::getMPSHooks().getDefaultMPSGenerator();
|
| 50 |
+
} else if (device_type == at::kXPU) {
|
| 51 |
+
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
|
| 52 |
+
} else if (device_type == at::kIPU) {
|
| 53 |
+
return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
|
| 54 |
+
} else if (device_type == at::kPrivateUse1) {
|
| 55 |
+
return at::detail::getPrivateUse1Hooks().getDefaultGenerator(
|
| 56 |
+
device.index());
|
| 57 |
+
} else {
|
| 58 |
+
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
|
| 62 |
+
std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
|
| 63 |
+
c10::DeviceType device_type = opt_device_type.has_value()
|
| 64 |
+
? opt_device_type.value()
|
| 65 |
+
: at::getAccelerator(true).value();
|
| 66 |
+
if (device_type == at::kCUDA) {
|
| 67 |
+
return at::detail::getCUDAHooks();
|
| 68 |
+
} else if (device_type == at::kXPU) {
|
| 69 |
+
return at::detail::getXPUHooks();
|
| 70 |
+
} else if (device_type == at::kMPS) {
|
| 71 |
+
return at::detail::getMPSHooks();
|
| 72 |
+
} else if (device_type == at::kPrivateUse1) {
|
| 73 |
+
return at::detail::getPrivateUse1Hooks();
|
| 74 |
+
} else if (device_type == at::kMTIA) {
|
| 75 |
+
return at::detail::getMTIAHooks();
|
| 76 |
+
} else if (device_type == at::kHIP) {
|
| 77 |
+
return at::detail::getHIPHooks();
|
| 78 |
+
} else {
|
| 79 |
+
AT_ERROR(
|
| 80 |
+
c10::DeviceTypeName(device_type), " device type not an accelerator.");
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
|
| 84 |
+
initCUDAIfNeeded(device_type);
|
| 85 |
+
initHIPIfNeeded(device_type);
|
| 86 |
+
initXPUIfNeeded(device_type);
|
| 87 |
+
if (device_type == at::kCPU) {
|
| 88 |
+
return c10::DeviceType::CPU;
|
| 89 |
+
} else if (device_type == at::kCUDA) {
|
| 90 |
+
return at::detail::getCUDAHooks().getDeviceFromPtr(data);
|
| 91 |
+
} else if (device_type == at::kXPU) {
|
| 92 |
+
return at::detail::getXPUHooks().getDeviceFromPtr(data);
|
| 93 |
+
} else if (device_type == at::kPrivateUse1) {
|
| 94 |
+
return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data);
|
| 95 |
+
} else {
|
| 96 |
+
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
bool isPinnedPtr(
|
| 100 |
+
const void* data,
|
| 101 |
+
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
| 102 |
+
auto opt_device_type =
|
| 103 |
+
device_type.has_value() ? device_type : at::getAccelerator();
|
| 104 |
+
if (!opt_device_type.has_value() || // there is no accelerator
|
| 105 |
+
!at::isAccelerator(
|
| 106 |
+
opt_device_type.value())) { // passed device not an accelerator
|
| 107 |
+
return false;
|
| 108 |
+
}
|
| 109 |
+
return getAcceleratorHooksInterface(opt_device_type.value())
|
| 110 |
+
.isPinnedPtr(data);
|
| 111 |
+
}
|
| 112 |
+
Allocator* getPinnedMemoryAllocator(
|
| 113 |
+
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
| 114 |
+
return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
|
| 115 |
+
}
|
| 116 |
+
static bool hasOpenMP();
|
| 117 |
+
static bool hasMKL();
|
| 118 |
+
static bool hasLAPACK();
|
| 119 |
+
static bool hasMKLDNN();
|
| 120 |
+
static bool hasMAGMA() {
|
| 121 |
+
return detail::getCUDAHooks().hasMAGMA();
|
| 122 |
+
}
|
| 123 |
+
static bool hasCUDA() {
|
| 124 |
+
return detail::getCUDAHooks().hasCUDA();
|
| 125 |
+
}
|
| 126 |
+
static bool hasMTIA() {
|
| 127 |
+
return detail::getMTIAHooks().hasMTIA();
|
| 128 |
+
}
|
| 129 |
+
static bool hasCUDART() {
|
| 130 |
+
return detail::getCUDAHooks().hasCUDART();
|
| 131 |
+
}
|
| 132 |
+
static long versionCUDART() {
|
| 133 |
+
return detail::getCUDAHooks().versionCUDART();
|
| 134 |
+
}
|
| 135 |
+
static bool hasCuDNN() {
|
| 136 |
+
return detail::getCUDAHooks().hasCuDNN();
|
| 137 |
+
}
|
| 138 |
+
static long versionCuDNN() {
|
| 139 |
+
return detail::getCUDAHooks().versionCuDNN();
|
| 140 |
+
}
|
| 141 |
+
static bool hasCuSOLVER() {
|
| 142 |
+
return detail::getCUDAHooks().hasCuSOLVER();
|
| 143 |
+
}
|
| 144 |
+
static bool hasCuBLASLt() {
|
| 145 |
+
return detail::getCUDAHooks().hasCuBLASLt();
|
| 146 |
+
}
|
| 147 |
+
static bool hasHIP() {
|
| 148 |
+
return detail::getHIPHooks().hasHIP();
|
| 149 |
+
}
|
| 150 |
+
static bool hasMPS() {
|
| 151 |
+
return detail::getMPSHooks().hasMPS();
|
| 152 |
+
}
|
| 153 |
+
static bool hasIPU() {
|
| 154 |
+
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
|
| 155 |
+
}
|
| 156 |
+
static bool hasXLA() {
|
| 157 |
+
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
|
| 158 |
+
}
|
| 159 |
+
static bool hasXPU() {
|
| 160 |
+
return detail::getXPUHooks().hasXPU();
|
| 161 |
+
}
|
| 162 |
+
static bool hasLazy() {
|
| 163 |
+
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
|
| 164 |
+
}
|
| 165 |
+
static bool hasMAIA() {
|
| 166 |
+
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
|
| 167 |
+
}
|
| 168 |
+
// defined in header so that getNonVariableType has ability to inline
|
| 169 |
+
// call_once check. getNonVariableType is called fairly frequently
|
| 170 |
+
void lazyInitCUDA() {
|
| 171 |
+
c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
|
| 172 |
+
}
|
| 173 |
+
void lazyInitHIP() {
|
| 174 |
+
c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
|
| 175 |
+
}
|
| 176 |
+
void lazyInitXPU() {
|
| 177 |
+
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
|
| 178 |
+
}
|
| 179 |
+
void lazyInitMTIA() {
|
| 180 |
+
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
|
| 181 |
+
}
|
| 182 |
+
void lazyInitPrivateUse1() {
|
| 183 |
+
c10::call_once(thp_init, [&] {
|
| 184 |
+
if (isPrivateUse1HooksRegistered()) {
|
| 185 |
+
at::detail::getPrivateUse1Hooks().initPrivateUse1();
|
| 186 |
+
}
|
| 187 |
+
});
|
| 188 |
+
}
|
| 189 |
+
static const at::cuda::NVRTC& getNVRTC() {
|
| 190 |
+
return detail::getCUDAHooks().nvrtc();
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
static bool setFlushDenormal(bool on);
|
| 194 |
+
|
| 195 |
+
// NB: This method is *purely* whether or not a user requested
|
| 196 |
+
// that CuDNN was enabled, it doesn't actually say anything about
|
| 197 |
+
// whether or not CuDNN is actually usable. Use cudnn_is_acceptable
|
| 198 |
+
// to test this instead
|
| 199 |
+
bool userEnabledCuDNN() const;
|
| 200 |
+
void setUserEnabledCuDNN(bool e);
|
| 201 |
+
bool userEnabledMkldnn() const;
|
| 202 |
+
void setUserEnabledMkldnn(bool e);
|
| 203 |
+
bool benchmarkCuDNN() const;
|
| 204 |
+
void setBenchmarkCuDNN(bool);
|
| 205 |
+
int benchmarkLimitCuDNN() const;
|
| 206 |
+
void setBenchmarkLimitCuDNN(int);
|
| 207 |
+
bool deterministicCuDNN() const;
|
| 208 |
+
void setDeterministicCuDNN(bool);
|
| 209 |
+
bool deterministicMkldnn() const;
|
| 210 |
+
void setDeterministicMkldnn(bool);
|
| 211 |
+
bool userEnabledNNPACK() const;
|
| 212 |
+
void setUserEnabledNNPACK(bool e);
|
| 213 |
+
|
| 214 |
+
// Note [Disabling Fused SDP Kernels]
|
| 215 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 216 |
+
// Flash and Memory Efficient SDP kernels are enabled by default.
|
| 217 |
+
// However, they can be disabled by setting
|
| 218 |
+
// at::globalContext().setUserEnabledFlashSDP(false) flag.
|
| 219 |
+
// This is useful for debugging purposes. For example, if you want to
|
| 220 |
+
// compare the performance of the flash SDP kernels with the unfused
|
| 221 |
+
// kernel, you can disable the flash SDP kernels. By disabling
|
| 222 |
+
// the math SDP kernel, you can force your code to use flash kernels.
|
| 223 |
+
// The math SDP kernel can be disabled by setting
|
| 224 |
+
// at::globalContext().setUserEnabledMathSDP(false) flag.
|
| 225 |
+
void setSDPUseFlash(bool);
|
| 226 |
+
bool userEnabledFlashSDP() const;
|
| 227 |
+
|
| 228 |
+
void setSDPUseMemEfficient(bool);
|
| 229 |
+
bool userEnabledMemEfficientSDP() const;
|
| 230 |
+
|
| 231 |
+
void setSDPUseMath(bool);
|
| 232 |
+
bool userEnabledMathSDP() const;
|
| 233 |
+
|
| 234 |
+
void setSDPUseCuDNN(bool);
|
| 235 |
+
bool userEnabledCuDNNSDP() const;
|
| 236 |
+
|
| 237 |
+
void setAllowFP16BF16ReductionMathSDP(bool);
|
| 238 |
+
bool allowFP16BF16ReductionMathSDP() const;
|
| 239 |
+
|
| 240 |
+
void setSDPUseOverrideable(bool);
|
| 241 |
+
bool userEnabledOverrideableSDP() const;
|
| 242 |
+
|
| 243 |
+
at::LinalgBackend linalgPreferredBackend() const;
|
| 244 |
+
void setLinalgPreferredBackend(at::LinalgBackend);
|
| 245 |
+
|
| 246 |
+
at::BlasBackend blasPreferredBackend();
|
| 247 |
+
void setBlasPreferredBackend(at::BlasBackend);
|
| 248 |
+
|
| 249 |
+
// Note [Enabling Deterministic Operations]
|
| 250 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 251 |
+
// Operations in PyTorch that normally act nondeterministically, but have an
|
| 252 |
+
// alternate deterministic implementation, should satisfy the following
|
| 253 |
+
// requirements:
|
| 254 |
+
//
|
| 255 |
+
// * Include this comment: "See Note [Enabling Deterministic Operations]"
|
| 256 |
+
//
|
| 257 |
+
// * Check the value of `at::globalContext().deterministicAlgorithms()` to
|
| 258 |
+
// toggle
|
| 259 |
+
// between nondeterministic and deterministic implementations.
|
| 260 |
+
//
|
| 261 |
+
// * Have an entry in the list of PyTorch operations that toggle between
|
| 262 |
+
// nondeterministic
|
| 263 |
+
// and deterministic implementations, in the docstring of
|
| 264 |
+
// `use_deterministic_algorithms()` in torch/__init__.py
|
| 265 |
+
//
|
| 266 |
+
// `example_func()` below shows an example of toggling between
|
| 267 |
+
// nondeterministic and deterministic implementations:
|
| 268 |
+
//
|
| 269 |
+
// void example_func() {
|
| 270 |
+
// // See Note [Enabling Deterministic Operations]
|
| 271 |
+
// if (at::globalContext().deterministicAlgorithms()) {
|
| 272 |
+
// example_func_deterministic();
|
| 273 |
+
// } else {
|
| 274 |
+
// example_func_nondeterministic();
|
| 275 |
+
// }
|
| 276 |
+
// }
|
| 277 |
+
|
| 278 |
+
bool deterministicAlgorithms() const;
|
| 279 |
+
bool deterministicAlgorithmsWarnOnly() const;
|
| 280 |
+
void setDeterministicAlgorithms(bool, bool);
|
| 281 |
+
bool deterministicFillUninitializedMemory() const;
|
| 282 |
+
void setDeterministicFillUninitializedMemory(bool);
|
| 283 |
+
|
| 284 |
+
// Note [Writing Nondeterministic Operations]
|
| 285 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 286 |
+
// Operations in PyTorch that act nondeterministically and do not have an
|
| 287 |
+
// alternate deterministic implementation should satisfy the following
|
| 288 |
+
// requirements:
|
| 289 |
+
//
|
| 290 |
+
// * Include this comment: "See Note [Writing Nondeterministic Operations]"
|
| 291 |
+
//
|
| 292 |
+
// * Include a comment explaining why the operation is nondeterministic.
|
| 293 |
+
//
|
| 294 |
+
// * Throw an error when `Context::deterministicAlgorithms()` is true. Most
|
| 295 |
+
// of the time, this should be accomplished by calling
|
| 296 |
+
// `at::globalContext().alertNotDeterminstic()`. However, if the
|
| 297 |
+
// nondeterministic behavior is caused by the CuBLAS workspace
|
| 298 |
+
// configuration in CUDA >= 10.2,
|
| 299 |
+
// `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
|
| 300 |
+
// called instead (in this case, a comment explaining why the operation is
|
| 301 |
+
// nondeterministic is not necessary). See below for details on these
|
| 302 |
+
// methods.
|
| 303 |
+
//
|
| 304 |
+
// * Have an entry in the list of nondeterministic PyTorch operations in the
|
| 305 |
+
// docstring of `use_deterministic_algorithms()` in torch/__init__.py
|
| 306 |
+
//
|
| 307 |
+
// * Have a test function in `test/test_torch.py` whose name begins with
|
| 308 |
+
// `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
|
| 309 |
+
// configuration is the reason for nondeterminism, the operation should be
|
| 310 |
+
// included in the `test_cublas_config_nondeterministic_alert` test. Any new
|
| 311 |
+
// tests should ideally follow a pattern similar to the existing ones.
|
| 312 |
+
//
|
| 313 |
+
// `example_func()` below shows an example of the comments and error-throwing
|
| 314 |
+
// code for a nondeterministic operation:
|
| 315 |
+
//
|
| 316 |
+
// void example_func() {
|
| 317 |
+
// // See Note [Writing Nondeterministic Operations]
|
| 318 |
+
// // Nondeterministic because <reason>
|
| 319 |
+
// at::globalContext().alertNondeterministic("example_func");
|
| 320 |
+
// ...
|
| 321 |
+
// }
|
| 322 |
+
|
| 323 |
+
// Throws an error if `Context::deterministicAlgorithms()` is true
|
| 324 |
+
static void alertNotDeterministic(c10::string_view const& caller);
|
| 325 |
+
|
| 326 |
+
// Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
|
| 327 |
+
// >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
|
| 328 |
+
// ":4096:8". For more details:
|
| 329 |
+
// https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
|
| 330 |
+
void alertCuBLASConfigNotDeterministic() const;
|
| 331 |
+
|
| 332 |
+
void setFloat32MatmulPrecision(const std::string& s);
|
| 333 |
+
bool allowTF32CuDNN() const;
|
| 334 |
+
void setAllowTF32CuDNN(bool);
|
| 335 |
+
bool allowTF32CuBLAS() const;
|
| 336 |
+
void setAllowTF32CuBLAS(bool);
|
| 337 |
+
Float32MatmulPrecision float32MatmulPrecision() const;
|
| 338 |
+
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
|
| 339 |
+
bool allowFP16ReductionCuBLAS() const;
|
| 340 |
+
void setAllowFP16ReductionCuBLAS(bool);
|
| 341 |
+
bool allowBF16ReductionCuBLAS() const;
|
| 342 |
+
void setAllowBF16ReductionCuBLAS(bool);
|
| 343 |
+
at::QEngine qEngine() const;
|
| 344 |
+
void setQEngine(at::QEngine e);
|
| 345 |
+
static const std::vector<at::QEngine>& supportedQEngines();
|
| 346 |
+
static bool isXNNPACKAvailable();
|
| 347 |
+
void setCheckSparseTensorInvariants(bool e);
|
| 348 |
+
bool checkSparseTensorInvariants() const;
|
| 349 |
+
// This method is used to release the original weight after pre-packing.
|
| 350 |
+
// It should be called once before loading/running the model.
|
| 351 |
+
// NB: By default it is set to true for mobile builds.
|
| 352 |
+
void setReleaseWeightsWhenPrepacking(bool e);
|
| 353 |
+
bool releaseWeightsWhenPrepacking() const;
|
| 354 |
+
|
| 355 |
+
void setDisplayVmapFallbackWarnings(bool enabled);
|
| 356 |
+
bool areVmapFallbackWarningsEnabled() const;
|
| 357 |
+
|
| 358 |
+
void setDefaultMobileCPUAllocator();
|
| 359 |
+
void unsetDefaultMobileCPUAllocator();
|
| 360 |
+
bool allowFP16ReductionCPU() const;
|
| 361 |
+
void setAllowFP16ReductionCPU(bool);
|
| 362 |
+
|
| 363 |
+
private:
|
| 364 |
+
void initCUDAIfNeeded(c10::DeviceType p) {
|
| 365 |
+
if (p == c10::DeviceType::CUDA) {
|
| 366 |
+
lazyInitCUDA();
|
| 367 |
+
}
|
| 368 |
+
}
|
| 369 |
+
void initHIPIfNeeded(c10::DeviceType p) {
|
| 370 |
+
if (p == c10::DeviceType::HIP) {
|
| 371 |
+
lazyInitHIP();
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
void initXPUIfNeeded(c10::DeviceType p) {
|
| 375 |
+
if (p == c10::DeviceType::XPU) {
|
| 376 |
+
lazyInitXPU();
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
static bool checkCuBLASConfigDeterministic();
|
| 380 |
+
c10::once_flag thc_init;
|
| 381 |
+
c10::once_flag thh_init;
|
| 382 |
+
c10::once_flag thx_init;
|
| 383 |
+
c10::once_flag th_mtia_init;
|
| 384 |
+
c10::once_flag thp_init;
|
| 385 |
+
bool enabled_cudnn = true;
|
| 386 |
+
bool deterministic_cudnn = false;
|
| 387 |
+
bool deterministic_mkldnn = false;
|
| 388 |
+
bool _deterministic_algorithms = false;
|
| 389 |
+
bool _deterministic_algorithms_warn_only = false;
|
| 390 |
+
bool _deterministic_fill_uninitialized_memory = true;
|
| 391 |
+
bool enabled_flashSDP = true;
|
| 392 |
+
bool enabled_mem_efficientSDP = true;
|
| 393 |
+
bool enabled_mathSDP = true;
|
| 394 |
+
bool enabled_cudnnSDP = true;
|
| 395 |
+
bool enabled_overrideable = true;
|
| 396 |
+
bool allow_fp16_bf16_reduction_mathSDP = false;
|
| 397 |
+
#ifdef USE_ROCM
|
| 398 |
+
bool benchmark_cudnn = true;
|
| 399 |
+
#else
|
| 400 |
+
bool benchmark_cudnn = false;
|
| 401 |
+
#endif
|
| 402 |
+
Float32MatmulPrecision float32_matmul_precision =
|
| 403 |
+
c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
|
| 404 |
+
? at::Float32MatmulPrecision::HIGH
|
| 405 |
+
: at::Float32MatmulPrecision::HIGHEST;
|
| 406 |
+
int benchmark_limit_cudnn = 10;
|
| 407 |
+
bool allow_tf32_cudnn = true;
|
| 408 |
+
bool allow_fp16_reduction_cublas = true;
|
| 409 |
+
bool allow_bf16_reduction_cublas = true;
|
| 410 |
+
bool enabled_mkldnn = true;
|
| 411 |
+
bool enabled_nnpack = true;
|
| 412 |
+
at::LinalgBackend linalg_preferred_backend =
|
| 413 |
+
c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
|
| 414 |
+
? at::LinalgBackend::Cusolver
|
| 415 |
+
: at::LinalgBackend::Default;
|
| 416 |
+
at::BlasBackend blas_preferred_backend =
|
| 417 |
+
#ifdef USE_ROCM
|
| 418 |
+
(c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
|
| 419 |
+
#else
|
| 420 |
+
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
|
| 421 |
+
#endif
|
| 422 |
+
? at::BlasBackend::Cublaslt
|
| 423 |
+
: at::BlasBackend::Cublas;
|
| 424 |
+
#ifdef C10_MOBILE
|
| 425 |
+
bool release_original_weights = true;
|
| 426 |
+
#else
|
| 427 |
+
bool release_original_weights = false;
|
| 428 |
+
#endif
|
| 429 |
+
bool display_vmap_fallback_warnings_ = false;
|
| 430 |
+
std::optional<at::QEngine> quantized_engine = std::nullopt;
|
| 431 |
+
bool enable_sparse_tensor_invariant_checks = false;
|
| 432 |
+
bool allow_fp16_reduction_cpu = false;
|
| 433 |
+
|
| 434 |
+
Allocator* prev_allocator_ptr_{nullptr};
|
| 435 |
+
};
|
| 436 |
+
|
| 437 |
+
TORCH_API Context& globalContext();
|
| 438 |
+
|
| 439 |
+
inline void init() {
|
| 440 |
+
globalContext();
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
TORCH_API Allocator* getCPUAllocator();
|
| 444 |
+
|
| 445 |
+
inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
|
| 446 |
+
Backend p,
|
| 447 |
+
ScalarType s) {
|
| 448 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 449 |
+
p, s);
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
inline DeprecatedTypeProperties& CPU(ScalarType s) {
|
| 453 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 454 |
+
Backend::CPU, s);
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
inline DeprecatedTypeProperties& CUDA(ScalarType s) {
|
| 458 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 459 |
+
Backend::CUDA, s);
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
inline DeprecatedTypeProperties& HIP(ScalarType s) {
|
| 463 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 464 |
+
Backend::HIP, s);
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
inline DeprecatedTypeProperties& MPS(ScalarType s) {
|
| 468 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 469 |
+
Backend::MPS, s);
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
inline bool hasCUDA() {
|
| 473 |
+
return globalContext().hasCUDA();
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
inline bool hasMTIA() {
|
| 477 |
+
return globalContext().hasMTIA();
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
inline bool hasHIP() {
|
| 481 |
+
return globalContext().hasHIP();
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
inline bool hasIPU() {
|
| 485 |
+
return globalContext().hasIPU();
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
inline bool hasXLA() {
|
| 489 |
+
return globalContext().hasXLA();
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
inline bool hasMPS() {
|
| 493 |
+
return globalContext().hasMPS();
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
inline bool hasMAIA() {
|
| 497 |
+
return globalContext().hasMAIA();
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
inline bool hasXPU() {
|
| 501 |
+
return globalContext().hasXPU();
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
// Despite its name, this function returns the number of *CUDA* GPUs.
|
| 505 |
+
inline size_t getNumGPUs() {
|
| 506 |
+
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
|
| 507 |
+
// FUNCTION. If you are interested in interrogating the number of
|
| 508 |
+
// devices for a specific device type, add that function to the
|
| 509 |
+
// relevant library (e.g., similar to at::cuda::device_count())
|
| 510 |
+
if (hasCUDA() && hasHIP()) {
|
| 511 |
+
throw std::runtime_error(
|
| 512 |
+
"Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
|
| 513 |
+
"to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
|
| 514 |
+
"means HIP. Rebuild PyTorch with one or the other disabled.");
|
| 515 |
+
} else if (hasCUDA()) {
|
| 516 |
+
return detail::getCUDAHooks().getNumGPUs();
|
| 517 |
+
} else if (hasHIP()) {
|
| 518 |
+
return detail::getHIPHooks().getNumGPUs();
|
| 519 |
+
} else {
|
| 520 |
+
return 0;
|
| 521 |
+
}
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
inline bool hasOpenMP() {
|
| 525 |
+
return globalContext().hasOpenMP();
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
inline bool hasMKL() {
|
| 529 |
+
return globalContext().hasMKL();
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
inline bool hasLAPACK() {
|
| 533 |
+
return globalContext().hasLAPACK();
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
inline bool hasMAGMA() {
|
| 537 |
+
return globalContext().hasMAGMA();
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
inline bool hasMKLDNN() {
|
| 541 |
+
return globalContext().hasMKLDNN();
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
inline void manual_seed(uint64_t seed) {
|
| 545 |
+
auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
|
| 546 |
+
{
|
| 547 |
+
// See Note [Acquire lock when using random generators]
|
| 548 |
+
std::lock_guard<std::mutex> lock(gen.mutex());
|
| 549 |
+
gen.set_current_seed(seed);
|
| 550 |
+
}
|
| 551 |
+
// NB: Sometimes we build with CUDA, but we don't have any GPUs
|
| 552 |
+
// available. In that case, we must not seed CUDA; it will fail!
|
| 553 |
+
const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
|
| 554 |
+
if (hasCUDA() && cuda_num_gpus > 0) {
|
| 555 |
+
for (const auto i : c10::irange(cuda_num_gpus)) {
|
| 556 |
+
auto cuda_gen = globalContext().defaultGenerator(
|
| 557 |
+
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
|
| 558 |
+
{
|
| 559 |
+
// See Note [Acquire lock when using random generators]
|
| 560 |
+
std::lock_guard<std::mutex> lock(cuda_gen.mutex());
|
| 561 |
+
cuda_gen.set_current_seed(seed);
|
| 562 |
+
}
|
| 563 |
+
}
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
|
| 567 |
+
if (hasXPU() && xpu_num_gpus) {
|
| 568 |
+
for (const auto i : c10::irange(xpu_num_gpus)) {
|
| 569 |
+
auto xpu_gen = globalContext().defaultGenerator(
|
| 570 |
+
Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
|
| 571 |
+
{
|
| 572 |
+
// See Note [Acquire lock when using random generators]
|
| 573 |
+
std::lock_guard<std::mutex> lock(xpu_gen.mutex());
|
| 574 |
+
xpu_gen.set_current_seed(seed);
|
| 575 |
+
}
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
if (hasMPS()) {
|
| 580 |
+
auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
|
| 581 |
+
// See Note [Acquire lock when using random generators]
|
| 582 |
+
std::lock_guard<std::mutex> lock(mps_gen.mutex());
|
| 583 |
+
mps_gen.set_current_seed(seed);
|
| 584 |
+
}
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
// When the global flag `allow_tf32` is set to true, cuBLAS handles are
|
| 588 |
+
// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
|
| 589 |
+
// For some operators, such as addmv, TF32 offers no performance improvement
|
| 590 |
+
// but causes precision loss. To help this case, this class implements
|
| 591 |
+
// a RAII guard that can be used to quickly disable TF32 within its scope.
|
| 592 |
+
//
|
| 593 |
+
// Usage:
|
| 594 |
+
// NoTF32Guard disable_tf32;
|
| 595 |
+
struct TORCH_API NoTF32Guard {
|
| 596 |
+
NoTF32Guard();
|
| 597 |
+
~NoTF32Guard();
|
| 598 |
+
static bool should_disable_tf32();
|
| 599 |
+
|
| 600 |
+
private:
|
| 601 |
+
bool changed = false;
|
| 602 |
+
};
|
| 603 |
+
|
| 604 |
+
struct TORCH_API ROCmBackwardPassGuard {
|
| 605 |
+
ROCmBackwardPassGuard();
|
| 606 |
+
~ROCmBackwardPassGuard();
|
| 607 |
+
static bool is_backward_pass();
|
| 608 |
+
};
|
| 609 |
+
|
| 610 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/Device.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Device.h>
|
.venv/Lib/site-packages/torch/include/ATen/DeviceAccelerator.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/DeviceType.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
|
| 6 |
+
#include <ATen/detail/MTIAHooksInterface.h>
|
| 7 |
+
#include <optional>
|
| 8 |
+
|
| 9 |
+
// This file defines the top level Accelerator concept for PyTorch.
|
| 10 |
+
// A device is an accelerator per the definition here if:
|
| 11 |
+
// - It is mutually exclusive with all other accelerators
|
| 12 |
+
// - It performs asynchronous compute via a Stream/Event system
|
| 13 |
+
// - It provides a set of common APIs as defined by AcceleratorHooksInterface
|
| 14 |
+
//
|
| 15 |
+
// As of today, accelerator devices are (in no particular order):
|
| 16 |
+
// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
|
| 17 |
+
|
| 18 |
+
namespace at {
|
| 19 |
+
|
| 20 |
+
// Ensures that only one accelerator is available (at
|
| 21 |
+
// compile time if possible) and return it.
|
| 22 |
+
// When checked is true, the returned optional always has a value.
|
| 23 |
+
TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
|
| 24 |
+
|
| 25 |
+
TORCH_API bool isAccelerator(c10::DeviceType d);
|
| 26 |
+
|
| 27 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/DeviceGuard.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/IListRef.h>
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <c10/core/DeviceGuard.h>
|
| 6 |
+
#include <c10/core/ScalarType.h> // TensorList whyyyyy
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
// Are you here because you're wondering why DeviceGuard(tensor) no
|
| 11 |
+
// longer works? For code organization reasons, we have temporarily(?)
|
| 12 |
+
// removed this constructor from DeviceGuard. The new way to
|
| 13 |
+
// spell it is:
|
| 14 |
+
//
|
| 15 |
+
// OptionalDeviceGuard guard(device_of(tensor));
|
| 16 |
+
|
| 17 |
+
/// Return the Device of a Tensor, if the Tensor is defined.
|
| 18 |
+
inline std::optional<Device> device_of(const Tensor& t) {
|
| 19 |
+
if (t.defined()) {
|
| 20 |
+
return std::make_optional(t.device());
|
| 21 |
+
} else {
|
| 22 |
+
return std::nullopt;
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
inline std::optional<Device> device_of(const std::optional<Tensor>& t) {
|
| 27 |
+
return t.has_value() ? device_of(t.value()) : std::nullopt;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/// Return the Device of a TensorList, if the list is non-empty and
|
| 31 |
+
/// the first Tensor is defined. (This function implicitly assumes
|
| 32 |
+
/// that all tensors in the list have the same device.)
|
| 33 |
+
inline std::optional<Device> device_of(ITensorListRef t) {
|
| 34 |
+
if (!t.empty()) {
|
| 35 |
+
return device_of(t.front());
|
| 36 |
+
} else {
|
| 37 |
+
return std::nullopt;
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/DimVector.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/DimVector.h>
|
.venv/Lib/site-packages/torch/include/ATen/Dimname.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Dimname.h>
|
.venv/Lib/site-packages/torch/include/ATen/Dispatch.h
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/DeprecatedTypeProperties.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/util/Metaprogramming.h>
|
| 8 |
+
#include <c10/util/complex.h>
|
| 9 |
+
#include <c10/util/string_view.h>
|
| 10 |
+
|
| 11 |
+
#ifdef __CUDACC__
|
| 12 |
+
#include <cuda.h> // For CUDA_VERSION
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
#ifdef TEMPLATE_SELECTIVE_BUILD
|
| 16 |
+
#include <ATen/selected_mobile_ops.h>
|
| 17 |
+
#else
|
| 18 |
+
namespace at {
|
| 19 |
+
/**
|
| 20 |
+
* The method should_include_kernel_dtype() returns true/false
|
| 21 |
+
* based on whether the switching code for a specific dtype should be
|
| 22 |
+
* included based on build time constants generated from tracing model
|
| 23 |
+
* execution. This method will be implemented via code-generation and
|
| 24 |
+
* included in this file when code-gen is ready.
|
| 25 |
+
*/
|
| 26 |
+
inline constexpr bool should_include_kernel_dtype(
|
| 27 |
+
const char* /*kernel_tag_str*/,
|
| 28 |
+
at::ScalarType /*scalar_type*/
|
| 29 |
+
) {
|
| 30 |
+
return true;
|
| 31 |
+
}
|
| 32 |
+
} // namespace at
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* In the Facebook internal build (using BUCK), this macro is enabled by
|
| 37 |
+
* passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
|
| 38 |
+
* binary.
|
| 39 |
+
*/
|
| 40 |
+
#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
|
| 41 |
+
namespace at {
|
| 42 |
+
namespace detail {
|
| 43 |
+
TORCH_API void record_kernel_function_dtype(std::string name);
|
| 44 |
+
}
|
| 45 |
+
} // namespace at
|
| 46 |
+
|
| 47 |
+
#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
|
| 48 |
+
at::detail::record_kernel_function_dtype( \
|
| 49 |
+
std::string(NAME) + "$" + toString(enum_type));
|
| 50 |
+
#else
|
| 51 |
+
#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
|
| 52 |
+
#endif
|
| 53 |
+
|
| 54 |
+
#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
|
| 55 |
+
do { \
|
| 56 |
+
if constexpr (!at::should_include_kernel_dtype( \
|
| 57 |
+
at_dispatch_name, enum_type)) { \
|
| 58 |
+
AT_ERROR( \
|
| 59 |
+
"dtype '", \
|
| 60 |
+
toString(enum_type), \
|
| 61 |
+
"' not selected for kernel tag ", \
|
| 62 |
+
at_dispatch_name); \
|
| 63 |
+
} \
|
| 64 |
+
} while (0)
|
| 65 |
+
|
| 66 |
+
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
| 67 |
+
case enum_type: { \
|
| 68 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 69 |
+
using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
| 70 |
+
return __VA_ARGS__(); \
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
#define AT_DISPATCH_CASE(enum_type, ...) \
|
| 74 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
| 75 |
+
|
| 76 |
+
#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
|
| 77 |
+
case enum_type: { \
|
| 78 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 79 |
+
using scalar_t = scalar_type; \
|
| 80 |
+
using underlying_t C10_UNUSED = typename scalar_t::underlying; \
|
| 81 |
+
const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
|
| 82 |
+
const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
|
| 83 |
+
return __VA_ARGS__(); \
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 87 |
+
enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
|
| 88 |
+
case enum_type: { \
|
| 89 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 90 |
+
using scalar_t = scalar_type; \
|
| 91 |
+
using underlying_t C10_UNUSED = typename scalar_t::underlying; \
|
| 92 |
+
const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
|
| 93 |
+
const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
|
| 94 |
+
C10_UNUSED int bit_width = bitwidth; \
|
| 95 |
+
C10_UNUSED int64_t quant_min = qmin; \
|
| 96 |
+
C10_UNUSED int64_t quant_max = qmax; \
|
| 97 |
+
return __VA_ARGS__(); \
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
namespace detail {
|
| 101 |
+
|
| 102 |
+
inline at::ScalarType scalar_type(at::ScalarType s) {
|
| 103 |
+
return s;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
C10_DEPRECATED_MESSAGE(
|
| 107 |
+
"passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
|
| 108 |
+
"pass an at::ScalarType instead")
|
| 109 |
+
inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
|
| 110 |
+
return t.scalarType();
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
C10_DEPRECATED_MESSAGE(
|
| 114 |
+
"AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
|
| 115 |
+
"use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
|
| 116 |
+
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
|
| 117 |
+
|
| 118 |
+
C10_DEPRECATED_MESSAGE(
|
| 119 |
+
"AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
|
| 120 |
+
"use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
|
| 121 |
+
"instead")
|
| 122 |
+
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
| 123 |
+
|
| 124 |
+
} // namespace detail
|
| 125 |
+
|
| 126 |
+
// The AT_DISPATCH_* family of macros provides the ability to
|
| 127 |
+
// conveniently generate specializations of a kernel over all of the
|
| 128 |
+
// dtypes we care about in PyTorch. We call it "dispatch" because
|
| 129 |
+
// we are "dispatching" to the correct, dtype-specific kernel.
|
| 130 |
+
//
|
| 131 |
+
// A standard usage looks like:
|
| 132 |
+
//
|
| 133 |
+
// AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
|
| 134 |
+
// // Your code here, with 'scalar_t' now defined to
|
| 135 |
+
// // be the dtype in question
|
| 136 |
+
// });
|
| 137 |
+
//
|
| 138 |
+
// There are many variations of this macro, so it's important to
|
| 139 |
+
// understand exactly /which/ dtypes you want to get instantiated, as
|
| 140 |
+
// well as what the "default" set is.
|
| 141 |
+
//
|
| 142 |
+
// The default set of dtypes that are instantiated (e.g., by
|
| 143 |
+
// AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
|
| 144 |
+
// and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
|
| 145 |
+
// but NOT booleans (bool), half-precision floats (Half) or
|
| 146 |
+
// complex number (c10::complex<float>, c10::complex<double>).
|
| 147 |
+
// This "cut" is somewhat historical (the default types are the
|
| 148 |
+
// ones that TH historically supported), but it also reflects the
|
| 149 |
+
// fact that the non-default types are "poorly" behaved (booleans
|
| 150 |
+
// are NOT integers mod 2, half precision operations ~essentially
|
| 151 |
+
// don't exist on CPU, complex numbers are an experimental application).
|
| 152 |
+
//
|
| 153 |
+
// Here are the questions you should generally ask to decide which
|
| 154 |
+
// dispatch you want:
|
| 155 |
+
//
|
| 156 |
+
// 1. Is this an integral or floating point specific operation?
|
| 157 |
+
// (If so, you'll want one of the FLOATING or INTEGRAL macros.)
|
| 158 |
+
//
|
| 159 |
+
// 2. Should half be supported? (If you're on CPU, the answer is almost
|
| 160 |
+
// definitely no. If you do want support, use one of the AND_HALF
|
| 161 |
+
// macros)
|
| 162 |
+
//
|
| 163 |
+
// Much rarer situations:
|
| 164 |
+
//
|
| 165 |
+
// 3. Should bool be supported? (You often have to write your kernel
|
| 166 |
+
// differently if arithmetic operations are involved.) If so,
|
| 167 |
+
// Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
|
| 168 |
+
//
|
| 169 |
+
// 4. Should complex be supported? The answer is almost always no,
|
| 170 |
+
// unless you are working on "generic" code that should work on
|
| 171 |
+
// all dtypes.
|
| 172 |
+
//
|
| 173 |
+
// Parameters:
|
| 174 |
+
// -----------
|
| 175 |
+
//
|
| 176 |
+
// 1. The NAME argument is a "tag" that is used to trace and then
|
| 177 |
+
// conditionally compile fragments of the case statements such
|
| 178 |
+
// that the kernel functions are specialized only for the dtypes
|
| 179 |
+
// that are needed. The NAME parameter *must* be a build time
|
| 180 |
+
// const char* (can't be std::string, etc...)
|
| 181 |
+
//
|
| 182 |
+
// Please ensure that the NAME is unique for every implementation
|
| 183 |
+
// or you run the risk of over-including code for the kernel
|
| 184 |
+
// functions. There is no risk of missing out on any code, so
|
| 185 |
+
// it's mostly a risk of a Type-2 error, and not a Type-1 error.
|
| 186 |
+
//
|
| 187 |
+
// Switch-like syntax:
|
| 188 |
+
// -------------------
|
| 189 |
+
// There is also a switch-case like syntax which is useful if a kernel
|
| 190 |
+
// needs to be specialized for particular scalar types
|
| 191 |
+
//
|
| 192 |
+
// AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
|
| 193 |
+
// AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
|
| 194 |
+
// op_integral<scalar_t>(iter);
|
| 195 |
+
// })
|
| 196 |
+
// AT_DISPATCH_CASE_FLOATING_TYPES([&] {
|
| 197 |
+
// op_floating<scalar_t>(iter);
|
| 198 |
+
// })
|
| 199 |
+
// AT_DISPATCH_CASE(kBool, [&] {
|
| 200 |
+
// op_bool(iter);
|
| 201 |
+
// })
|
| 202 |
+
// );
|
| 203 |
+
//
|
| 204 |
+
// For each AT_DISPATCH_FOO macro, there is a corresponding
|
| 205 |
+
// AT_DISPATCH_CASE_FOO macro which can be used inside of an
|
| 206 |
+
// AT_DISPATCH_SWITCH block.
|
| 207 |
+
|
| 208 |
+
// NB: the the_type variable is not used, but we have kept it for
|
| 209 |
+
// backwards compatibility. It's probably not used by anyone though;
|
| 210 |
+
// but we're just being safe (and it doesn't hurt.) Note we must
|
| 211 |
+
// use it to shut up warnings about unused store.
|
| 212 |
+
|
| 213 |
+
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
| 214 |
+
[&] { \
|
| 215 |
+
const auto& the_type = TYPE; \
|
| 216 |
+
constexpr const char* at_dispatch_name = NAME; \
|
| 217 |
+
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
| 218 |
+
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
| 219 |
+
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
| 220 |
+
switch (_st) { \
|
| 221 |
+
__VA_ARGS__ \
|
| 222 |
+
default: \
|
| 223 |
+
AT_ERROR( \
|
| 224 |
+
'"', \
|
| 225 |
+
at_dispatch_name, \
|
| 226 |
+
"\" not implemented for '", \
|
| 227 |
+
toString(_st), \
|
| 228 |
+
"'"); \
|
| 229 |
+
} \
|
| 230 |
+
}()
|
| 231 |
+
|
| 232 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
| 233 |
+
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
| 234 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
|
| 235 |
+
|
| 236 |
+
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 237 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
| 238 |
+
|
| 239 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
|
| 240 |
+
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
| 241 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 242 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
| 243 |
+
|
| 244 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
|
| 245 |
+
AT_DISPATCH_SWITCH( \
|
| 246 |
+
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
|
| 247 |
+
|
| 248 |
+
#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
|
| 249 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
| 250 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
| 251 |
+
|
| 252 |
+
#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 253 |
+
AT_DISPATCH_SWITCH( \
|
| 254 |
+
TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
|
| 255 |
+
|
| 256 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
|
| 257 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 258 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 259 |
+
|
| 260 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 261 |
+
AT_DISPATCH_SWITCH( \
|
| 262 |
+
TYPE, \
|
| 263 |
+
NAME, \
|
| 264 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 265 |
+
|
| 266 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
|
| 267 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 268 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 269 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 270 |
+
|
| 271 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND2( \
|
| 272 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 273 |
+
AT_DISPATCH_SWITCH( \
|
| 274 |
+
TYPE, \
|
| 275 |
+
NAME, \
|
| 276 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
|
| 277 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 278 |
+
|
| 279 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
|
| 280 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 281 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 282 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 283 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 284 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 285 |
+
|
| 286 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND3( \
|
| 287 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 288 |
+
AT_DISPATCH_SWITCH( \
|
| 289 |
+
TYPE, \
|
| 290 |
+
NAME, \
|
| 291 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
|
| 292 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 293 |
+
|
| 294 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
|
| 295 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 296 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 297 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 298 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 299 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 300 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 301 |
+
|
| 302 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND4( \
|
| 303 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 304 |
+
AT_DISPATCH_SWITCH( \
|
| 305 |
+
TYPE, \
|
| 306 |
+
NAME, \
|
| 307 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
|
| 308 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 309 |
+
|
| 310 |
+
#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
|
| 311 |
+
AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
|
| 312 |
+
AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
|
| 313 |
+
|
| 314 |
+
#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
|
| 315 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
|
| 316 |
+
|
| 317 |
+
#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
|
| 318 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
|
| 319 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 320 |
+
|
| 321 |
+
#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 322 |
+
AT_DISPATCH_SWITCH( \
|
| 323 |
+
TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 324 |
+
|
| 325 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
|
| 326 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 327 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
|
| 328 |
+
|
| 329 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
|
| 330 |
+
AT_DISPATCH_SWITCH( \
|
| 331 |
+
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
|
| 332 |
+
|
| 333 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
|
| 334 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 335 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 336 |
+
|
| 337 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
|
| 338 |
+
SCALARTYPE, TYPE, NAME, ...) \
|
| 339 |
+
AT_DISPATCH_SWITCH( \
|
| 340 |
+
TYPE, \
|
| 341 |
+
NAME, \
|
| 342 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
|
| 343 |
+
SCALARTYPE, __VA_ARGS__))
|
| 344 |
+
|
| 345 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 346 |
+
SCALARTYPE1, SCALARTYPE2, ...) \
|
| 347 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 348 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 349 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 350 |
+
|
| 351 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 352 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 353 |
+
AT_DISPATCH_SWITCH( \
|
| 354 |
+
TYPE, \
|
| 355 |
+
NAME, \
|
| 356 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 357 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 358 |
+
|
| 359 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 360 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 361 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 362 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 363 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 364 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 365 |
+
|
| 366 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 367 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 368 |
+
AT_DISPATCH_SWITCH( \
|
| 369 |
+
TYPE, \
|
| 370 |
+
NAME, \
|
| 371 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 372 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 373 |
+
|
| 374 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 375 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 376 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 377 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 378 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 379 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 380 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 381 |
+
|
| 382 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 383 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 384 |
+
AT_DISPATCH_SWITCH( \
|
| 385 |
+
TYPE, \
|
| 386 |
+
NAME, \
|
| 387 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 388 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 389 |
+
|
| 390 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 391 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
|
| 392 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 393 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 394 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 395 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 396 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 397 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
|
| 398 |
+
|
| 399 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 400 |
+
SCALARTYPE1, \
|
| 401 |
+
SCALARTYPE2, \
|
| 402 |
+
SCALARTYPE3, \
|
| 403 |
+
SCALARTYPE4, \
|
| 404 |
+
SCALARTYPE5, \
|
| 405 |
+
TYPE, \
|
| 406 |
+
NAME, \
|
| 407 |
+
...) \
|
| 408 |
+
AT_DISPATCH_SWITCH( \
|
| 409 |
+
TYPE, \
|
| 410 |
+
NAME, \
|
| 411 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 412 |
+
SCALARTYPE1, \
|
| 413 |
+
SCALARTYPE2, \
|
| 414 |
+
SCALARTYPE3, \
|
| 415 |
+
SCALARTYPE4, \
|
| 416 |
+
SCALARTYPE5, \
|
| 417 |
+
__VA_ARGS__))
|
| 418 |
+
|
| 419 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 420 |
+
SCALARTYPE1, \
|
| 421 |
+
SCALARTYPE2, \
|
| 422 |
+
SCALARTYPE3, \
|
| 423 |
+
SCALARTYPE4, \
|
| 424 |
+
SCALARTYPE5, \
|
| 425 |
+
SCALARTYPE6, \
|
| 426 |
+
...) \
|
| 427 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 428 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 429 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 430 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 431 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 432 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 433 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
|
| 434 |
+
|
| 435 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 436 |
+
SCALARTYPE1, \
|
| 437 |
+
SCALARTYPE2, \
|
| 438 |
+
SCALARTYPE3, \
|
| 439 |
+
SCALARTYPE4, \
|
| 440 |
+
SCALARTYPE5, \
|
| 441 |
+
SCALARTYPE6, \
|
| 442 |
+
TYPE, \
|
| 443 |
+
NAME, \
|
| 444 |
+
...) \
|
| 445 |
+
AT_DISPATCH_SWITCH( \
|
| 446 |
+
TYPE, \
|
| 447 |
+
NAME, \
|
| 448 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 449 |
+
SCALARTYPE1, \
|
| 450 |
+
SCALARTYPE2, \
|
| 451 |
+
SCALARTYPE3, \
|
| 452 |
+
SCALARTYPE4, \
|
| 453 |
+
SCALARTYPE5, \
|
| 454 |
+
SCALARTYPE6, \
|
| 455 |
+
__VA_ARGS__))
|
| 456 |
+
|
| 457 |
+
#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
| 458 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
| 459 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
| 460 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
| 461 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
| 462 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
|
| 463 |
+
|
| 464 |
+
#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
| 465 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
| 466 |
+
|
| 467 |
+
#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
|
| 468 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
|
| 469 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 470 |
+
|
| 471 |
+
#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 472 |
+
AT_DISPATCH_SWITCH( \
|
| 473 |
+
TYPE, \
|
| 474 |
+
NAME, \
|
| 475 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 476 |
+
|
| 477 |
+
#define AT_DISPATCH_CASE_ALL_TYPES(...) \
|
| 478 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
|
| 479 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
|
| 480 |
+
|
| 481 |
+
#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
|
| 482 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
|
| 483 |
+
|
| 484 |
+
#define AT_DISPATCH_CASE_QINT_TYPES(...) \
|
| 485 |
+
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
|
| 486 |
+
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
|
| 487 |
+
AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
|
| 488 |
+
|
| 489 |
+
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
|
| 490 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
|
| 491 |
+
|
| 492 |
+
#define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
|
| 493 |
+
AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
|
| 494 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 495 |
+
|
| 496 |
+
#define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 497 |
+
AT_DISPATCH_SWITCH( \
|
| 498 |
+
TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 499 |
+
|
| 500 |
+
#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
|
| 501 |
+
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
|
| 502 |
+
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
|
| 503 |
+
|
| 504 |
+
#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
|
| 505 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
|
| 506 |
+
|
| 507 |
+
#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
|
| 508 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 509 |
+
at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
|
| 510 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 511 |
+
at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
|
| 512 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 513 |
+
at::kQInt32, \
|
| 514 |
+
at::qint32, \
|
| 515 |
+
CHAR_BIT * sizeof(int), \
|
| 516 |
+
INT_MIN, \
|
| 517 |
+
INT_MAX, \
|
| 518 |
+
__VA_ARGS__) \
|
| 519 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 520 |
+
at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
|
| 521 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 522 |
+
at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
|
| 523 |
+
|
| 524 |
+
#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
|
| 525 |
+
AT_DISPATCH_SWITCH( \
|
| 526 |
+
TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
|
| 527 |
+
|
| 528 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
|
| 529 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 530 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
|
| 531 |
+
|
| 532 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
|
| 533 |
+
AT_DISPATCH_SWITCH( \
|
| 534 |
+
TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
|
| 535 |
+
|
| 536 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
|
| 537 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 538 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 539 |
+
|
| 540 |
+
#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 541 |
+
AT_DISPATCH_SWITCH( \
|
| 542 |
+
TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 543 |
+
|
| 544 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
|
| 545 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 546 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 547 |
+
|
| 548 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 549 |
+
AT_DISPATCH_SWITCH( \
|
| 550 |
+
TYPE, \
|
| 551 |
+
NAME, \
|
| 552 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
|
| 553 |
+
|
| 554 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
|
| 555 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 556 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 557 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 558 |
+
|
| 559 |
+
#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 560 |
+
AT_DISPATCH_SWITCH( \
|
| 561 |
+
TYPE, \
|
| 562 |
+
NAME, \
|
| 563 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 564 |
+
|
| 565 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 566 |
+
SCALARTYPE1, SCALARTYPE2, ...) \
|
| 567 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 568 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 569 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 570 |
+
|
| 571 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 572 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 573 |
+
AT_DISPATCH_SWITCH( \
|
| 574 |
+
TYPE, \
|
| 575 |
+
NAME, \
|
| 576 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 577 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 578 |
+
|
| 579 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
|
| 580 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 581 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 582 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 583 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 584 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 585 |
+
|
| 586 |
+
#define AT_DISPATCH_ALL_TYPES_AND3( \
|
| 587 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 588 |
+
AT_DISPATCH_SWITCH( \
|
| 589 |
+
TYPE, \
|
| 590 |
+
NAME, \
|
| 591 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND3( \
|
| 592 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 593 |
+
|
| 594 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 595 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 596 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 597 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 598 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 599 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 600 |
+
|
| 601 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 602 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 603 |
+
AT_DISPATCH_SWITCH( \
|
| 604 |
+
TYPE, \
|
| 605 |
+
NAME, \
|
| 606 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 607 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 608 |
+
|
| 609 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 610 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 611 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 612 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 613 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 614 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 615 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 616 |
+
|
| 617 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 618 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 619 |
+
AT_DISPATCH_SWITCH( \
|
| 620 |
+
TYPE, \
|
| 621 |
+
NAME, \
|
| 622 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 623 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 624 |
+
|
| 625 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 626 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
|
| 627 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 628 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 629 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 630 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 631 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 632 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
|
| 633 |
+
|
| 634 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 635 |
+
SCALARTYPE1, \
|
| 636 |
+
SCALARTYPE2, \
|
| 637 |
+
SCALARTYPE3, \
|
| 638 |
+
SCALARTYPE4, \
|
| 639 |
+
SCALARTYPE5, \
|
| 640 |
+
TYPE, \
|
| 641 |
+
NAME, \
|
| 642 |
+
...) \
|
| 643 |
+
AT_DISPATCH_SWITCH( \
|
| 644 |
+
TYPE, \
|
| 645 |
+
NAME, \
|
| 646 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 647 |
+
SCALARTYPE1, \
|
| 648 |
+
SCALARTYPE2, \
|
| 649 |
+
SCALARTYPE3, \
|
| 650 |
+
SCALARTYPE4, \
|
| 651 |
+
SCALARTYPE5, \
|
| 652 |
+
__VA_ARGS__))
|
| 653 |
+
|
| 654 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 655 |
+
SCALARTYPE1, \
|
| 656 |
+
SCALARTYPE2, \
|
| 657 |
+
SCALARTYPE3, \
|
| 658 |
+
SCALARTYPE4, \
|
| 659 |
+
SCALARTYPE5, \
|
| 660 |
+
SCALARTYPE6, \
|
| 661 |
+
...) \
|
| 662 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 663 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 664 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 665 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 666 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 667 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 668 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
|
| 669 |
+
|
| 670 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 671 |
+
SCALARTYPE1, \
|
| 672 |
+
SCALARTYPE2, \
|
| 673 |
+
SCALARTYPE3, \
|
| 674 |
+
SCALARTYPE4, \
|
| 675 |
+
SCALARTYPE5, \
|
| 676 |
+
SCALARTYPE6, \
|
| 677 |
+
TYPE, \
|
| 678 |
+
NAME, \
|
| 679 |
+
...) \
|
| 680 |
+
AT_DISPATCH_SWITCH( \
|
| 681 |
+
TYPE, \
|
| 682 |
+
NAME, \
|
| 683 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 684 |
+
SCALARTYPE1, \
|
| 685 |
+
SCALARTYPE2, \
|
| 686 |
+
SCALARTYPE3, \
|
| 687 |
+
SCALARTYPE4, \
|
| 688 |
+
SCALARTYPE5, \
|
| 689 |
+
SCALARTYPE6, \
|
| 690 |
+
__VA_ARGS__))
|
| 691 |
+
|
| 692 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 693 |
+
SCALARTYPE1, \
|
| 694 |
+
SCALARTYPE2, \
|
| 695 |
+
SCALARTYPE3, \
|
| 696 |
+
SCALARTYPE4, \
|
| 697 |
+
SCALARTYPE5, \
|
| 698 |
+
SCALARTYPE6, \
|
| 699 |
+
SCALARTYPE7, \
|
| 700 |
+
...) \
|
| 701 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 702 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 703 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 704 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 705 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 706 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 707 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
|
| 708 |
+
AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
|
| 709 |
+
|
| 710 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 711 |
+
SCALARTYPE1, \
|
| 712 |
+
SCALARTYPE2, \
|
| 713 |
+
SCALARTYPE3, \
|
| 714 |
+
SCALARTYPE4, \
|
| 715 |
+
SCALARTYPE5, \
|
| 716 |
+
SCALARTYPE6, \
|
| 717 |
+
SCALARTYPE7, \
|
| 718 |
+
TYPE, \
|
| 719 |
+
NAME, \
|
| 720 |
+
...) \
|
| 721 |
+
AT_DISPATCH_SWITCH( \
|
| 722 |
+
TYPE, \
|
| 723 |
+
NAME, \
|
| 724 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 725 |
+
SCALARTYPE1, \
|
| 726 |
+
SCALARTYPE2, \
|
| 727 |
+
SCALARTYPE3, \
|
| 728 |
+
SCALARTYPE4, \
|
| 729 |
+
SCALARTYPE5, \
|
| 730 |
+
SCALARTYPE6, \
|
| 731 |
+
SCALARTYPE7, \
|
| 732 |
+
__VA_ARGS__))
|
| 733 |
+
|
| 734 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 735 |
+
SCALARTYPE1, \
|
| 736 |
+
SCALARTYPE2, \
|
| 737 |
+
SCALARTYPE3, \
|
| 738 |
+
SCALARTYPE4, \
|
| 739 |
+
SCALARTYPE5, \
|
| 740 |
+
SCALARTYPE6, \
|
| 741 |
+
SCALARTYPE7, \
|
| 742 |
+
SCALARTYPE8, \
|
| 743 |
+
...) \
|
| 744 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 745 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 746 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 747 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 748 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 749 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 750 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
|
| 751 |
+
AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
|
| 752 |
+
AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
|
| 753 |
+
|
| 754 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 755 |
+
SCALARTYPE1, \
|
| 756 |
+
SCALARTYPE2, \
|
| 757 |
+
SCALARTYPE3, \
|
| 758 |
+
SCALARTYPE4, \
|
| 759 |
+
SCALARTYPE5, \
|
| 760 |
+
SCALARTYPE6, \
|
| 761 |
+
SCALARTYPE7, \
|
| 762 |
+
SCALARTYPE8, \
|
| 763 |
+
TYPE, \
|
| 764 |
+
NAME, \
|
| 765 |
+
...) \
|
| 766 |
+
AT_DISPATCH_SWITCH( \
|
| 767 |
+
TYPE, \
|
| 768 |
+
NAME, \
|
| 769 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 770 |
+
SCALARTYPE1, \
|
| 771 |
+
SCALARTYPE2, \
|
| 772 |
+
SCALARTYPE3, \
|
| 773 |
+
SCALARTYPE4, \
|
| 774 |
+
SCALARTYPE5, \
|
| 775 |
+
SCALARTYPE6, \
|
| 776 |
+
SCALARTYPE7, \
|
| 777 |
+
SCALARTYPE8, \
|
| 778 |
+
__VA_ARGS__))
|
| 779 |
+
|
| 780 |
+
#define AT_DISPATCH_CASE_BIT_TYPES(...) \
|
| 781 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
|
| 782 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
|
| 783 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
|
| 784 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
|
| 785 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
|
| 786 |
+
|
| 787 |
+
#define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
|
| 788 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
|
| 789 |
+
|
| 790 |
+
#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
|
| 791 |
+
AT_DISPATCH_SWITCH( \
|
| 792 |
+
TYPE, \
|
| 793 |
+
NAME, \
|
| 794 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
| 795 |
+
at::ScalarType::Int, index_t, __VA_ARGS__) \
|
| 796 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
| 797 |
+
at::ScalarType::Long, index_t, __VA_ARGS__))
|
| 798 |
+
|
| 799 |
+
// ----------------------------------------------------------------------------
|
| 800 |
+
// DEPRECATED MACROS, DON'T USE THESE
|
| 801 |
+
// ----------------------------------------------------------------------------
|
| 802 |
+
|
| 803 |
+
#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
|
| 804 |
+
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
|
| 805 |
+
AT_DISPATCH_SWITCH( \
|
| 806 |
+
TYPE, \
|
| 807 |
+
NAME, \
|
| 808 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
|
.venv/Lib/site-packages/torch/include/ATen/FuncTorchTLS.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Macros.h>
|
| 4 |
+
#include <memory>
|
| 5 |
+
|
| 6 |
+
namespace at::functorch {
|
| 7 |
+
|
| 8 |
+
// NOTE [functorch TLS in pytorch/pytorch]
|
| 9 |
+
//
|
| 10 |
+
// functorch lives out-of-tree. However, it has some TLS that needs to be
|
| 11 |
+
// propagated. The solution for that is we store a pointer to the TLS
|
| 12 |
+
// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
|
| 13 |
+
// include whatever functorch needs.
|
| 14 |
+
//
|
| 15 |
+
// We need to store a pointer due to the indirection:
|
| 16 |
+
// inside functorch, we will create a subclass of FunctorchTLSBase called
|
| 17 |
+
// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
|
| 18 |
+
// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
|
| 19 |
+
// yet.
|
| 20 |
+
//
|
| 21 |
+
// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
|
| 22 |
+
// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
|
| 23 |
+
// We can't directly pass around FunctorchTLSBase (without a pointer) because
|
| 24 |
+
// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
|
| 25 |
+
// more elements.
|
| 26 |
+
struct TORCH_API FuncTorchTLSBase {
|
| 27 |
+
virtual ~FuncTorchTLSBase() = default;
|
| 28 |
+
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
|
| 29 |
+
|
| 30 |
+
virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
|
| 31 |
+
virtual void checkSupportsCppAutogradFunction() const = 0;
|
| 32 |
+
virtual void checkSupportsInplaceRequiresGrad() const = 0;
|
| 33 |
+
virtual void checkSupportsRetainGrad() const = 0;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
// returns deepcopy of the functorch tls
|
| 37 |
+
TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
|
| 38 |
+
|
| 39 |
+
// sets the functorch tls. always does a deep copy.
|
| 40 |
+
TORCH_API void setFuncTorchTLS(
|
| 41 |
+
const std::shared_ptr<const FuncTorchTLSBase>& state);
|
| 42 |
+
|
| 43 |
+
// get a mutable reference to the functorch tls
|
| 44 |
+
TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
|
| 45 |
+
|
| 46 |
+
} // namespace at::functorch
|
.venv/Lib/site-packages/torch/include/ATen/FunctionalTensorWrapper.h
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/ArrayRef.h>
|
| 5 |
+
#include <ATen/FunctionalStorageImpl.h>
|
| 6 |
+
#include <ATen/core/IListRef.h>
|
| 7 |
+
#include <ATen/core/List.h>
|
| 8 |
+
#include <ATen/core/boxing/BoxedKernel.h>
|
| 9 |
+
#include <ATen/core/boxing/impl/boxing.h>
|
| 10 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 11 |
+
|
| 12 |
+
#include <c10/core/DispatchKey.h>
|
| 13 |
+
|
| 14 |
+
namespace at {
|
| 15 |
+
|
| 16 |
+
// Note [Functionalization Pass In Core]
|
| 17 |
+
// The Functionalization pass is used to remove aliasing from a pytorch program.
|
| 18 |
+
//
|
| 19 |
+
// This is useful for backends that don't support aliasing, like XLA and Vulkan.
|
| 20 |
+
// It's also necessary in order to remove mutation from a program, which is
|
| 21 |
+
// needed in Functorch.
|
| 22 |
+
//
|
| 23 |
+
// Consider this program:
|
| 24 |
+
// a = torch.ones(...)
|
| 25 |
+
// b = a.view(...)
|
| 26 |
+
// b.add_(1)
|
| 27 |
+
//
|
| 28 |
+
// In this program, b is meant to alias with a due to the use of view(). At the
|
| 29 |
+
// end of the program, both a and b are full of 2's. However, backends that
|
| 30 |
+
// don't support aliasing aren't able to correctly implement the view()
|
| 31 |
+
// operator. Instead, they can opt into the Functionalization pass, which will
|
| 32 |
+
// sit between the user and the backend, and provide the necessary aliasing
|
| 33 |
+
// logic.
|
| 34 |
+
//
|
| 35 |
+
// The functionalization pass will turn the above program into a slightly
|
| 36 |
+
// different program that has the same semantics, transparently to the user,
|
| 37 |
+
// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
|
| 38 |
+
// a.view_copy(...) # view() replaced with view_copy(). Backends like
|
| 39 |
+
// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
|
| 40 |
+
// pass machinery knows that a and b are aliased - it applies b's mutation to a
|
| 41 |
+
// too.
|
| 42 |
+
//
|
| 43 |
+
// So, how does the functionalization pass keep track of which tensors are
|
| 44 |
+
// aliased? The pass works by wrapping EVERY tensor in the program inside of a
|
| 45 |
+
// FunctionalTensorWrapper, which knows about its alias'd tensors.
|
| 46 |
+
//
|
| 47 |
+
// See Note [Functionalization: Alias Removal] for details on the aliasing
|
| 48 |
+
// machinery. See Note [Functionalization: Mutation Removal] for details on
|
| 49 |
+
// mutation removal.
|
| 50 |
+
struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
| 51 |
+
explicit FunctionalTensorWrapper(const Tensor& value);
|
| 52 |
+
// Additional constructor to create a FunctionalTensorWrapper directly from an
|
| 53 |
+
// underlying tensor that was created from a view. For example, the code b =
|
| 54 |
+
// a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
|
| 55 |
+
// view1_meta)
|
| 56 |
+
explicit FunctionalTensorWrapper(
|
| 57 |
+
const Tensor& view_value,
|
| 58 |
+
const FunctionalTensorWrapper* base,
|
| 59 |
+
const functionalization::ViewMeta& meta);
|
| 60 |
+
|
| 61 |
+
// Get the underlying, actual tensor, that doesn't know anything about
|
| 62 |
+
// functionalization.
|
| 63 |
+
const Tensor& value() const {
|
| 64 |
+
return value_;
|
| 65 |
+
};
|
| 66 |
+
// The concept of "level" is only ever important to functorch; it's exposed
|
| 67 |
+
// here as more of a hook for functorch to use.
|
| 68 |
+
int64_t level() const {
|
| 69 |
+
return level_;
|
| 70 |
+
};
|
| 71 |
+
void set_level(int64_t level) {
|
| 72 |
+
level_ = level;
|
| 73 |
+
}
|
| 74 |
+
bool has_metadata_mutation() const {
|
| 75 |
+
return has_metadata_mutation_;
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
void mark_mutation() {
|
| 79 |
+
functional_storage_impl()->mark_mutation();
|
| 80 |
+
}
|
| 81 |
+
// Denotes a mutation that's hidden from autograd,
|
| 82 |
+
// e.g. for the purposes of passing a tensor to a triton kernel
|
| 83 |
+
void mark_mutation_hidden_from_autograd() {
|
| 84 |
+
functional_storage_impl()->mark_mutation_hidden_from_autograd();
|
| 85 |
+
}
|
| 86 |
+
void mark_mutation_during_no_grad_or_inference_mode() {
|
| 87 |
+
functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
|
| 88 |
+
}
|
| 89 |
+
// Are all the mutations happening to the tensor hidden from autograd
|
| 90 |
+
bool are_all_mutations_hidden_from_autograd() const {
|
| 91 |
+
return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
|
| 92 |
+
}
|
| 93 |
+
// Did all mutations happen under no_grad or inference_mode
|
| 94 |
+
// (We also need to ignore mutations fully hidden from autograd here)
|
| 95 |
+
bool are_all_mutations_under_no_grad_or_inference_mode() const {
|
| 96 |
+
return functional_storage_impl()
|
| 97 |
+
->are_all_mutations_under_no_grad_or_inference_mode();
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
|
| 101 |
+
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
bool is_symbolic() const {
|
| 105 |
+
return is_symbolic_;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Runs the forward_fn of every ViewMeta collected in the current instance
|
| 109 |
+
// to some other base.
|
| 110 |
+
Tensor apply_view_metas(const Tensor& base);
|
| 111 |
+
|
| 112 |
+
// Sync's the underlying tensor with its alias, if it's out of date. This
|
| 113 |
+
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
|
| 114 |
+
// Replay the views (if any) to regenerate the current tensor off of the
|
| 115 |
+
// updated alias.
|
| 116 |
+
void sync_();
|
| 117 |
+
// Performs step (1) of the sync. This is its own public API because it's
|
| 118 |
+
// needed by view_inplace ops like transpose_. See Note [Functionalization
|
| 119 |
+
// Pass - Inplace View Ops]
|
| 120 |
+
void regenerate_from_base();
|
| 121 |
+
// Performs step (2) of the sync. This is its own public API because it's
|
| 122 |
+
// needed by functorch. functorch wants to make sure that all input tensors to
|
| 123 |
+
// a functionalized program have been properly synced so it can properly
|
| 124 |
+
// propagate mutations to inputs. It can't just call sync_(), because the
|
| 125 |
+
// FunctionalTensorWrapper will look like it has no aliases and sync_ will be
|
| 126 |
+
// a noop. We use the reference count on storage_ to determine if the wrapper
|
| 127 |
+
// is aliased, and by the time functorch is ready to propagate updates to
|
| 128 |
+
// inputs, any intermediate views of the input created by the program will
|
| 129 |
+
// have been deallocated. This function also returns whether or not the base
|
| 130 |
+
// actually had any updates to apply.
|
| 131 |
+
bool apply_updates();
|
| 132 |
+
// Takes the current state of value_ and snapshots it, sending it as a pending
|
| 133 |
+
// update to the alias.
|
| 134 |
+
void commit_update();
|
| 135 |
+
// When any tensor is mutated, the tensor increments its alias's "generation".
|
| 136 |
+
// Separately, each tensor maintains its own "generation" counter, which is
|
| 137 |
+
// used to determine if it's up-to-date with its alias. The act of syncing a
|
| 138 |
+
// tensor will set a tensor's generation equal to its alias's generation.
|
| 139 |
+
bool is_up_to_date() const;
|
| 140 |
+
// Freezes the storage of this tensor, preventing subsequent mutations
|
| 141 |
+
void freeze_storage() const;
|
| 142 |
+
// Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
|
| 143 |
+
// describing the series of view ops that ran to generate the current tensor
|
| 144 |
+
// from the base tensor. This method is used by inplace-view ops like
|
| 145 |
+
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
|
| 146 |
+
// tensor by replaying the views off of the alias.
|
| 147 |
+
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
|
| 148 |
+
|
| 149 |
+
// Custom implementation of self.set_(src)
|
| 150 |
+
void set__impl(const FunctionalTensorWrapper* other);
|
| 151 |
+
|
| 152 |
+
// Custom implementation of resize_storage_bytes_(self, new_size)
|
| 153 |
+
void storage_resize_(const c10::SymInt& new_size);
|
| 154 |
+
|
| 155 |
+
// Returns whether the current tensor's data was ever mutated
|
| 156 |
+
bool has_data_mutation();
|
| 157 |
+
//
|
| 158 |
+
// Returns whether the current FunctionalTensorWrapper
|
| 159 |
+
// experienced a set_() call.
|
| 160 |
+
bool was_storage_changed() {
|
| 161 |
+
return was_storage_changed_;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
void set_storage_changed() {
|
| 165 |
+
was_storage_changed_ = true;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// A FunctionalTensor is considered a base if its not a view of another
|
| 169 |
+
// tensor.
|
| 170 |
+
bool isBaseTensor() const {
|
| 171 |
+
return view_metas_.empty();
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
c10::SymInt get_storage_size(bool before) {
|
| 175 |
+
return functional_storage_impl()->get_storage_size(before);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
// Returns whether the FunctionalTensor experienced an
|
| 179 |
+
// untyped_storage().resize_() call
|
| 180 |
+
bool was_inductor_storage_resized() {
|
| 181 |
+
return functional_storage_impl()->was_inductor_storage_resized();
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
// The functionalization pass can be used to remove mutations.
|
| 185 |
+
// It does so by replacing any mutation op with it's corresponding
|
| 186 |
+
// out-of-place op, followed by a call to replace_(). e.g:
|
| 187 |
+
//
|
| 188 |
+
// a.add_(1)
|
| 189 |
+
//
|
| 190 |
+
// will turn into:
|
| 191 |
+
//
|
| 192 |
+
// tmp = a.add(1)
|
| 193 |
+
// a.replace_(tmp)
|
| 194 |
+
//
|
| 195 |
+
// replace_() swaps out the wrapped tensor, value_, with tmp.
|
| 196 |
+
void replace_(const Tensor& other, bool from_lazy_regenerate = false);
|
| 197 |
+
|
| 198 |
+
bool is_multi_output_view() {
|
| 199 |
+
return is_multi_output_view_;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// See Note[resize_() in functionalization pass]
|
| 203 |
+
void maybe_replace_storage(const Tensor& other);
|
| 204 |
+
|
| 205 |
+
// Replaces the storage with a new functional storage,
|
| 206 |
+
// and clears the view_metas_ stack.
|
| 207 |
+
// WARNING: Calling this function will sever the aliasing relationship between
|
| 208 |
+
// the current FunctionalTensorWrapper and any of its outstanding aliases.
|
| 209 |
+
// Please only call if you know what you're doing.
|
| 210 |
+
void _unsafe_reset_storage();
|
| 211 |
+
|
| 212 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 213 |
+
const c10::VariableVersion& version_counter,
|
| 214 |
+
bool allow_tensor_metadata_change) const override;
|
| 215 |
+
|
| 216 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 217 |
+
c10::VariableVersion&& version_counter,
|
| 218 |
+
bool allow_tensor_metadata_change) const override;
|
| 219 |
+
|
| 220 |
+
~FunctionalTensorWrapper() override = default;
|
| 221 |
+
|
| 222 |
+
// FunctionalTensorWrapper overrides all custom size/stride function,
|
| 223 |
+
// so that if the inner tensor has a custom implementation
|
| 224 |
+
// we make sure to call that implementation.
|
| 225 |
+
at::IntArrayRef sizes_custom() const override;
|
| 226 |
+
at::IntArrayRef strides_custom() const override;
|
| 227 |
+
int64_t dim_custom() const override;
|
| 228 |
+
int64_t numel_custom() const override;
|
| 229 |
+
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
| 230 |
+
c10::SymIntArrayRef sym_sizes_custom() const override;
|
| 231 |
+
c10::SymInt sym_size_custom(int64_t d) const override;
|
| 232 |
+
c10::SymIntArrayRef sym_strides_custom() const override;
|
| 233 |
+
c10::SymInt sym_storage_offset_custom() const override;
|
| 234 |
+
c10::Device device_custom() const override;
|
| 235 |
+
c10::Layout layout_impl() const override;
|
| 236 |
+
|
| 237 |
+
private:
|
| 238 |
+
const char* tensorimpl_type_name() const override;
|
| 239 |
+
void set_constructor_metadata();
|
| 240 |
+
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
|
| 241 |
+
|
| 242 |
+
// This is used to re-implement shallow_copy_and_detach for
|
| 243 |
+
// FunctionalTensorWrapper. The implementation is identical, but we just need
|
| 244 |
+
// to return a subclass instead of a plain TensorImpl.
|
| 245 |
+
// TODO: maybe it's possible to arrange for that to happen automatically
|
| 246 |
+
// without an override here?
|
| 247 |
+
template <typename VariableVersion>
|
| 248 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 249 |
+
VariableVersion&& version_counter,
|
| 250 |
+
bool allow_tensor_metadata_change) const;
|
| 251 |
+
|
| 252 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
| 253 |
+
void copy_tensor_metadata_and_refresh(
|
| 254 |
+
const FunctionalTensorWrapper* src_impl,
|
| 255 |
+
FunctionalTensorWrapper* dest_impl,
|
| 256 |
+
const c10::VariableVersion& version_counter,
|
| 257 |
+
bool allow_tensor_metadata_change) const;
|
| 258 |
+
|
| 259 |
+
// Note that value is not taken by reference: internally, the wrapper will
|
| 260 |
+
// change the value tensor that it points to over time.
|
| 261 |
+
Tensor value_;
|
| 262 |
+
int64_t level_{};
|
| 263 |
+
// These two counters are used for identifying
|
| 264 |
+
// whether all the mutations on a given tensor are hidden from autograd or
|
| 265 |
+
// not. If we have an input mutation that is hidden from autograd, then once
|
| 266 |
+
// we convert the input mutation to a copy_() we know it will be safe to hide
|
| 267 |
+
// the copy_() from autograd as well.
|
| 268 |
+
bool has_metadata_mutation_ = false;
|
| 269 |
+
bool is_multi_output_view_ = false;
|
| 270 |
+
// Did the tensor experience a set_() call.
|
| 271 |
+
bool was_storage_changed_ = false;
|
| 272 |
+
// Did the tensor experience any view operation with symbolic int.
|
| 273 |
+
bool is_symbolic_ = false;
|
| 274 |
+
|
| 275 |
+
size_t generation_ = 0;
|
| 276 |
+
std::vector<at::functionalization::ViewMeta> view_metas_;
|
| 277 |
+
|
| 278 |
+
protected:
|
| 279 |
+
static void copy_tensor_metadata(
|
| 280 |
+
const FunctionalTensorWrapper* src_impl,
|
| 281 |
+
FunctionalTensorWrapper* dest_impl,
|
| 282 |
+
const c10::VariableVersion& version_counter,
|
| 283 |
+
bool allow_tensor_metadata_change);
|
| 284 |
+
};
|
| 285 |
+
|
| 286 |
+
// Utility functions for the functionalization pass.
|
| 287 |
+
|
| 288 |
+
namespace functionalization {
|
| 289 |
+
namespace impl {
|
| 290 |
+
|
| 291 |
+
TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
|
| 292 |
+
const Tensor& tensor) {
|
| 293 |
+
auto functional_impl =
|
| 294 |
+
static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
|
| 295 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
|
| 296 |
+
return functional_impl;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
TORCH_API bool isBaseTensor(const at::Tensor& tensor);
|
| 300 |
+
|
| 301 |
+
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
|
| 302 |
+
TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
|
| 303 |
+
TORCH_API bool isFunctionalTensor(
|
| 304 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 305 |
+
TORCH_API bool isFunctionalTensor(ITensorListRef list);
|
| 306 |
+
|
| 307 |
+
TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
|
| 308 |
+
TORCH_API std::optional<Tensor> to_functional_tensor(
|
| 309 |
+
const std::optional<Tensor>& tensor);
|
| 310 |
+
TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
|
| 311 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 312 |
+
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
|
| 313 |
+
|
| 314 |
+
TORCH_API void freeze_functional_tensor(const Tensor& tensor);
|
| 315 |
+
|
| 316 |
+
TORCH_API Tensor
|
| 317 |
+
from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
|
| 318 |
+
TORCH_API std::optional<Tensor> from_functional_tensor(
|
| 319 |
+
const std::optional<Tensor>& t,
|
| 320 |
+
bool assert_functional = true);
|
| 321 |
+
TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
|
| 322 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 323 |
+
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
|
| 324 |
+
|
| 325 |
+
TORCH_API void sync(const at::Tensor& t);
|
| 326 |
+
TORCH_API void sync(const std::optional<Tensor>& t);
|
| 327 |
+
TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
|
| 328 |
+
TORCH_API void sync(ITensorListRef t_list);
|
| 329 |
+
|
| 330 |
+
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
|
| 331 |
+
TORCH_API void replace_(
|
| 332 |
+
const ITensorListRef functional_tensor,
|
| 333 |
+
ITensorListRef other);
|
| 334 |
+
|
| 335 |
+
TORCH_API void commit_update(const Tensor& functional_tensor);
|
| 336 |
+
TORCH_API void commit_update(ITensorListRef functional_tensor);
|
| 337 |
+
|
| 338 |
+
TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
|
| 339 |
+
|
| 340 |
+
TORCH_API void mark_mutation_hidden_from_autograd(
|
| 341 |
+
const Tensor& functional_tensor);
|
| 342 |
+
|
| 343 |
+
TORCH_API bool are_all_mutations_hidden_from_autograd(
|
| 344 |
+
const Tensor& functional_tensor);
|
| 345 |
+
|
| 346 |
+
TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
|
| 347 |
+
const Tensor& functional_tensor);
|
| 348 |
+
|
| 349 |
+
// These two methods are XLA-specific logic and are no-ops
|
| 350 |
+
// for the normal functionalization flow.
|
| 351 |
+
TORCH_API void propagate_xla_data(
|
| 352 |
+
const Tensor& functional_tensor,
|
| 353 |
+
const Tensor& other);
|
| 354 |
+
TORCH_API void propagate_xla_data(
|
| 355 |
+
const ITensorListRef functional_tensor,
|
| 356 |
+
ITensorListRef other);
|
| 357 |
+
|
| 358 |
+
TORCH_API void propagate_xla_data_direct(
|
| 359 |
+
const Tensor& tensor,
|
| 360 |
+
const Tensor& other);
|
| 361 |
+
TORCH_API void propagate_xla_data_direct(
|
| 362 |
+
const ITensorListRef tensor,
|
| 363 |
+
ITensorListRef other);
|
| 364 |
+
|
| 365 |
+
Tensor create_functional_tensor_with_view_meta(
|
| 366 |
+
const Tensor& view_to_wrap,
|
| 367 |
+
const Tensor& base,
|
| 368 |
+
functionalization::ViewMeta meta,
|
| 369 |
+
int64_t out_idx = 0);
|
| 370 |
+
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
| 371 |
+
ITensorListRef view_to_wrap,
|
| 372 |
+
const Tensor& base,
|
| 373 |
+
const functionalization::ViewMeta& meta);
|
| 374 |
+
|
| 375 |
+
void mutate_view_meta(
|
| 376 |
+
const Tensor& self,
|
| 377 |
+
const functionalization::ViewMeta& meta);
|
| 378 |
+
|
| 379 |
+
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
|
| 380 |
+
void set_sizes_strides_offset(
|
| 381 |
+
const std::vector<Tensor>& outs,
|
| 382 |
+
const std::vector<Tensor>& meta_outs);
|
| 383 |
+
|
| 384 |
+
// ~~~~~ TLS used in functionalization ~~~~~
|
| 385 |
+
|
| 386 |
+
TORCH_API bool getFunctionalizationReapplyViewsTLS();
|
| 387 |
+
TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
|
| 388 |
+
|
| 389 |
+
class TORCH_API FunctionalizationReapplyViewsGuard {
|
| 390 |
+
public:
|
| 391 |
+
FunctionalizationReapplyViewsGuard(bool reapply_views)
|
| 392 |
+
: prev_(getFunctionalizationReapplyViewsTLS()) {
|
| 393 |
+
setFunctionalizationReapplyViewsTLS(reapply_views);
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
~FunctionalizationReapplyViewsGuard() {
|
| 397 |
+
setFunctionalizationReapplyViewsTLS(prev_);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
FunctionalizationReapplyViewsGuard(
|
| 401 |
+
const FunctionalizationReapplyViewsGuard&) = delete;
|
| 402 |
+
FunctionalizationReapplyViewsGuard operator=(
|
| 403 |
+
const FunctionalizationReapplyViewsGuard&) = delete;
|
| 404 |
+
FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
|
| 405 |
+
delete;
|
| 406 |
+
FunctionalizationReapplyViewsGuard operator=(
|
| 407 |
+
FunctionalizationReapplyViewsGuard&&) = delete;
|
| 408 |
+
|
| 409 |
+
private:
|
| 410 |
+
bool prev_;
|
| 411 |
+
};
|
| 412 |
+
|
| 413 |
+
} // namespace impl
|
| 414 |
+
|
| 415 |
+
// Helper function to call an out-of-place composite aten kernel that may use
|
| 416 |
+
// mutations / views internally, and functionalize them.
|
| 417 |
+
TORCH_API void functionalize_op_helper(
|
| 418 |
+
const c10::OperatorHandle& op,
|
| 419 |
+
torch::jit::Stack* stack);
|
| 420 |
+
|
| 421 |
+
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 422 |
+
struct _functionalize_aten_op final {};
|
| 423 |
+
|
| 424 |
+
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 425 |
+
struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
|
| 426 |
+
static ReturnType call(
|
| 427 |
+
typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
|
| 428 |
+
using FuncType = ReturnType(
|
| 429 |
+
typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
|
| 430 |
+
auto op = c10::Dispatcher::singleton()
|
| 431 |
+
.findSchemaOrThrow(
|
| 432 |
+
(const char*)Op::name, (const char*)Op::overload_name)
|
| 433 |
+
.typed<FuncType>();
|
| 434 |
+
|
| 435 |
+
return c10::impl::BoxedKernelWrapper<FuncType>::call(
|
| 436 |
+
c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
|
| 437 |
+
op,
|
| 438 |
+
// BoxedKernelWrapper knows to ignore this keyset argument,
|
| 439 |
+
// because functionalize_op_helper doesn't take in a DispatchKeySet
|
| 440 |
+
c10::DispatchKeySet(),
|
| 441 |
+
args...);
|
| 442 |
+
}
|
| 443 |
+
};
|
| 444 |
+
|
| 445 |
+
template <class Op>
|
| 446 |
+
using functionalize_aten_op =
|
| 447 |
+
_functionalize_aten_op<Op, false, typename Op::schema>;
|
| 448 |
+
|
| 449 |
+
template <class Op>
|
| 450 |
+
using functionalize_aten_op_symint =
|
| 451 |
+
_functionalize_aten_op<Op, true, typename Op::schema>;
|
| 452 |
+
|
| 453 |
+
} // namespace functionalization
|
| 454 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/Functions.h
ADDED
|
@@ -0,0 +1,1454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Functions.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}.h> and \
|
| 17 |
+
see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
|
| 21 |
+
//
|
| 22 |
+
// In ATen, certain generated headers files include the definitions of
|
| 23 |
+
// every single operator in PyTorch. Unfortunately this means every
|
| 24 |
+
// time an operator signature is updated or changed in
|
| 25 |
+
// native_functions.yaml, you (and every other PyTorch developer) need
|
| 26 |
+
// to recompile every source file that includes any of these headers.
|
| 27 |
+
//
|
| 28 |
+
// To break up these header dependencies, and improve incremental
|
| 29 |
+
// build times for all PyTorch developers. These headers are split
|
| 30 |
+
// into per-operator headers in the `ATen/ops` folder. This limits
|
| 31 |
+
// incremental builds to only changes to methods of `Tensor`, or files
|
| 32 |
+
// that use the specific operator being changed. With `at::sum` as an
|
| 33 |
+
// example, you should include
|
| 34 |
+
//
|
| 35 |
+
// <ATen/ops/sum.h> // instead of ATen/Functions.h
|
| 36 |
+
// <ATen/ops/sum_native.h> // instead of ATen/NativeFunctions.h
|
| 37 |
+
// <ATen/ops/sum_ops.h> // instead of ATen/Operators.h
|
| 38 |
+
// <ATen/ops/sum_cpu_dispatch.h> // instead of ATen/CPUFunctions.h
|
| 39 |
+
//
|
| 40 |
+
// However, even if you're careful to use this in your own code.
|
| 41 |
+
// `Functions.h` might be included indirectly through another header
|
| 42 |
+
// without you realising. To avoid this, you can add
|
| 43 |
+
//
|
| 44 |
+
// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
| 45 |
+
//
|
| 46 |
+
// to the top of your source file. This way any time the non-specific
|
| 47 |
+
// headers are included, the compiler will error out.
|
| 48 |
+
//
|
| 49 |
+
// Also, be aware that `ops` are not available in all build
|
| 50 |
+
// configurations (namely fb-internal) so you must guard these
|
| 51 |
+
// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
|
| 52 |
+
//
|
| 53 |
+
// #ifndef AT_PER_OPERATOR_HEADERS
|
| 54 |
+
// #include <ATen/Functions.h>
|
| 55 |
+
// #else
|
| 56 |
+
// #include <ATen/ops/sum.h>
|
| 57 |
+
// #endif
|
| 58 |
+
|
| 59 |
+
#include <ATen/Context.h>
|
| 60 |
+
#include <ATen/DeviceGuard.h>
|
| 61 |
+
#include <ATen/TensorUtils.h>
|
| 62 |
+
#include <ATen/TracerMode.h>
|
| 63 |
+
#include <ATen/core/Generator.h>
|
| 64 |
+
#include <ATen/core/Reduction.h>
|
| 65 |
+
#include <c10/core/SymInt.h>
|
| 66 |
+
#include <ATen/core/Tensor.h>
|
| 67 |
+
#include <c10/core/Scalar.h>
|
| 68 |
+
#include <c10/core/Storage.h>
|
| 69 |
+
#include <c10/core/TensorOptions.h>
|
| 70 |
+
#include <c10/util/Deprecated.h>
|
| 71 |
+
#include <optional>
|
| 72 |
+
#include <c10/util/OptionalArrayRef.h>
|
| 73 |
+
|
| 74 |
+
#include <ATen/ops/from_blob.h>
|
| 75 |
+
#include <ATen/ops/tensor.h>
|
| 76 |
+
|
| 77 |
+
#include <ATen/ops/_adaptive_avg_pool2d.h>
|
| 78 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward.h>
|
| 79 |
+
#include <ATen/ops/_adaptive_avg_pool3d.h>
|
| 80 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward.h>
|
| 81 |
+
#include <ATen/ops/_add_batch_dim.h>
|
| 82 |
+
#include <ATen/ops/_add_relu.h>
|
| 83 |
+
#include <ATen/ops/_addmm_activation.h>
|
| 84 |
+
#include <ATen/ops/_aminmax.h>
|
| 85 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
|
| 86 |
+
#include <ATen/ops/_amp_update_scale.h>
|
| 87 |
+
#include <ATen/ops/_assert_async.h>
|
| 88 |
+
#include <ATen/ops/_assert_scalar.h>
|
| 89 |
+
#include <ATen/ops/_assert_tensor_metadata.h>
|
| 90 |
+
#include <ATen/ops/_autocast_to_full_precision.h>
|
| 91 |
+
#include <ATen/ops/_autocast_to_reduced_precision.h>
|
| 92 |
+
#include <ATen/ops/_backward.h>
|
| 93 |
+
#include <ATen/ops/_batch_norm_impl_index.h>
|
| 94 |
+
#include <ATen/ops/_batch_norm_impl_index_backward.h>
|
| 95 |
+
#include <ATen/ops/_batch_norm_no_update.h>
|
| 96 |
+
#include <ATen/ops/_batch_norm_with_update.h>
|
| 97 |
+
#include <ATen/ops/_cast_Byte.h>
|
| 98 |
+
#include <ATen/ops/_cast_Char.h>
|
| 99 |
+
#include <ATen/ops/_cast_Double.h>
|
| 100 |
+
#include <ATen/ops/_cast_Float.h>
|
| 101 |
+
#include <ATen/ops/_cast_Half.h>
|
| 102 |
+
#include <ATen/ops/_cast_Int.h>
|
| 103 |
+
#include <ATen/ops/_cast_Long.h>
|
| 104 |
+
#include <ATen/ops/_cast_Short.h>
|
| 105 |
+
#include <ATen/ops/_cdist_backward.h>
|
| 106 |
+
#include <ATen/ops/_cdist_forward.h>
|
| 107 |
+
#include <ATen/ops/_cholesky_solve_helper.h>
|
| 108 |
+
#include <ATen/ops/_choose_qparams_per_tensor.h>
|
| 109 |
+
#include <ATen/ops/_chunk_cat.h>
|
| 110 |
+
#include <ATen/ops/_coalesce.h>
|
| 111 |
+
#include <ATen/ops/_coalesced.h>
|
| 112 |
+
#include <ATen/ops/_compute_linear_combination.h>
|
| 113 |
+
#include <ATen/ops/_conj.h>
|
| 114 |
+
#include <ATen/ops/_conj_copy.h>
|
| 115 |
+
#include <ATen/ops/_conj_physical.h>
|
| 116 |
+
#include <ATen/ops/_conv_depthwise2d.h>
|
| 117 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr.h>
|
| 118 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
|
| 119 |
+
#include <ATen/ops/_convert_weight_to_int4pack.h>
|
| 120 |
+
#include <ATen/ops/_convolution.h>
|
| 121 |
+
#include <ATen/ops/_convolution_double_backward.h>
|
| 122 |
+
#include <ATen/ops/_convolution_mode.h>
|
| 123 |
+
#include <ATen/ops/_copy_from.h>
|
| 124 |
+
#include <ATen/ops/_copy_from_and_resize.h>
|
| 125 |
+
#include <ATen/ops/_cslt_compress.h>
|
| 126 |
+
#include <ATen/ops/_cslt_sparse_mm.h>
|
| 127 |
+
#include <ATen/ops/_cslt_sparse_mm_search.h>
|
| 128 |
+
#include <ATen/ops/_ctc_loss.h>
|
| 129 |
+
#include <ATen/ops/_ctc_loss_backward.h>
|
| 130 |
+
#include <ATen/ops/_cudnn_ctc_loss.h>
|
| 131 |
+
#include <ATen/ops/_cudnn_init_dropout_state.h>
|
| 132 |
+
#include <ATen/ops/_cudnn_rnn.h>
|
| 133 |
+
#include <ATen/ops/_cudnn_rnn_backward.h>
|
| 134 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight.h>
|
| 135 |
+
#include <ATen/ops/_cufft_clear_plan_cache.h>
|
| 136 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size.h>
|
| 137 |
+
#include <ATen/ops/_cufft_get_plan_cache_size.h>
|
| 138 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size.h>
|
| 139 |
+
#include <ATen/ops/_cummax_helper.h>
|
| 140 |
+
#include <ATen/ops/_cummin_helper.h>
|
| 141 |
+
#include <ATen/ops/_debug_has_internal_overlap.h>
|
| 142 |
+
#include <ATen/ops/_dimI.h>
|
| 143 |
+
#include <ATen/ops/_dimV.h>
|
| 144 |
+
#include <ATen/ops/_dim_arange.h>
|
| 145 |
+
#include <ATen/ops/_dirichlet_grad.h>
|
| 146 |
+
#include <ATen/ops/_efficient_attention_backward.h>
|
| 147 |
+
#include <ATen/ops/_efficient_attention_forward.h>
|
| 148 |
+
#include <ATen/ops/_efficientzerotensor.h>
|
| 149 |
+
#include <ATen/ops/_embedding_bag.h>
|
| 150 |
+
#include <ATen/ops/_embedding_bag_backward.h>
|
| 151 |
+
#include <ATen/ops/_embedding_bag_dense_backward.h>
|
| 152 |
+
#include <ATen/ops/_embedding_bag_forward_only.h>
|
| 153 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward.h>
|
| 154 |
+
#include <ATen/ops/_embedding_bag_sparse_backward.h>
|
| 155 |
+
#include <ATen/ops/_empty_affine_quantized.h>
|
| 156 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized.h>
|
| 157 |
+
#include <ATen/ops/_euclidean_dist.h>
|
| 158 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine.h>
|
| 159 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward.h>
|
| 160 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine.h>
|
| 161 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward.h>
|
| 162 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
|
| 163 |
+
#include <ATen/ops/_fft_c2c.h>
|
| 164 |
+
#include <ATen/ops/_fft_c2r.h>
|
| 165 |
+
#include <ATen/ops/_fft_r2c.h>
|
| 166 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask.h>
|
| 167 |
+
#include <ATen/ops/_flash_attention_backward.h>
|
| 168 |
+
#include <ATen/ops/_flash_attention_forward.h>
|
| 169 |
+
#include <ATen/ops/_foobar.h>
|
| 170 |
+
#include <ATen/ops/_foreach_abs.h>
|
| 171 |
+
#include <ATen/ops/_foreach_acos.h>
|
| 172 |
+
#include <ATen/ops/_foreach_add.h>
|
| 173 |
+
#include <ATen/ops/_foreach_addcdiv.h>
|
| 174 |
+
#include <ATen/ops/_foreach_addcmul.h>
|
| 175 |
+
#include <ATen/ops/_foreach_asin.h>
|
| 176 |
+
#include <ATen/ops/_foreach_atan.h>
|
| 177 |
+
#include <ATen/ops/_foreach_ceil.h>
|
| 178 |
+
#include <ATen/ops/_foreach_clamp_max.h>
|
| 179 |
+
#include <ATen/ops/_foreach_clamp_min.h>
|
| 180 |
+
#include <ATen/ops/_foreach_copy.h>
|
| 181 |
+
#include <ATen/ops/_foreach_cos.h>
|
| 182 |
+
#include <ATen/ops/_foreach_cosh.h>
|
| 183 |
+
#include <ATen/ops/_foreach_div.h>
|
| 184 |
+
#include <ATen/ops/_foreach_erf.h>
|
| 185 |
+
#include <ATen/ops/_foreach_erfc.h>
|
| 186 |
+
#include <ATen/ops/_foreach_exp.h>
|
| 187 |
+
#include <ATen/ops/_foreach_expm1.h>
|
| 188 |
+
#include <ATen/ops/_foreach_floor.h>
|
| 189 |
+
#include <ATen/ops/_foreach_frac.h>
|
| 190 |
+
#include <ATen/ops/_foreach_lerp.h>
|
| 191 |
+
#include <ATen/ops/_foreach_lgamma.h>
|
| 192 |
+
#include <ATen/ops/_foreach_log.h>
|
| 193 |
+
#include <ATen/ops/_foreach_log10.h>
|
| 194 |
+
#include <ATen/ops/_foreach_log1p.h>
|
| 195 |
+
#include <ATen/ops/_foreach_log2.h>
|
| 196 |
+
#include <ATen/ops/_foreach_max.h>
|
| 197 |
+
#include <ATen/ops/_foreach_maximum.h>
|
| 198 |
+
#include <ATen/ops/_foreach_minimum.h>
|
| 199 |
+
#include <ATen/ops/_foreach_mul.h>
|
| 200 |
+
#include <ATen/ops/_foreach_neg.h>
|
| 201 |
+
#include <ATen/ops/_foreach_norm.h>
|
| 202 |
+
#include <ATen/ops/_foreach_pow.h>
|
| 203 |
+
#include <ATen/ops/_foreach_reciprocal.h>
|
| 204 |
+
#include <ATen/ops/_foreach_round.h>
|
| 205 |
+
#include <ATen/ops/_foreach_sigmoid.h>
|
| 206 |
+
#include <ATen/ops/_foreach_sign.h>
|
| 207 |
+
#include <ATen/ops/_foreach_sin.h>
|
| 208 |
+
#include <ATen/ops/_foreach_sinh.h>
|
| 209 |
+
#include <ATen/ops/_foreach_sqrt.h>
|
| 210 |
+
#include <ATen/ops/_foreach_sub.h>
|
| 211 |
+
#include <ATen/ops/_foreach_tan.h>
|
| 212 |
+
#include <ATen/ops/_foreach_tanh.h>
|
| 213 |
+
#include <ATen/ops/_foreach_trunc.h>
|
| 214 |
+
#include <ATen/ops/_foreach_zero.h>
|
| 215 |
+
#include <ATen/ops/_functional_assert_async.h>
|
| 216 |
+
#include <ATen/ops/_functional_assert_scalar.h>
|
| 217 |
+
#include <ATen/ops/_functional_sym_constrain_range.h>
|
| 218 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size.h>
|
| 219 |
+
#include <ATen/ops/_fused_adagrad.h>
|
| 220 |
+
#include <ATen/ops/_fused_adam.h>
|
| 221 |
+
#include <ATen/ops/_fused_adamw.h>
|
| 222 |
+
#include <ATen/ops/_fused_dropout.h>
|
| 223 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper.h>
|
| 224 |
+
#include <ATen/ops/_fused_sdp_choice.h>
|
| 225 |
+
#include <ATen/ops/_fused_sgd.h>
|
| 226 |
+
#include <ATen/ops/_fw_primal.h>
|
| 227 |
+
#include <ATen/ops/_fw_primal_copy.h>
|
| 228 |
+
#include <ATen/ops/_gather_sparse_backward.h>
|
| 229 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback.h>
|
| 230 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward.h>
|
| 231 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type.h>
|
| 232 |
+
#include <ATen/ops/_has_same_storage_numel.h>
|
| 233 |
+
#include <ATen/ops/_histogramdd_bin_edges.h>
|
| 234 |
+
#include <ATen/ops/_histogramdd_from_bin_cts.h>
|
| 235 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors.h>
|
| 236 |
+
#include <ATen/ops/_index_put_impl.h>
|
| 237 |
+
#include <ATen/ops/_indices.h>
|
| 238 |
+
#include <ATen/ops/_indices_copy.h>
|
| 239 |
+
#include <ATen/ops/_int_mm.h>
|
| 240 |
+
#include <ATen/ops/_is_all_true.h>
|
| 241 |
+
#include <ATen/ops/_is_any_true.h>
|
| 242 |
+
#include <ATen/ops/_is_zerotensor.h>
|
| 243 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward.h>
|
| 244 |
+
#include <ATen/ops/_lazy_clone.h>
|
| 245 |
+
#include <ATen/ops/_linalg_check_errors.h>
|
| 246 |
+
#include <ATen/ops/_linalg_det.h>
|
| 247 |
+
#include <ATen/ops/_linalg_eigh.h>
|
| 248 |
+
#include <ATen/ops/_linalg_eigvals.h>
|
| 249 |
+
#include <ATen/ops/_linalg_slogdet.h>
|
| 250 |
+
#include <ATen/ops/_linalg_solve_ex.h>
|
| 251 |
+
#include <ATen/ops/_linalg_svd.h>
|
| 252 |
+
#include <ATen/ops/_local_scalar_dense.h>
|
| 253 |
+
#include <ATen/ops/_log_softmax.h>
|
| 254 |
+
#include <ATen/ops/_log_softmax_backward_data.h>
|
| 255 |
+
#include <ATen/ops/_logcumsumexp.h>
|
| 256 |
+
#include <ATen/ops/_lstm_mps.h>
|
| 257 |
+
#include <ATen/ops/_lu_with_info.h>
|
| 258 |
+
#include <ATen/ops/_make_dep_token.h>
|
| 259 |
+
#include <ATen/ops/_make_dual.h>
|
| 260 |
+
#include <ATen/ops/_make_dual_copy.h>
|
| 261 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor.h>
|
| 262 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
| 263 |
+
#include <ATen/ops/_masked_scale.h>
|
| 264 |
+
#include <ATen/ops/_masked_softmax.h>
|
| 265 |
+
#include <ATen/ops/_masked_softmax_backward.h>
|
| 266 |
+
#include <ATen/ops/_mixed_dtypes_linear.h>
|
| 267 |
+
#include <ATen/ops/_mkldnn_reshape.h>
|
| 268 |
+
#include <ATen/ops/_mkldnn_transpose.h>
|
| 269 |
+
#include <ATen/ops/_mps_convolution.h>
|
| 270 |
+
#include <ATen/ops/_mps_convolution_transpose.h>
|
| 271 |
+
#include <ATen/ops/_native_batch_norm_legit.h>
|
| 272 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
|
| 273 |
+
#include <ATen/ops/_native_multi_head_attention.h>
|
| 274 |
+
#include <ATen/ops/_neg_view.h>
|
| 275 |
+
#include <ATen/ops/_neg_view_copy.h>
|
| 276 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets.h>
|
| 277 |
+
#include <ATen/ops/_nested_from_padded.h>
|
| 278 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example.h>
|
| 279 |
+
#include <ATen/ops/_nested_get_jagged_dummy.h>
|
| 280 |
+
#include <ATen/ops/_nested_get_lengths.h>
|
| 281 |
+
#include <ATen/ops/_nested_get_max_seqlen.h>
|
| 282 |
+
#include <ATen/ops/_nested_get_min_seqlen.h>
|
| 283 |
+
#include <ATen/ops/_nested_get_offsets.h>
|
| 284 |
+
#include <ATen/ops/_nested_get_ragged_idx.h>
|
| 285 |
+
#include <ATen/ops/_nested_get_values.h>
|
| 286 |
+
#include <ATen/ops/_nested_get_values_copy.h>
|
| 287 |
+
#include <ATen/ops/_nested_select_backward.h>
|
| 288 |
+
#include <ATen/ops/_nested_sum_backward.h>
|
| 289 |
+
#include <ATen/ops/_nested_tensor_from_mask.h>
|
| 290 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned.h>
|
| 291 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list.h>
|
| 292 |
+
#include <ATen/ops/_nested_tensor_size.h>
|
| 293 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape.h>
|
| 294 |
+
#include <ATen/ops/_nested_tensor_storage_offsets.h>
|
| 295 |
+
#include <ATen/ops/_nested_tensor_strides.h>
|
| 296 |
+
#include <ATen/ops/_nested_view_from_buffer.h>
|
| 297 |
+
#include <ATen/ops/_nested_view_from_buffer_copy.h>
|
| 298 |
+
#include <ATen/ops/_nested_view_from_jagged.h>
|
| 299 |
+
#include <ATen/ops/_nested_view_from_jagged_copy.h>
|
| 300 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta.h>
|
| 301 |
+
#include <ATen/ops/_nnpack_available.h>
|
| 302 |
+
#include <ATen/ops/_nnpack_spatial_convolution.h>
|
| 303 |
+
#include <ATen/ops/_nnz.h>
|
| 304 |
+
#include <ATen/ops/_pack_padded_sequence.h>
|
| 305 |
+
#include <ATen/ops/_pack_padded_sequence_backward.h>
|
| 306 |
+
#include <ATen/ops/_pad_circular.h>
|
| 307 |
+
#include <ATen/ops/_pad_enum.h>
|
| 308 |
+
#include <ATen/ops/_pad_packed_sequence.h>
|
| 309 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward.h>
|
| 310 |
+
#include <ATen/ops/_pdist_backward.h>
|
| 311 |
+
#include <ATen/ops/_pdist_forward.h>
|
| 312 |
+
#include <ATen/ops/_pin_memory.h>
|
| 313 |
+
#include <ATen/ops/_prelu_kernel.h>
|
| 314 |
+
#include <ATen/ops/_prelu_kernel_backward.h>
|
| 315 |
+
#include <ATen/ops/_print.h>
|
| 316 |
+
#include <ATen/ops/_propagate_xla_data.h>
|
| 317 |
+
#include <ATen/ops/_remove_batch_dim.h>
|
| 318 |
+
#include <ATen/ops/_reshape_alias.h>
|
| 319 |
+
#include <ATen/ops/_reshape_alias_copy.h>
|
| 320 |
+
#include <ATen/ops/_reshape_copy.h>
|
| 321 |
+
#include <ATen/ops/_reshape_from_tensor.h>
|
| 322 |
+
#include <ATen/ops/_resize_output.h>
|
| 323 |
+
#include <ATen/ops/_rowwise_prune.h>
|
| 324 |
+
#include <ATen/ops/_safe_softmax.h>
|
| 325 |
+
#include <ATen/ops/_sample_dirichlet.h>
|
| 326 |
+
#include <ATen/ops/_saturate_weight_to_fp16.h>
|
| 327 |
+
#include <ATen/ops/_scaled_dot_product_attention_math.h>
|
| 328 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_for_mps.h>
|
| 329 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
|
| 330 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward.h>
|
| 331 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention.h>
|
| 332 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward.h>
|
| 333 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention.h>
|
| 334 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward.h>
|
| 335 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
|
| 336 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
|
| 337 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable.h>
|
| 338 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward.h>
|
| 339 |
+
#include <ATen/ops/_scaled_mm.h>
|
| 340 |
+
#include <ATen/ops/_segment_reduce_backward.h>
|
| 341 |
+
#include <ATen/ops/_shape_as_tensor.h>
|
| 342 |
+
#include <ATen/ops/_slow_conv2d_backward.h>
|
| 343 |
+
#include <ATen/ops/_slow_conv2d_forward.h>
|
| 344 |
+
#include <ATen/ops/_sobol_engine_draw.h>
|
| 345 |
+
#include <ATen/ops/_sobol_engine_ff.h>
|
| 346 |
+
#include <ATen/ops/_sobol_engine_initialize_state.h>
|
| 347 |
+
#include <ATen/ops/_sobol_engine_scramble.h>
|
| 348 |
+
#include <ATen/ops/_softmax.h>
|
| 349 |
+
#include <ATen/ops/_softmax_backward_data.h>
|
| 350 |
+
#include <ATen/ops/_sparse_addmm.h>
|
| 351 |
+
#include <ATen/ops/_sparse_broadcast_to.h>
|
| 352 |
+
#include <ATen/ops/_sparse_broadcast_to_copy.h>
|
| 353 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe.h>
|
| 354 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe.h>
|
| 355 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
|
| 356 |
+
#include <ATen/ops/_sparse_compressed_tensor_with_dims.h>
|
| 357 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
| 358 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims.h>
|
| 359 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
|
| 360 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe.h>
|
| 361 |
+
#include <ATen/ops/_sparse_csr_prod.h>
|
| 362 |
+
#include <ATen/ops/_sparse_csr_sum.h>
|
| 363 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe.h>
|
| 364 |
+
#include <ATen/ops/_sparse_log_softmax.h>
|
| 365 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data.h>
|
| 366 |
+
#include <ATen/ops/_sparse_mask_projection.h>
|
| 367 |
+
#include <ATen/ops/_sparse_mm.h>
|
| 368 |
+
#include <ATen/ops/_sparse_mm_reduce_impl.h>
|
| 369 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward.h>
|
| 370 |
+
#include <ATen/ops/_sparse_semi_structured_addmm.h>
|
| 371 |
+
#include <ATen/ops/_sparse_semi_structured_apply.h>
|
| 372 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense.h>
|
| 373 |
+
#include <ATen/ops/_sparse_semi_structured_linear.h>
|
| 374 |
+
#include <ATen/ops/_sparse_semi_structured_mm.h>
|
| 375 |
+
#include <ATen/ops/_sparse_semi_structured_tile.h>
|
| 376 |
+
#include <ATen/ops/_sparse_softmax.h>
|
| 377 |
+
#include <ATen/ops/_sparse_softmax_backward_data.h>
|
| 378 |
+
#include <ATen/ops/_sparse_sparse_matmul.h>
|
| 379 |
+
#include <ATen/ops/_sparse_sum.h>
|
| 380 |
+
#include <ATen/ops/_sparse_sum_backward.h>
|
| 381 |
+
#include <ATen/ops/_spdiags.h>
|
| 382 |
+
#include <ATen/ops/_spsolve.h>
|
| 383 |
+
#include <ATen/ops/_stack.h>
|
| 384 |
+
#include <ATen/ops/_standard_gamma.h>
|
| 385 |
+
#include <ATen/ops/_standard_gamma_grad.h>
|
| 386 |
+
#include <ATen/ops/_test_ambiguous_defaults.h>
|
| 387 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch.h>
|
| 388 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view.h>
|
| 389 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy.h>
|
| 390 |
+
#include <ATen/ops/_test_check_tensor.h>
|
| 391 |
+
#include <ATen/ops/_test_functorch_fallback.h>
|
| 392 |
+
#include <ATen/ops/_test_optional_filled_intlist.h>
|
| 393 |
+
#include <ATen/ops/_test_optional_floatlist.h>
|
| 394 |
+
#include <ATen/ops/_test_optional_intlist.h>
|
| 395 |
+
#include <ATen/ops/_test_parallel_materialize.h>
|
| 396 |
+
#include <ATen/ops/_test_serialization_subcmul.h>
|
| 397 |
+
#include <ATen/ops/_test_string_default.h>
|
| 398 |
+
#include <ATen/ops/_test_warn_in_autograd.h>
|
| 399 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward.h>
|
| 400 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward.h>
|
| 401 |
+
#include <ATen/ops/_thnn_fused_gru_cell.h>
|
| 402 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward.h>
|
| 403 |
+
#include <ATen/ops/_thnn_fused_lstm_cell.h>
|
| 404 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward.h>
|
| 405 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl.h>
|
| 406 |
+
#include <ATen/ops/_to_copy.h>
|
| 407 |
+
#include <ATen/ops/_to_cpu.h>
|
| 408 |
+
#include <ATen/ops/_to_dense.h>
|
| 409 |
+
#include <ATen/ops/_to_sparse.h>
|
| 410 |
+
#include <ATen/ops/_to_sparse_bsc.h>
|
| 411 |
+
#include <ATen/ops/_to_sparse_bsr.h>
|
| 412 |
+
#include <ATen/ops/_to_sparse_csc.h>
|
| 413 |
+
#include <ATen/ops/_to_sparse_csr.h>
|
| 414 |
+
#include <ATen/ops/_to_sparse_semi_structured.h>
|
| 415 |
+
#include <ATen/ops/_transform_bias_rescale_qkv.h>
|
| 416 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd.h>
|
| 417 |
+
#include <ATen/ops/_trilinear.h>
|
| 418 |
+
#include <ATen/ops/_triton_multi_head_attention.h>
|
| 419 |
+
#include <ATen/ops/_triton_scaled_dot_attention.h>
|
| 420 |
+
#include <ATen/ops/_unique.h>
|
| 421 |
+
#include <ATen/ops/_unique2.h>
|
| 422 |
+
#include <ATen/ops/_unpack_dual.h>
|
| 423 |
+
#include <ATen/ops/_unsafe_index.h>
|
| 424 |
+
#include <ATen/ops/_unsafe_index_put.h>
|
| 425 |
+
#include <ATen/ops/_unsafe_masked_index.h>
|
| 426 |
+
#include <ATen/ops/_unsafe_masked_index_put_accumulate.h>
|
| 427 |
+
#include <ATen/ops/_unsafe_view.h>
|
| 428 |
+
#include <ATen/ops/_upsample_bicubic2d_aa.h>
|
| 429 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward.h>
|
| 430 |
+
#include <ATen/ops/_upsample_bilinear2d_aa.h>
|
| 431 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward.h>
|
| 432 |
+
#include <ATen/ops/_upsample_nearest_exact1d.h>
|
| 433 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward.h>
|
| 434 |
+
#include <ATen/ops/_upsample_nearest_exact2d.h>
|
| 435 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward.h>
|
| 436 |
+
#include <ATen/ops/_upsample_nearest_exact3d.h>
|
| 437 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward.h>
|
| 438 |
+
#include <ATen/ops/_use_cudnn_ctc_loss.h>
|
| 439 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight.h>
|
| 440 |
+
#include <ATen/ops/_validate_compressed_sparse_indices.h>
|
| 441 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args.h>
|
| 442 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args.h>
|
| 443 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args.h>
|
| 444 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args.h>
|
| 445 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args.h>
|
| 446 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args.h>
|
| 447 |
+
#include <ATen/ops/_values.h>
|
| 448 |
+
#include <ATen/ops/_values_copy.h>
|
| 449 |
+
#include <ATen/ops/_version.h>
|
| 450 |
+
#include <ATen/ops/_weight_int4pack_mm.h>
|
| 451 |
+
#include <ATen/ops/_weight_int8pack_mm.h>
|
| 452 |
+
#include <ATen/ops/_weight_norm.h>
|
| 453 |
+
#include <ATen/ops/_weight_norm_differentiable_backward.h>
|
| 454 |
+
#include <ATen/ops/_weight_norm_interface.h>
|
| 455 |
+
#include <ATen/ops/_weight_norm_interface_backward.h>
|
| 456 |
+
#include <ATen/ops/_wrapped_linear_prepack.h>
|
| 457 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked.h>
|
| 458 |
+
#include <ATen/ops/abs.h>
|
| 459 |
+
#include <ATen/ops/absolute.h>
|
| 460 |
+
#include <ATen/ops/acos.h>
|
| 461 |
+
#include <ATen/ops/acosh.h>
|
| 462 |
+
#include <ATen/ops/adaptive_avg_pool1d.h>
|
| 463 |
+
#include <ATen/ops/adaptive_avg_pool2d.h>
|
| 464 |
+
#include <ATen/ops/adaptive_avg_pool3d.h>
|
| 465 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward.h>
|
| 466 |
+
#include <ATen/ops/adaptive_max_pool1d.h>
|
| 467 |
+
#include <ATen/ops/adaptive_max_pool2d.h>
|
| 468 |
+
#include <ATen/ops/adaptive_max_pool2d_backward.h>
|
| 469 |
+
#include <ATen/ops/adaptive_max_pool3d.h>
|
| 470 |
+
#include <ATen/ops/adaptive_max_pool3d_backward.h>
|
| 471 |
+
#include <ATen/ops/add.h>
|
| 472 |
+
#include <ATen/ops/addbmm.h>
|
| 473 |
+
#include <ATen/ops/addcdiv.h>
|
| 474 |
+
#include <ATen/ops/addcmul.h>
|
| 475 |
+
#include <ATen/ops/addmm.h>
|
| 476 |
+
#include <ATen/ops/addmv.h>
|
| 477 |
+
#include <ATen/ops/addr.h>
|
| 478 |
+
#include <ATen/ops/adjoint.h>
|
| 479 |
+
#include <ATen/ops/affine_grid_generator.h>
|
| 480 |
+
#include <ATen/ops/affine_grid_generator_backward.h>
|
| 481 |
+
#include <ATen/ops/alias.h>
|
| 482 |
+
#include <ATen/ops/alias_copy.h>
|
| 483 |
+
#include <ATen/ops/align_as.h>
|
| 484 |
+
#include <ATen/ops/align_tensors.h>
|
| 485 |
+
#include <ATen/ops/align_to.h>
|
| 486 |
+
#include <ATen/ops/all.h>
|
| 487 |
+
#include <ATen/ops/allclose.h>
|
| 488 |
+
#include <ATen/ops/alpha_dropout.h>
|
| 489 |
+
#include <ATen/ops/amax.h>
|
| 490 |
+
#include <ATen/ops/amin.h>
|
| 491 |
+
#include <ATen/ops/aminmax.h>
|
| 492 |
+
#include <ATen/ops/and.h>
|
| 493 |
+
#include <ATen/ops/angle.h>
|
| 494 |
+
#include <ATen/ops/any.h>
|
| 495 |
+
#include <ATen/ops/arange.h>
|
| 496 |
+
#include <ATen/ops/arccos.h>
|
| 497 |
+
#include <ATen/ops/arccosh.h>
|
| 498 |
+
#include <ATen/ops/arcsin.h>
|
| 499 |
+
#include <ATen/ops/arcsinh.h>
|
| 500 |
+
#include <ATen/ops/arctan.h>
|
| 501 |
+
#include <ATen/ops/arctan2.h>
|
| 502 |
+
#include <ATen/ops/arctanh.h>
|
| 503 |
+
#include <ATen/ops/argmax.h>
|
| 504 |
+
#include <ATen/ops/argmin.h>
|
| 505 |
+
#include <ATen/ops/argsort.h>
|
| 506 |
+
#include <ATen/ops/argwhere.h>
|
| 507 |
+
#include <ATen/ops/as_strided.h>
|
| 508 |
+
#include <ATen/ops/as_strided_copy.h>
|
| 509 |
+
#include <ATen/ops/as_strided_scatter.h>
|
| 510 |
+
#include <ATen/ops/asin.h>
|
| 511 |
+
#include <ATen/ops/asinh.h>
|
| 512 |
+
#include <ATen/ops/atan.h>
|
| 513 |
+
#include <ATen/ops/atan2.h>
|
| 514 |
+
#include <ATen/ops/atanh.h>
|
| 515 |
+
#include <ATen/ops/atleast_1d.h>
|
| 516 |
+
#include <ATen/ops/atleast_2d.h>
|
| 517 |
+
#include <ATen/ops/atleast_3d.h>
|
| 518 |
+
#include <ATen/ops/avg_pool1d.h>
|
| 519 |
+
#include <ATen/ops/avg_pool2d.h>
|
| 520 |
+
#include <ATen/ops/avg_pool2d_backward.h>
|
| 521 |
+
#include <ATen/ops/avg_pool3d.h>
|
| 522 |
+
#include <ATen/ops/avg_pool3d_backward.h>
|
| 523 |
+
#include <ATen/ops/baddbmm.h>
|
| 524 |
+
#include <ATen/ops/bartlett_window.h>
|
| 525 |
+
#include <ATen/ops/batch_norm.h>
|
| 526 |
+
#include <ATen/ops/batch_norm_backward.h>
|
| 527 |
+
#include <ATen/ops/batch_norm_backward_elemt.h>
|
| 528 |
+
#include <ATen/ops/batch_norm_backward_reduce.h>
|
| 529 |
+
#include <ATen/ops/batch_norm_elemt.h>
|
| 530 |
+
#include <ATen/ops/batch_norm_gather_stats.h>
|
| 531 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts.h>
|
| 532 |
+
#include <ATen/ops/batch_norm_stats.h>
|
| 533 |
+
#include <ATen/ops/batch_norm_update_stats.h>
|
| 534 |
+
#include <ATen/ops/bernoulli.h>
|
| 535 |
+
#include <ATen/ops/bilinear.h>
|
| 536 |
+
#include <ATen/ops/binary_cross_entropy.h>
|
| 537 |
+
#include <ATen/ops/binary_cross_entropy_backward.h>
|
| 538 |
+
#include <ATen/ops/binary_cross_entropy_with_logits.h>
|
| 539 |
+
#include <ATen/ops/bincount.h>
|
| 540 |
+
#include <ATen/ops/binomial.h>
|
| 541 |
+
#include <ATen/ops/bitwise_and.h>
|
| 542 |
+
#include <ATen/ops/bitwise_left_shift.h>
|
| 543 |
+
#include <ATen/ops/bitwise_not.h>
|
| 544 |
+
#include <ATen/ops/bitwise_or.h>
|
| 545 |
+
#include <ATen/ops/bitwise_right_shift.h>
|
| 546 |
+
#include <ATen/ops/bitwise_xor.h>
|
| 547 |
+
#include <ATen/ops/blackman_window.h>
|
| 548 |
+
#include <ATen/ops/block_diag.h>
|
| 549 |
+
#include <ATen/ops/bmm.h>
|
| 550 |
+
#include <ATen/ops/broadcast_tensors.h>
|
| 551 |
+
#include <ATen/ops/broadcast_to.h>
|
| 552 |
+
#include <ATen/ops/bucketize.h>
|
| 553 |
+
#include <ATen/ops/can_cast.h>
|
| 554 |
+
#include <ATen/ops/cartesian_prod.h>
|
| 555 |
+
#include <ATen/ops/cat.h>
|
| 556 |
+
#include <ATen/ops/cauchy.h>
|
| 557 |
+
#include <ATen/ops/ccol_indices.h>
|
| 558 |
+
#include <ATen/ops/ccol_indices_copy.h>
|
| 559 |
+
#include <ATen/ops/cdist.h>
|
| 560 |
+
#include <ATen/ops/ceil.h>
|
| 561 |
+
#include <ATen/ops/celu.h>
|
| 562 |
+
#include <ATen/ops/chain_matmul.h>
|
| 563 |
+
#include <ATen/ops/chalf.h>
|
| 564 |
+
#include <ATen/ops/channel_shuffle.h>
|
| 565 |
+
#include <ATen/ops/cholesky.h>
|
| 566 |
+
#include <ATen/ops/cholesky_inverse.h>
|
| 567 |
+
#include <ATen/ops/cholesky_solve.h>
|
| 568 |
+
#include <ATen/ops/choose_qparams_optimized.h>
|
| 569 |
+
#include <ATen/ops/chunk.h>
|
| 570 |
+
#include <ATen/ops/clamp.h>
|
| 571 |
+
#include <ATen/ops/clamp_max.h>
|
| 572 |
+
#include <ATen/ops/clamp_min.h>
|
| 573 |
+
#include <ATen/ops/clip.h>
|
| 574 |
+
#include <ATen/ops/clone.h>
|
| 575 |
+
#include <ATen/ops/coalesce.h>
|
| 576 |
+
#include <ATen/ops/col2im.h>
|
| 577 |
+
#include <ATen/ops/col_indices.h>
|
| 578 |
+
#include <ATen/ops/col_indices_copy.h>
|
| 579 |
+
#include <ATen/ops/column_stack.h>
|
| 580 |
+
#include <ATen/ops/combinations.h>
|
| 581 |
+
#include <ATen/ops/complex.h>
|
| 582 |
+
#include <ATen/ops/concat.h>
|
| 583 |
+
#include <ATen/ops/concatenate.h>
|
| 584 |
+
#include <ATen/ops/conj.h>
|
| 585 |
+
#include <ATen/ops/conj_physical.h>
|
| 586 |
+
#include <ATen/ops/constant_pad_nd.h>
|
| 587 |
+
#include <ATen/ops/contiguous.h>
|
| 588 |
+
#include <ATen/ops/conv1d.h>
|
| 589 |
+
#include <ATen/ops/conv2d.h>
|
| 590 |
+
#include <ATen/ops/conv3d.h>
|
| 591 |
+
#include <ATen/ops/conv_depthwise3d.h>
|
| 592 |
+
#include <ATen/ops/conv_tbc.h>
|
| 593 |
+
#include <ATen/ops/conv_tbc_backward.h>
|
| 594 |
+
#include <ATen/ops/conv_transpose1d.h>
|
| 595 |
+
#include <ATen/ops/conv_transpose2d.h>
|
| 596 |
+
#include <ATen/ops/conv_transpose3d.h>
|
| 597 |
+
#include <ATen/ops/convolution.h>
|
| 598 |
+
#include <ATen/ops/convolution_backward.h>
|
| 599 |
+
#include <ATen/ops/convolution_backward_overrideable.h>
|
| 600 |
+
#include <ATen/ops/convolution_overrideable.h>
|
| 601 |
+
#include <ATen/ops/copy.h>
|
| 602 |
+
#include <ATen/ops/copy_sparse_to_sparse.h>
|
| 603 |
+
#include <ATen/ops/copysign.h>
|
| 604 |
+
#include <ATen/ops/corrcoef.h>
|
| 605 |
+
#include <ATen/ops/cos.h>
|
| 606 |
+
#include <ATen/ops/cosh.h>
|
| 607 |
+
#include <ATen/ops/cosine_embedding_loss.h>
|
| 608 |
+
#include <ATen/ops/cosine_similarity.h>
|
| 609 |
+
#include <ATen/ops/count_nonzero.h>
|
| 610 |
+
#include <ATen/ops/cov.h>
|
| 611 |
+
#include <ATen/ops/cross.h>
|
| 612 |
+
#include <ATen/ops/cross_entropy_loss.h>
|
| 613 |
+
#include <ATen/ops/crow_indices.h>
|
| 614 |
+
#include <ATen/ops/crow_indices_copy.h>
|
| 615 |
+
#include <ATen/ops/ctc_loss.h>
|
| 616 |
+
#include <ATen/ops/cudnn_affine_grid_generator.h>
|
| 617 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward.h>
|
| 618 |
+
#include <ATen/ops/cudnn_batch_norm.h>
|
| 619 |
+
#include <ATen/ops/cudnn_batch_norm_backward.h>
|
| 620 |
+
#include <ATen/ops/cudnn_convolution.h>
|
| 621 |
+
#include <ATen/ops/cudnn_convolution_add_relu.h>
|
| 622 |
+
#include <ATen/ops/cudnn_convolution_relu.h>
|
| 623 |
+
#include <ATen/ops/cudnn_convolution_transpose.h>
|
| 624 |
+
#include <ATen/ops/cudnn_grid_sampler.h>
|
| 625 |
+
#include <ATen/ops/cudnn_grid_sampler_backward.h>
|
| 626 |
+
#include <ATen/ops/cudnn_is_acceptable.h>
|
| 627 |
+
#include <ATen/ops/cummax.h>
|
| 628 |
+
#include <ATen/ops/cummaxmin_backward.h>
|
| 629 |
+
#include <ATen/ops/cummin.h>
|
| 630 |
+
#include <ATen/ops/cumprod.h>
|
| 631 |
+
#include <ATen/ops/cumprod_backward.h>
|
| 632 |
+
#include <ATen/ops/cumsum.h>
|
| 633 |
+
#include <ATen/ops/cumulative_trapezoid.h>
|
| 634 |
+
#include <ATen/ops/data.h>
|
| 635 |
+
#include <ATen/ops/deg2rad.h>
|
| 636 |
+
#include <ATen/ops/dense_dim.h>
|
| 637 |
+
#include <ATen/ops/dequantize.h>
|
| 638 |
+
#include <ATen/ops/det.h>
|
| 639 |
+
#include <ATen/ops/detach.h>
|
| 640 |
+
#include <ATen/ops/detach_copy.h>
|
| 641 |
+
#include <ATen/ops/diag.h>
|
| 642 |
+
#include <ATen/ops/diag_embed.h>
|
| 643 |
+
#include <ATen/ops/diagflat.h>
|
| 644 |
+
#include <ATen/ops/diagonal.h>
|
| 645 |
+
#include <ATen/ops/diagonal_backward.h>
|
| 646 |
+
#include <ATen/ops/diagonal_copy.h>
|
| 647 |
+
#include <ATen/ops/diagonal_scatter.h>
|
| 648 |
+
#include <ATen/ops/diff.h>
|
| 649 |
+
#include <ATen/ops/digamma.h>
|
| 650 |
+
#include <ATen/ops/dist.h>
|
| 651 |
+
#include <ATen/ops/div.h>
|
| 652 |
+
#include <ATen/ops/divide.h>
|
| 653 |
+
#include <ATen/ops/dot.h>
|
| 654 |
+
#include <ATen/ops/dropout.h>
|
| 655 |
+
#include <ATen/ops/dsplit.h>
|
| 656 |
+
#include <ATen/ops/dstack.h>
|
| 657 |
+
#include <ATen/ops/einsum.h>
|
| 658 |
+
#include <ATen/ops/elu.h>
|
| 659 |
+
#include <ATen/ops/elu_backward.h>
|
| 660 |
+
#include <ATen/ops/embedding.h>
|
| 661 |
+
#include <ATen/ops/embedding_backward.h>
|
| 662 |
+
#include <ATen/ops/embedding_bag.h>
|
| 663 |
+
#include <ATen/ops/embedding_dense_backward.h>
|
| 664 |
+
#include <ATen/ops/embedding_renorm.h>
|
| 665 |
+
#include <ATen/ops/embedding_sparse_backward.h>
|
| 666 |
+
#include <ATen/ops/empty.h>
|
| 667 |
+
#include <ATen/ops/empty_like.h>
|
| 668 |
+
#include <ATen/ops/empty_permuted.h>
|
| 669 |
+
#include <ATen/ops/empty_quantized.h>
|
| 670 |
+
#include <ATen/ops/empty_strided.h>
|
| 671 |
+
#include <ATen/ops/eq.h>
|
| 672 |
+
#include <ATen/ops/equal.h>
|
| 673 |
+
#include <ATen/ops/erf.h>
|
| 674 |
+
#include <ATen/ops/erfc.h>
|
| 675 |
+
#include <ATen/ops/erfinv.h>
|
| 676 |
+
#include <ATen/ops/exp.h>
|
| 677 |
+
#include <ATen/ops/exp2.h>
|
| 678 |
+
#include <ATen/ops/expand.h>
|
| 679 |
+
#include <ATen/ops/expand_as.h>
|
| 680 |
+
#include <ATen/ops/expand_copy.h>
|
| 681 |
+
#include <ATen/ops/expm1.h>
|
| 682 |
+
#include <ATen/ops/exponential.h>
|
| 683 |
+
#include <ATen/ops/eye.h>
|
| 684 |
+
#include <ATen/ops/fake_quantize_per_channel_affine.h>
|
| 685 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
|
| 686 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h>
|
| 687 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine.h>
|
| 688 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask.h>
|
| 689 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward.h>
|
| 690 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight.h>
|
| 691 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h>
|
| 692 |
+
#include <ATen/ops/fbgemm_linear_int8_weight.h>
|
| 693 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
|
| 694 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight.h>
|
| 695 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16.h>
|
| 696 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix.h>
|
| 697 |
+
#include <ATen/ops/feature_alpha_dropout.h>
|
| 698 |
+
#include <ATen/ops/feature_dropout.h>
|
| 699 |
+
#include <ATen/ops/fft_fft.h>
|
| 700 |
+
#include <ATen/ops/fft_fft2.h>
|
| 701 |
+
#include <ATen/ops/fft_fftfreq.h>
|
| 702 |
+
#include <ATen/ops/fft_fftn.h>
|
| 703 |
+
#include <ATen/ops/fft_fftshift.h>
|
| 704 |
+
#include <ATen/ops/fft_hfft.h>
|
| 705 |
+
#include <ATen/ops/fft_hfft2.h>
|
| 706 |
+
#include <ATen/ops/fft_hfftn.h>
|
| 707 |
+
#include <ATen/ops/fft_ifft.h>
|
| 708 |
+
#include <ATen/ops/fft_ifft2.h>
|
| 709 |
+
#include <ATen/ops/fft_ifftn.h>
|
| 710 |
+
#include <ATen/ops/fft_ifftshift.h>
|
| 711 |
+
#include <ATen/ops/fft_ihfft.h>
|
| 712 |
+
#include <ATen/ops/fft_ihfft2.h>
|
| 713 |
+
#include <ATen/ops/fft_ihfftn.h>
|
| 714 |
+
#include <ATen/ops/fft_irfft.h>
|
| 715 |
+
#include <ATen/ops/fft_irfft2.h>
|
| 716 |
+
#include <ATen/ops/fft_irfftn.h>
|
| 717 |
+
#include <ATen/ops/fft_rfft.h>
|
| 718 |
+
#include <ATen/ops/fft_rfft2.h>
|
| 719 |
+
#include <ATen/ops/fft_rfftfreq.h>
|
| 720 |
+
#include <ATen/ops/fft_rfftn.h>
|
| 721 |
+
#include <ATen/ops/fill.h>
|
| 722 |
+
#include <ATen/ops/fill_diagonal.h>
|
| 723 |
+
#include <ATen/ops/fix.h>
|
| 724 |
+
#include <ATen/ops/flatten.h>
|
| 725 |
+
#include <ATen/ops/flatten_dense_tensors.h>
|
| 726 |
+
#include <ATen/ops/flip.h>
|
| 727 |
+
#include <ATen/ops/fliplr.h>
|
| 728 |
+
#include <ATen/ops/flipud.h>
|
| 729 |
+
#include <ATen/ops/float_power.h>
|
| 730 |
+
#include <ATen/ops/floor.h>
|
| 731 |
+
#include <ATen/ops/floor_divide.h>
|
| 732 |
+
#include <ATen/ops/fmax.h>
|
| 733 |
+
#include <ATen/ops/fmin.h>
|
| 734 |
+
#include <ATen/ops/fmod.h>
|
| 735 |
+
#include <ATen/ops/frac.h>
|
| 736 |
+
#include <ATen/ops/fractional_max_pool2d.h>
|
| 737 |
+
#include <ATen/ops/fractional_max_pool2d_backward.h>
|
| 738 |
+
#include <ATen/ops/fractional_max_pool3d.h>
|
| 739 |
+
#include <ATen/ops/fractional_max_pool3d_backward.h>
|
| 740 |
+
#include <ATen/ops/frexp.h>
|
| 741 |
+
#include <ATen/ops/frobenius_norm.h>
|
| 742 |
+
#include <ATen/ops/from_file.h>
|
| 743 |
+
#include <ATen/ops/full.h>
|
| 744 |
+
#include <ATen/ops/full_like.h>
|
| 745 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant.h>
|
| 746 |
+
#include <ATen/ops/gather.h>
|
| 747 |
+
#include <ATen/ops/gather_backward.h>
|
| 748 |
+
#include <ATen/ops/gcd.h>
|
| 749 |
+
#include <ATen/ops/ge.h>
|
| 750 |
+
#include <ATen/ops/gelu.h>
|
| 751 |
+
#include <ATen/ops/gelu_backward.h>
|
| 752 |
+
#include <ATen/ops/geometric.h>
|
| 753 |
+
#include <ATen/ops/geqrf.h>
|
| 754 |
+
#include <ATen/ops/ger.h>
|
| 755 |
+
#include <ATen/ops/glu.h>
|
| 756 |
+
#include <ATen/ops/glu_backward.h>
|
| 757 |
+
#include <ATen/ops/glu_backward_jvp.h>
|
| 758 |
+
#include <ATen/ops/glu_jvp.h>
|
| 759 |
+
#include <ATen/ops/gradient.h>
|
| 760 |
+
#include <ATen/ops/greater.h>
|
| 761 |
+
#include <ATen/ops/greater_equal.h>
|
| 762 |
+
#include <ATen/ops/grid_sampler.h>
|
| 763 |
+
#include <ATen/ops/grid_sampler_2d.h>
|
| 764 |
+
#include <ATen/ops/grid_sampler_2d_backward.h>
|
| 765 |
+
#include <ATen/ops/grid_sampler_3d.h>
|
| 766 |
+
#include <ATen/ops/grid_sampler_3d_backward.h>
|
| 767 |
+
#include <ATen/ops/group_norm.h>
|
| 768 |
+
#include <ATen/ops/gru.h>
|
| 769 |
+
#include <ATen/ops/gru_cell.h>
|
| 770 |
+
#include <ATen/ops/gt.h>
|
| 771 |
+
#include <ATen/ops/hamming_window.h>
|
| 772 |
+
#include <ATen/ops/hann_window.h>
|
| 773 |
+
#include <ATen/ops/hardshrink.h>
|
| 774 |
+
#include <ATen/ops/hardshrink_backward.h>
|
| 775 |
+
#include <ATen/ops/hardsigmoid.h>
|
| 776 |
+
#include <ATen/ops/hardsigmoid_backward.h>
|
| 777 |
+
#include <ATen/ops/hardswish.h>
|
| 778 |
+
#include <ATen/ops/hardswish_backward.h>
|
| 779 |
+
#include <ATen/ops/hardtanh.h>
|
| 780 |
+
#include <ATen/ops/hardtanh_backward.h>
|
| 781 |
+
#include <ATen/ops/heaviside.h>
|
| 782 |
+
#include <ATen/ops/hinge_embedding_loss.h>
|
| 783 |
+
#include <ATen/ops/histc.h>
|
| 784 |
+
#include <ATen/ops/histogram.h>
|
| 785 |
+
#include <ATen/ops/histogramdd.h>
|
| 786 |
+
#include <ATen/ops/hsplit.h>
|
| 787 |
+
#include <ATen/ops/hspmm.h>
|
| 788 |
+
#include <ATen/ops/hstack.h>
|
| 789 |
+
#include <ATen/ops/huber_loss.h>
|
| 790 |
+
#include <ATen/ops/huber_loss_backward.h>
|
| 791 |
+
#include <ATen/ops/hypot.h>
|
| 792 |
+
#include <ATen/ops/i0.h>
|
| 793 |
+
#include <ATen/ops/igamma.h>
|
| 794 |
+
#include <ATen/ops/igammac.h>
|
| 795 |
+
#include <ATen/ops/im2col.h>
|
| 796 |
+
#include <ATen/ops/imag.h>
|
| 797 |
+
#include <ATen/ops/index.h>
|
| 798 |
+
#include <ATen/ops/index_add.h>
|
| 799 |
+
#include <ATen/ops/index_copy.h>
|
| 800 |
+
#include <ATen/ops/index_fill.h>
|
| 801 |
+
#include <ATen/ops/index_put.h>
|
| 802 |
+
#include <ATen/ops/index_reduce.h>
|
| 803 |
+
#include <ATen/ops/index_select.h>
|
| 804 |
+
#include <ATen/ops/index_select_backward.h>
|
| 805 |
+
#include <ATen/ops/indices.h>
|
| 806 |
+
#include <ATen/ops/indices_copy.h>
|
| 807 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward.h>
|
| 808 |
+
#include <ATen/ops/inner.h>
|
| 809 |
+
#include <ATen/ops/instance_norm.h>
|
| 810 |
+
#include <ATen/ops/int_repr.h>
|
| 811 |
+
#include <ATen/ops/inverse.h>
|
| 812 |
+
#include <ATen/ops/is_coalesced.h>
|
| 813 |
+
#include <ATen/ops/is_complex.h>
|
| 814 |
+
#include <ATen/ops/is_conj.h>
|
| 815 |
+
#include <ATen/ops/is_distributed.h>
|
| 816 |
+
#include <ATen/ops/is_floating_point.h>
|
| 817 |
+
#include <ATen/ops/is_inference.h>
|
| 818 |
+
#include <ATen/ops/is_leaf.h>
|
| 819 |
+
#include <ATen/ops/is_neg.h>
|
| 820 |
+
#include <ATen/ops/is_nonzero.h>
|
| 821 |
+
#include <ATen/ops/is_pinned.h>
|
| 822 |
+
#include <ATen/ops/is_same_size.h>
|
| 823 |
+
#include <ATen/ops/is_set_to.h>
|
| 824 |
+
#include <ATen/ops/is_signed.h>
|
| 825 |
+
#include <ATen/ops/is_vulkan_available.h>
|
| 826 |
+
#include <ATen/ops/isclose.h>
|
| 827 |
+
#include <ATen/ops/isfinite.h>
|
| 828 |
+
#include <ATen/ops/isin.h>
|
| 829 |
+
#include <ATen/ops/isinf.h>
|
| 830 |
+
#include <ATen/ops/isnan.h>
|
| 831 |
+
#include <ATen/ops/isneginf.h>
|
| 832 |
+
#include <ATen/ops/isposinf.h>
|
| 833 |
+
#include <ATen/ops/isreal.h>
|
| 834 |
+
#include <ATen/ops/istft.h>
|
| 835 |
+
#include <ATen/ops/item.h>
|
| 836 |
+
#include <ATen/ops/kaiser_window.h>
|
| 837 |
+
#include <ATen/ops/kl_div.h>
|
| 838 |
+
#include <ATen/ops/kron.h>
|
| 839 |
+
#include <ATen/ops/kthvalue.h>
|
| 840 |
+
#include <ATen/ops/l1_loss.h>
|
| 841 |
+
#include <ATen/ops/layer_norm.h>
|
| 842 |
+
#include <ATen/ops/lcm.h>
|
| 843 |
+
#include <ATen/ops/ldexp.h>
|
| 844 |
+
#include <ATen/ops/le.h>
|
| 845 |
+
#include <ATen/ops/leaky_relu.h>
|
| 846 |
+
#include <ATen/ops/leaky_relu_backward.h>
|
| 847 |
+
#include <ATen/ops/lerp.h>
|
| 848 |
+
#include <ATen/ops/less.h>
|
| 849 |
+
#include <ATen/ops/less_equal.h>
|
| 850 |
+
#include <ATen/ops/lgamma.h>
|
| 851 |
+
#include <ATen/ops/lift.h>
|
| 852 |
+
#include <ATen/ops/lift_fresh.h>
|
| 853 |
+
#include <ATen/ops/lift_fresh_copy.h>
|
| 854 |
+
#include <ATen/ops/linalg_cholesky.h>
|
| 855 |
+
#include <ATen/ops/linalg_cholesky_ex.h>
|
| 856 |
+
#include <ATen/ops/linalg_cond.h>
|
| 857 |
+
#include <ATen/ops/linalg_cross.h>
|
| 858 |
+
#include <ATen/ops/linalg_det.h>
|
| 859 |
+
#include <ATen/ops/linalg_diagonal.h>
|
| 860 |
+
#include <ATen/ops/linalg_eig.h>
|
| 861 |
+
#include <ATen/ops/linalg_eigh.h>
|
| 862 |
+
#include <ATen/ops/linalg_eigvals.h>
|
| 863 |
+
#include <ATen/ops/linalg_eigvalsh.h>
|
| 864 |
+
#include <ATen/ops/linalg_householder_product.h>
|
| 865 |
+
#include <ATen/ops/linalg_inv.h>
|
| 866 |
+
#include <ATen/ops/linalg_inv_ex.h>
|
| 867 |
+
#include <ATen/ops/linalg_ldl_factor.h>
|
| 868 |
+
#include <ATen/ops/linalg_ldl_factor_ex.h>
|
| 869 |
+
#include <ATen/ops/linalg_ldl_solve.h>
|
| 870 |
+
#include <ATen/ops/linalg_lstsq.h>
|
| 871 |
+
#include <ATen/ops/linalg_lu.h>
|
| 872 |
+
#include <ATen/ops/linalg_lu_factor.h>
|
| 873 |
+
#include <ATen/ops/linalg_lu_factor_ex.h>
|
| 874 |
+
#include <ATen/ops/linalg_lu_solve.h>
|
| 875 |
+
#include <ATen/ops/linalg_matmul.h>
|
| 876 |
+
#include <ATen/ops/linalg_matrix_exp.h>
|
| 877 |
+
#include <ATen/ops/linalg_matrix_norm.h>
|
| 878 |
+
#include <ATen/ops/linalg_matrix_power.h>
|
| 879 |
+
#include <ATen/ops/linalg_matrix_rank.h>
|
| 880 |
+
#include <ATen/ops/linalg_multi_dot.h>
|
| 881 |
+
#include <ATen/ops/linalg_norm.h>
|
| 882 |
+
#include <ATen/ops/linalg_pinv.h>
|
| 883 |
+
#include <ATen/ops/linalg_qr.h>
|
| 884 |
+
#include <ATen/ops/linalg_slogdet.h>
|
| 885 |
+
#include <ATen/ops/linalg_solve.h>
|
| 886 |
+
#include <ATen/ops/linalg_solve_ex.h>
|
| 887 |
+
#include <ATen/ops/linalg_solve_triangular.h>
|
| 888 |
+
#include <ATen/ops/linalg_svd.h>
|
| 889 |
+
#include <ATen/ops/linalg_svdvals.h>
|
| 890 |
+
#include <ATen/ops/linalg_tensorinv.h>
|
| 891 |
+
#include <ATen/ops/linalg_tensorsolve.h>
|
| 892 |
+
#include <ATen/ops/linalg_vander.h>
|
| 893 |
+
#include <ATen/ops/linalg_vecdot.h>
|
| 894 |
+
#include <ATen/ops/linalg_vector_norm.h>
|
| 895 |
+
#include <ATen/ops/linear.h>
|
| 896 |
+
#include <ATen/ops/linear_backward.h>
|
| 897 |
+
#include <ATen/ops/linspace.h>
|
| 898 |
+
#include <ATen/ops/log.h>
|
| 899 |
+
#include <ATen/ops/log10.h>
|
| 900 |
+
#include <ATen/ops/log1p.h>
|
| 901 |
+
#include <ATen/ops/log2.h>
|
| 902 |
+
#include <ATen/ops/log_normal.h>
|
| 903 |
+
#include <ATen/ops/log_sigmoid.h>
|
| 904 |
+
#include <ATen/ops/log_sigmoid_backward.h>
|
| 905 |
+
#include <ATen/ops/log_sigmoid_forward.h>
|
| 906 |
+
#include <ATen/ops/log_softmax.h>
|
| 907 |
+
#include <ATen/ops/logaddexp.h>
|
| 908 |
+
#include <ATen/ops/logaddexp2.h>
|
| 909 |
+
#include <ATen/ops/logcumsumexp.h>
|
| 910 |
+
#include <ATen/ops/logdet.h>
|
| 911 |
+
#include <ATen/ops/logical_and.h>
|
| 912 |
+
#include <ATen/ops/logical_not.h>
|
| 913 |
+
#include <ATen/ops/logical_or.h>
|
| 914 |
+
#include <ATen/ops/logical_xor.h>
|
| 915 |
+
#include <ATen/ops/logit.h>
|
| 916 |
+
#include <ATen/ops/logit_backward.h>
|
| 917 |
+
#include <ATen/ops/logspace.h>
|
| 918 |
+
#include <ATen/ops/logsumexp.h>
|
| 919 |
+
#include <ATen/ops/lshift.h>
|
| 920 |
+
#include <ATen/ops/lstm.h>
|
| 921 |
+
#include <ATen/ops/lstm_cell.h>
|
| 922 |
+
#include <ATen/ops/lstm_mps_backward.h>
|
| 923 |
+
#include <ATen/ops/lt.h>
|
| 924 |
+
#include <ATen/ops/lu_solve.h>
|
| 925 |
+
#include <ATen/ops/lu_unpack.h>
|
| 926 |
+
#include <ATen/ops/mH.h>
|
| 927 |
+
#include <ATen/ops/mT.h>
|
| 928 |
+
#include <ATen/ops/margin_ranking_loss.h>
|
| 929 |
+
#include <ATen/ops/masked_fill.h>
|
| 930 |
+
#include <ATen/ops/masked_scatter.h>
|
| 931 |
+
#include <ATen/ops/masked_scatter_backward.h>
|
| 932 |
+
#include <ATen/ops/masked_select.h>
|
| 933 |
+
#include <ATen/ops/masked_select_backward.h>
|
| 934 |
+
#include <ATen/ops/matmul.h>
|
| 935 |
+
#include <ATen/ops/matmul_backward.h>
|
| 936 |
+
#include <ATen/ops/matrix_H.h>
|
| 937 |
+
#include <ATen/ops/matrix_exp.h>
|
| 938 |
+
#include <ATen/ops/matrix_exp_backward.h>
|
| 939 |
+
#include <ATen/ops/matrix_power.h>
|
| 940 |
+
#include <ATen/ops/max.h>
|
| 941 |
+
#include <ATen/ops/max_pool1d.h>
|
| 942 |
+
#include <ATen/ops/max_pool1d_with_indices.h>
|
| 943 |
+
#include <ATen/ops/max_pool2d.h>
|
| 944 |
+
#include <ATen/ops/max_pool2d_backward.h>
|
| 945 |
+
#include <ATen/ops/max_pool2d_with_indices.h>
|
| 946 |
+
#include <ATen/ops/max_pool2d_with_indices_backward.h>
|
| 947 |
+
#include <ATen/ops/max_pool3d.h>
|
| 948 |
+
#include <ATen/ops/max_pool3d_with_indices.h>
|
| 949 |
+
#include <ATen/ops/max_pool3d_with_indices_backward.h>
|
| 950 |
+
#include <ATen/ops/max_unpool2d.h>
|
| 951 |
+
#include <ATen/ops/max_unpool3d.h>
|
| 952 |
+
#include <ATen/ops/maximum.h>
|
| 953 |
+
#include <ATen/ops/mean.h>
|
| 954 |
+
#include <ATen/ops/median.h>
|
| 955 |
+
#include <ATen/ops/meshgrid.h>
|
| 956 |
+
#include <ATen/ops/min.h>
|
| 957 |
+
#include <ATen/ops/minimum.h>
|
| 958 |
+
#include <ATen/ops/miopen_batch_norm.h>
|
| 959 |
+
#include <ATen/ops/miopen_batch_norm_backward.h>
|
| 960 |
+
#include <ATen/ops/miopen_convolution.h>
|
| 961 |
+
#include <ATen/ops/miopen_convolution_add_relu.h>
|
| 962 |
+
#include <ATen/ops/miopen_convolution_relu.h>
|
| 963 |
+
#include <ATen/ops/miopen_convolution_transpose.h>
|
| 964 |
+
#include <ATen/ops/miopen_depthwise_convolution.h>
|
| 965 |
+
#include <ATen/ops/miopen_rnn.h>
|
| 966 |
+
#include <ATen/ops/miopen_rnn_backward.h>
|
| 967 |
+
#include <ATen/ops/mish.h>
|
| 968 |
+
#include <ATen/ops/mish_backward.h>
|
| 969 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d.h>
|
| 970 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward.h>
|
| 971 |
+
#include <ATen/ops/mkldnn_convolution.h>
|
| 972 |
+
#include <ATen/ops/mkldnn_linear.h>
|
| 973 |
+
#include <ATen/ops/mkldnn_linear_backward.h>
|
| 974 |
+
#include <ATen/ops/mkldnn_linear_backward_input.h>
|
| 975 |
+
#include <ATen/ops/mkldnn_linear_backward_weights.h>
|
| 976 |
+
#include <ATen/ops/mkldnn_max_pool2d.h>
|
| 977 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward.h>
|
| 978 |
+
#include <ATen/ops/mkldnn_max_pool3d.h>
|
| 979 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward.h>
|
| 980 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight.h>
|
| 981 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight.h>
|
| 982 |
+
#include <ATen/ops/mkldnn_rnn_layer.h>
|
| 983 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward.h>
|
| 984 |
+
#include <ATen/ops/mm.h>
|
| 985 |
+
#include <ATen/ops/mode.h>
|
| 986 |
+
#include <ATen/ops/moveaxis.h>
|
| 987 |
+
#include <ATen/ops/movedim.h>
|
| 988 |
+
#include <ATen/ops/mps_convolution_backward.h>
|
| 989 |
+
#include <ATen/ops/mps_convolution_transpose_backward.h>
|
| 990 |
+
#include <ATen/ops/mse_loss.h>
|
| 991 |
+
#include <ATen/ops/mse_loss_backward.h>
|
| 992 |
+
#include <ATen/ops/msort.h>
|
| 993 |
+
#include <ATen/ops/mul.h>
|
| 994 |
+
#include <ATen/ops/multi_margin_loss.h>
|
| 995 |
+
#include <ATen/ops/multi_margin_loss_backward.h>
|
| 996 |
+
#include <ATen/ops/multilabel_margin_loss.h>
|
| 997 |
+
#include <ATen/ops/multilabel_margin_loss_backward.h>
|
| 998 |
+
#include <ATen/ops/multilabel_margin_loss_forward.h>
|
| 999 |
+
#include <ATen/ops/multinomial.h>
|
| 1000 |
+
#include <ATen/ops/multiply.h>
|
| 1001 |
+
#include <ATen/ops/mv.h>
|
| 1002 |
+
#include <ATen/ops/mvlgamma.h>
|
| 1003 |
+
#include <ATen/ops/nan_to_num.h>
|
| 1004 |
+
#include <ATen/ops/nanmean.h>
|
| 1005 |
+
#include <ATen/ops/nanmedian.h>
|
| 1006 |
+
#include <ATen/ops/nanquantile.h>
|
| 1007 |
+
#include <ATen/ops/nansum.h>
|
| 1008 |
+
#include <ATen/ops/narrow.h>
|
| 1009 |
+
#include <ATen/ops/narrow_copy.h>
|
| 1010 |
+
#include <ATen/ops/native_batch_norm.h>
|
| 1011 |
+
#include <ATen/ops/native_batch_norm_backward.h>
|
| 1012 |
+
#include <ATen/ops/native_channel_shuffle.h>
|
| 1013 |
+
#include <ATen/ops/native_dropout.h>
|
| 1014 |
+
#include <ATen/ops/native_dropout_backward.h>
|
| 1015 |
+
#include <ATen/ops/native_group_norm.h>
|
| 1016 |
+
#include <ATen/ops/native_group_norm_backward.h>
|
| 1017 |
+
#include <ATen/ops/native_layer_norm.h>
|
| 1018 |
+
#include <ATen/ops/native_layer_norm_backward.h>
|
| 1019 |
+
#include <ATen/ops/native_norm.h>
|
| 1020 |
+
#include <ATen/ops/ne.h>
|
| 1021 |
+
#include <ATen/ops/neg.h>
|
| 1022 |
+
#include <ATen/ops/negative.h>
|
| 1023 |
+
#include <ATen/ops/nested_to_padded_tensor.h>
|
| 1024 |
+
#include <ATen/ops/new_empty.h>
|
| 1025 |
+
#include <ATen/ops/new_empty_strided.h>
|
| 1026 |
+
#include <ATen/ops/new_full.h>
|
| 1027 |
+
#include <ATen/ops/new_ones.h>
|
| 1028 |
+
#include <ATen/ops/new_zeros.h>
|
| 1029 |
+
#include <ATen/ops/nextafter.h>
|
| 1030 |
+
#include <ATen/ops/nll_loss.h>
|
| 1031 |
+
#include <ATen/ops/nll_loss2d.h>
|
| 1032 |
+
#include <ATen/ops/nll_loss2d_backward.h>
|
| 1033 |
+
#include <ATen/ops/nll_loss2d_forward.h>
|
| 1034 |
+
#include <ATen/ops/nll_loss_backward.h>
|
| 1035 |
+
#include <ATen/ops/nll_loss_forward.h>
|
| 1036 |
+
#include <ATen/ops/nll_loss_nd.h>
|
| 1037 |
+
#include <ATen/ops/nonzero.h>
|
| 1038 |
+
#include <ATen/ops/nonzero_numpy.h>
|
| 1039 |
+
#include <ATen/ops/nonzero_static.h>
|
| 1040 |
+
#include <ATen/ops/norm.h>
|
| 1041 |
+
#include <ATen/ops/norm_except_dim.h>
|
| 1042 |
+
#include <ATen/ops/normal.h>
|
| 1043 |
+
#include <ATen/ops/not_equal.h>
|
| 1044 |
+
#include <ATen/ops/nuclear_norm.h>
|
| 1045 |
+
#include <ATen/ops/numpy_T.h>
|
| 1046 |
+
#include <ATen/ops/one_hot.h>
|
| 1047 |
+
#include <ATen/ops/ones.h>
|
| 1048 |
+
#include <ATen/ops/ones_like.h>
|
| 1049 |
+
#include <ATen/ops/or.h>
|
| 1050 |
+
#include <ATen/ops/orgqr.h>
|
| 1051 |
+
#include <ATen/ops/ormqr.h>
|
| 1052 |
+
#include <ATen/ops/outer.h>
|
| 1053 |
+
#include <ATen/ops/output_nr.h>
|
| 1054 |
+
#include <ATen/ops/pad.h>
|
| 1055 |
+
#include <ATen/ops/pad_sequence.h>
|
| 1056 |
+
#include <ATen/ops/pairwise_distance.h>
|
| 1057 |
+
#include <ATen/ops/pdist.h>
|
| 1058 |
+
#include <ATen/ops/permute.h>
|
| 1059 |
+
#include <ATen/ops/permute_copy.h>
|
| 1060 |
+
#include <ATen/ops/pin_memory.h>
|
| 1061 |
+
#include <ATen/ops/pinverse.h>
|
| 1062 |
+
#include <ATen/ops/pixel_shuffle.h>
|
| 1063 |
+
#include <ATen/ops/pixel_unshuffle.h>
|
| 1064 |
+
#include <ATen/ops/poisson.h>
|
| 1065 |
+
#include <ATen/ops/poisson_nll_loss.h>
|
| 1066 |
+
#include <ATen/ops/polar.h>
|
| 1067 |
+
#include <ATen/ops/polygamma.h>
|
| 1068 |
+
#include <ATen/ops/positive.h>
|
| 1069 |
+
#include <ATen/ops/pow.h>
|
| 1070 |
+
#include <ATen/ops/prelu.h>
|
| 1071 |
+
#include <ATen/ops/prod.h>
|
| 1072 |
+
#include <ATen/ops/promote_types.h>
|
| 1073 |
+
#include <ATen/ops/put.h>
|
| 1074 |
+
#include <ATen/ops/q_per_channel_axis.h>
|
| 1075 |
+
#include <ATen/ops/q_per_channel_scales.h>
|
| 1076 |
+
#include <ATen/ops/q_per_channel_zero_points.h>
|
| 1077 |
+
#include <ATen/ops/q_scale.h>
|
| 1078 |
+
#include <ATen/ops/q_zero_point.h>
|
| 1079 |
+
#include <ATen/ops/qr.h>
|
| 1080 |
+
#include <ATen/ops/qscheme.h>
|
| 1081 |
+
#include <ATen/ops/quantile.h>
|
| 1082 |
+
#include <ATen/ops/quantize_per_channel.h>
|
| 1083 |
+
#include <ATen/ops/quantize_per_tensor.h>
|
| 1084 |
+
#include <ATen/ops/quantize_per_tensor_dynamic.h>
|
| 1085 |
+
#include <ATen/ops/quantized_batch_norm.h>
|
| 1086 |
+
#include <ATen/ops/quantized_gru_cell.h>
|
| 1087 |
+
#include <ATen/ops/quantized_lstm_cell.h>
|
| 1088 |
+
#include <ATen/ops/quantized_max_pool1d.h>
|
| 1089 |
+
#include <ATen/ops/quantized_max_pool2d.h>
|
| 1090 |
+
#include <ATen/ops/quantized_max_pool3d.h>
|
| 1091 |
+
#include <ATen/ops/quantized_rnn_relu_cell.h>
|
| 1092 |
+
#include <ATen/ops/quantized_rnn_tanh_cell.h>
|
| 1093 |
+
#include <ATen/ops/rad2deg.h>
|
| 1094 |
+
#include <ATen/ops/rand.h>
|
| 1095 |
+
#include <ATen/ops/rand_like.h>
|
| 1096 |
+
#include <ATen/ops/randint.h>
|
| 1097 |
+
#include <ATen/ops/randint_like.h>
|
| 1098 |
+
#include <ATen/ops/randn.h>
|
| 1099 |
+
#include <ATen/ops/randn_like.h>
|
| 1100 |
+
#include <ATen/ops/random.h>
|
| 1101 |
+
#include <ATen/ops/randperm.h>
|
| 1102 |
+
#include <ATen/ops/range.h>
|
| 1103 |
+
#include <ATen/ops/ravel.h>
|
| 1104 |
+
#include <ATen/ops/real.h>
|
| 1105 |
+
#include <ATen/ops/reciprocal.h>
|
| 1106 |
+
#include <ATen/ops/record_stream.h>
|
| 1107 |
+
#include <ATen/ops/refine_names.h>
|
| 1108 |
+
#include <ATen/ops/reflection_pad1d.h>
|
| 1109 |
+
#include <ATen/ops/reflection_pad1d_backward.h>
|
| 1110 |
+
#include <ATen/ops/reflection_pad2d.h>
|
| 1111 |
+
#include <ATen/ops/reflection_pad2d_backward.h>
|
| 1112 |
+
#include <ATen/ops/reflection_pad3d.h>
|
| 1113 |
+
#include <ATen/ops/reflection_pad3d_backward.h>
|
| 1114 |
+
#include <ATen/ops/relu.h>
|
| 1115 |
+
#include <ATen/ops/relu6.h>
|
| 1116 |
+
#include <ATen/ops/remainder.h>
|
| 1117 |
+
#include <ATen/ops/rename.h>
|
| 1118 |
+
#include <ATen/ops/renorm.h>
|
| 1119 |
+
#include <ATen/ops/repeat.h>
|
| 1120 |
+
#include <ATen/ops/repeat_interleave.h>
|
| 1121 |
+
#include <ATen/ops/replication_pad1d.h>
|
| 1122 |
+
#include <ATen/ops/replication_pad1d_backward.h>
|
| 1123 |
+
#include <ATen/ops/replication_pad2d.h>
|
| 1124 |
+
#include <ATen/ops/replication_pad2d_backward.h>
|
| 1125 |
+
#include <ATen/ops/replication_pad3d.h>
|
| 1126 |
+
#include <ATen/ops/replication_pad3d_backward.h>
|
| 1127 |
+
#include <ATen/ops/requires_grad.h>
|
| 1128 |
+
#include <ATen/ops/reshape.h>
|
| 1129 |
+
#include <ATen/ops/reshape_as.h>
|
| 1130 |
+
#include <ATen/ops/resize.h>
|
| 1131 |
+
#include <ATen/ops/resize_as.h>
|
| 1132 |
+
#include <ATen/ops/resize_as_sparse.h>
|
| 1133 |
+
#include <ATen/ops/resolve_conj.h>
|
| 1134 |
+
#include <ATen/ops/resolve_neg.h>
|
| 1135 |
+
#include <ATen/ops/result_type.h>
|
| 1136 |
+
#include <ATen/ops/retain_grad.h>
|
| 1137 |
+
#include <ATen/ops/retains_grad.h>
|
| 1138 |
+
#include <ATen/ops/rms_norm.h>
|
| 1139 |
+
#include <ATen/ops/rnn_relu.h>
|
| 1140 |
+
#include <ATen/ops/rnn_relu_cell.h>
|
| 1141 |
+
#include <ATen/ops/rnn_tanh.h>
|
| 1142 |
+
#include <ATen/ops/rnn_tanh_cell.h>
|
| 1143 |
+
#include <ATen/ops/roll.h>
|
| 1144 |
+
#include <ATen/ops/rot90.h>
|
| 1145 |
+
#include <ATen/ops/round.h>
|
| 1146 |
+
#include <ATen/ops/row_indices.h>
|
| 1147 |
+
#include <ATen/ops/row_indices_copy.h>
|
| 1148 |
+
#include <ATen/ops/row_stack.h>
|
| 1149 |
+
#include <ATen/ops/rrelu.h>
|
| 1150 |
+
#include <ATen/ops/rrelu_with_noise.h>
|
| 1151 |
+
#include <ATen/ops/rrelu_with_noise_backward.h>
|
| 1152 |
+
#include <ATen/ops/rshift.h>
|
| 1153 |
+
#include <ATen/ops/rsqrt.h>
|
| 1154 |
+
#include <ATen/ops/rsub.h>
|
| 1155 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 1156 |
+
#include <ATen/ops/scaled_dot_product_attention.h>
|
| 1157 |
+
#include <ATen/ops/scatter.h>
|
| 1158 |
+
#include <ATen/ops/scatter_add.h>
|
| 1159 |
+
#include <ATen/ops/scatter_reduce.h>
|
| 1160 |
+
#include <ATen/ops/searchsorted.h>
|
| 1161 |
+
#include <ATen/ops/segment_reduce.h>
|
| 1162 |
+
#include <ATen/ops/select.h>
|
| 1163 |
+
#include <ATen/ops/select_backward.h>
|
| 1164 |
+
#include <ATen/ops/select_copy.h>
|
| 1165 |
+
#include <ATen/ops/select_scatter.h>
|
| 1166 |
+
#include <ATen/ops/selu.h>
|
| 1167 |
+
#include <ATen/ops/set.h>
|
| 1168 |
+
#include <ATen/ops/set_data.h>
|
| 1169 |
+
#include <ATen/ops/sgn.h>
|
| 1170 |
+
#include <ATen/ops/sigmoid.h>
|
| 1171 |
+
#include <ATen/ops/sigmoid_backward.h>
|
| 1172 |
+
#include <ATen/ops/sign.h>
|
| 1173 |
+
#include <ATen/ops/signbit.h>
|
| 1174 |
+
#include <ATen/ops/silu.h>
|
| 1175 |
+
#include <ATen/ops/silu_backward.h>
|
| 1176 |
+
#include <ATen/ops/sin.h>
|
| 1177 |
+
#include <ATen/ops/sinc.h>
|
| 1178 |
+
#include <ATen/ops/sinh.h>
|
| 1179 |
+
#include <ATen/ops/size.h>
|
| 1180 |
+
#include <ATen/ops/slice.h>
|
| 1181 |
+
#include <ATen/ops/slice_backward.h>
|
| 1182 |
+
#include <ATen/ops/slice_copy.h>
|
| 1183 |
+
#include <ATen/ops/slice_inverse.h>
|
| 1184 |
+
#include <ATen/ops/slice_scatter.h>
|
| 1185 |
+
#include <ATen/ops/slogdet.h>
|
| 1186 |
+
#include <ATen/ops/slow_conv3d.h>
|
| 1187 |
+
#include <ATen/ops/slow_conv3d_forward.h>
|
| 1188 |
+
#include <ATen/ops/slow_conv_dilated2d.h>
|
| 1189 |
+
#include <ATen/ops/slow_conv_dilated3d.h>
|
| 1190 |
+
#include <ATen/ops/slow_conv_transpose2d.h>
|
| 1191 |
+
#include <ATen/ops/slow_conv_transpose3d.h>
|
| 1192 |
+
#include <ATen/ops/smm.h>
|
| 1193 |
+
#include <ATen/ops/smooth_l1_loss.h>
|
| 1194 |
+
#include <ATen/ops/smooth_l1_loss_backward.h>
|
| 1195 |
+
#include <ATen/ops/soft_margin_loss.h>
|
| 1196 |
+
#include <ATen/ops/soft_margin_loss_backward.h>
|
| 1197 |
+
#include <ATen/ops/softmax.h>
|
| 1198 |
+
#include <ATen/ops/softplus.h>
|
| 1199 |
+
#include <ATen/ops/softplus_backward.h>
|
| 1200 |
+
#include <ATen/ops/softshrink.h>
|
| 1201 |
+
#include <ATen/ops/softshrink_backward.h>
|
| 1202 |
+
#include <ATen/ops/sort.h>
|
| 1203 |
+
#include <ATen/ops/sparse_bsc_tensor.h>
|
| 1204 |
+
#include <ATen/ops/sparse_bsr_tensor.h>
|
| 1205 |
+
#include <ATen/ops/sparse_compressed_tensor.h>
|
| 1206 |
+
#include <ATen/ops/sparse_coo_tensor.h>
|
| 1207 |
+
#include <ATen/ops/sparse_csc_tensor.h>
|
| 1208 |
+
#include <ATen/ops/sparse_csr_tensor.h>
|
| 1209 |
+
#include <ATen/ops/sparse_dim.h>
|
| 1210 |
+
#include <ATen/ops/sparse_mask.h>
|
| 1211 |
+
#include <ATen/ops/sparse_resize.h>
|
| 1212 |
+
#include <ATen/ops/sparse_resize_and_clear.h>
|
| 1213 |
+
#include <ATen/ops/sparse_sampled_addmm.h>
|
| 1214 |
+
#include <ATen/ops/special_airy_ai.h>
|
| 1215 |
+
#include <ATen/ops/special_bessel_j0.h>
|
| 1216 |
+
#include <ATen/ops/special_bessel_j1.h>
|
| 1217 |
+
#include <ATen/ops/special_bessel_y0.h>
|
| 1218 |
+
#include <ATen/ops/special_bessel_y1.h>
|
| 1219 |
+
#include <ATen/ops/special_chebyshev_polynomial_t.h>
|
| 1220 |
+
#include <ATen/ops/special_chebyshev_polynomial_u.h>
|
| 1221 |
+
#include <ATen/ops/special_chebyshev_polynomial_v.h>
|
| 1222 |
+
#include <ATen/ops/special_chebyshev_polynomial_w.h>
|
| 1223 |
+
#include <ATen/ops/special_digamma.h>
|
| 1224 |
+
#include <ATen/ops/special_entr.h>
|
| 1225 |
+
#include <ATen/ops/special_erf.h>
|
| 1226 |
+
#include <ATen/ops/special_erfc.h>
|
| 1227 |
+
#include <ATen/ops/special_erfcx.h>
|
| 1228 |
+
#include <ATen/ops/special_erfinv.h>
|
| 1229 |
+
#include <ATen/ops/special_exp2.h>
|
| 1230 |
+
#include <ATen/ops/special_expit.h>
|
| 1231 |
+
#include <ATen/ops/special_expm1.h>
|
| 1232 |
+
#include <ATen/ops/special_gammainc.h>
|
| 1233 |
+
#include <ATen/ops/special_gammaincc.h>
|
| 1234 |
+
#include <ATen/ops/special_gammaln.h>
|
| 1235 |
+
#include <ATen/ops/special_hermite_polynomial_h.h>
|
| 1236 |
+
#include <ATen/ops/special_hermite_polynomial_he.h>
|
| 1237 |
+
#include <ATen/ops/special_i0.h>
|
| 1238 |
+
#include <ATen/ops/special_i0e.h>
|
| 1239 |
+
#include <ATen/ops/special_i1.h>
|
| 1240 |
+
#include <ATen/ops/special_i1e.h>
|
| 1241 |
+
#include <ATen/ops/special_laguerre_polynomial_l.h>
|
| 1242 |
+
#include <ATen/ops/special_legendre_polynomial_p.h>
|
| 1243 |
+
#include <ATen/ops/special_log1p.h>
|
| 1244 |
+
#include <ATen/ops/special_log_ndtr.h>
|
| 1245 |
+
#include <ATen/ops/special_log_softmax.h>
|
| 1246 |
+
#include <ATen/ops/special_logit.h>
|
| 1247 |
+
#include <ATen/ops/special_logsumexp.h>
|
| 1248 |
+
#include <ATen/ops/special_modified_bessel_i0.h>
|
| 1249 |
+
#include <ATen/ops/special_modified_bessel_i1.h>
|
| 1250 |
+
#include <ATen/ops/special_modified_bessel_k0.h>
|
| 1251 |
+
#include <ATen/ops/special_modified_bessel_k1.h>
|
| 1252 |
+
#include <ATen/ops/special_multigammaln.h>
|
| 1253 |
+
#include <ATen/ops/special_ndtr.h>
|
| 1254 |
+
#include <ATen/ops/special_ndtri.h>
|
| 1255 |
+
#include <ATen/ops/special_polygamma.h>
|
| 1256 |
+
#include <ATen/ops/special_psi.h>
|
| 1257 |
+
#include <ATen/ops/special_round.h>
|
| 1258 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0.h>
|
| 1259 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1.h>
|
| 1260 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t.h>
|
| 1261 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u.h>
|
| 1262 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v.h>
|
| 1263 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w.h>
|
| 1264 |
+
#include <ATen/ops/special_sinc.h>
|
| 1265 |
+
#include <ATen/ops/special_softmax.h>
|
| 1266 |
+
#include <ATen/ops/special_spherical_bessel_j0.h>
|
| 1267 |
+
#include <ATen/ops/special_xlog1py.h>
|
| 1268 |
+
#include <ATen/ops/special_xlogy.h>
|
| 1269 |
+
#include <ATen/ops/special_zeta.h>
|
| 1270 |
+
#include <ATen/ops/split.h>
|
| 1271 |
+
#include <ATen/ops/split_copy.h>
|
| 1272 |
+
#include <ATen/ops/split_with_sizes.h>
|
| 1273 |
+
#include <ATen/ops/split_with_sizes_copy.h>
|
| 1274 |
+
#include <ATen/ops/sqrt.h>
|
| 1275 |
+
#include <ATen/ops/square.h>
|
| 1276 |
+
#include <ATen/ops/squeeze.h>
|
| 1277 |
+
#include <ATen/ops/squeeze_copy.h>
|
| 1278 |
+
#include <ATen/ops/sspaddmm.h>
|
| 1279 |
+
#include <ATen/ops/stack.h>
|
| 1280 |
+
#include <ATen/ops/std.h>
|
| 1281 |
+
#include <ATen/ops/std_mean.h>
|
| 1282 |
+
#include <ATen/ops/stft.h>
|
| 1283 |
+
#include <ATen/ops/stride.h>
|
| 1284 |
+
#include <ATen/ops/sub.h>
|
| 1285 |
+
#include <ATen/ops/subtract.h>
|
| 1286 |
+
#include <ATen/ops/sum.h>
|
| 1287 |
+
#include <ATen/ops/sum_to_size.h>
|
| 1288 |
+
#include <ATen/ops/svd.h>
|
| 1289 |
+
#include <ATen/ops/swapaxes.h>
|
| 1290 |
+
#include <ATen/ops/swapdims.h>
|
| 1291 |
+
#include <ATen/ops/sym_constrain_range.h>
|
| 1292 |
+
#include <ATen/ops/sym_constrain_range_for_size.h>
|
| 1293 |
+
#include <ATen/ops/sym_numel.h>
|
| 1294 |
+
#include <ATen/ops/sym_size.h>
|
| 1295 |
+
#include <ATen/ops/sym_storage_offset.h>
|
| 1296 |
+
#include <ATen/ops/sym_stride.h>
|
| 1297 |
+
#include <ATen/ops/t.h>
|
| 1298 |
+
#include <ATen/ops/t_copy.h>
|
| 1299 |
+
#include <ATen/ops/take.h>
|
| 1300 |
+
#include <ATen/ops/take_along_dim.h>
|
| 1301 |
+
#include <ATen/ops/tan.h>
|
| 1302 |
+
#include <ATen/ops/tanh.h>
|
| 1303 |
+
#include <ATen/ops/tanh_backward.h>
|
| 1304 |
+
#include <ATen/ops/tensor_split.h>
|
| 1305 |
+
#include <ATen/ops/tensordot.h>
|
| 1306 |
+
#include <ATen/ops/thnn_conv2d.h>
|
| 1307 |
+
#include <ATen/ops/threshold.h>
|
| 1308 |
+
#include <ATen/ops/threshold_backward.h>
|
| 1309 |
+
#include <ATen/ops/tile.h>
|
| 1310 |
+
#include <ATen/ops/to.h>
|
| 1311 |
+
#include <ATen/ops/to_dense.h>
|
| 1312 |
+
#include <ATen/ops/to_dense_backward.h>
|
| 1313 |
+
#include <ATen/ops/to_mkldnn.h>
|
| 1314 |
+
#include <ATen/ops/to_mkldnn_backward.h>
|
| 1315 |
+
#include <ATen/ops/to_padded_tensor.h>
|
| 1316 |
+
#include <ATen/ops/to_sparse.h>
|
| 1317 |
+
#include <ATen/ops/to_sparse_bsc.h>
|
| 1318 |
+
#include <ATen/ops/to_sparse_bsr.h>
|
| 1319 |
+
#include <ATen/ops/to_sparse_csc.h>
|
| 1320 |
+
#include <ATen/ops/to_sparse_csr.h>
|
| 1321 |
+
#include <ATen/ops/topk.h>
|
| 1322 |
+
#include <ATen/ops/trace.h>
|
| 1323 |
+
#include <ATen/ops/trace_backward.h>
|
| 1324 |
+
#include <ATen/ops/transpose.h>
|
| 1325 |
+
#include <ATen/ops/transpose_copy.h>
|
| 1326 |
+
#include <ATen/ops/trapezoid.h>
|
| 1327 |
+
#include <ATen/ops/trapz.h>
|
| 1328 |
+
#include <ATen/ops/triangular_solve.h>
|
| 1329 |
+
#include <ATen/ops/tril.h>
|
| 1330 |
+
#include <ATen/ops/tril_indices.h>
|
| 1331 |
+
#include <ATen/ops/triplet_margin_loss.h>
|
| 1332 |
+
#include <ATen/ops/triu.h>
|
| 1333 |
+
#include <ATen/ops/triu_indices.h>
|
| 1334 |
+
#include <ATen/ops/true_divide.h>
|
| 1335 |
+
#include <ATen/ops/trunc.h>
|
| 1336 |
+
#include <ATen/ops/type_as.h>
|
| 1337 |
+
#include <ATen/ops/unbind.h>
|
| 1338 |
+
#include <ATen/ops/unbind_copy.h>
|
| 1339 |
+
#include <ATen/ops/unflatten.h>
|
| 1340 |
+
#include <ATen/ops/unflatten_dense_tensors.h>
|
| 1341 |
+
#include <ATen/ops/unfold.h>
|
| 1342 |
+
#include <ATen/ops/unfold_backward.h>
|
| 1343 |
+
#include <ATen/ops/unfold_copy.h>
|
| 1344 |
+
#include <ATen/ops/uniform.h>
|
| 1345 |
+
#include <ATen/ops/unique_consecutive.h>
|
| 1346 |
+
#include <ATen/ops/unique_dim.h>
|
| 1347 |
+
#include <ATen/ops/unique_dim_consecutive.h>
|
| 1348 |
+
#include <ATen/ops/unsafe_chunk.h>
|
| 1349 |
+
#include <ATen/ops/unsafe_split.h>
|
| 1350 |
+
#include <ATen/ops/unsafe_split_with_sizes.h>
|
| 1351 |
+
#include <ATen/ops/unsqueeze.h>
|
| 1352 |
+
#include <ATen/ops/unsqueeze_copy.h>
|
| 1353 |
+
#include <ATen/ops/upsample_bicubic2d.h>
|
| 1354 |
+
#include <ATen/ops/upsample_bicubic2d_backward.h>
|
| 1355 |
+
#include <ATen/ops/upsample_bilinear2d.h>
|
| 1356 |
+
#include <ATen/ops/upsample_bilinear2d_backward.h>
|
| 1357 |
+
#include <ATen/ops/upsample_linear1d.h>
|
| 1358 |
+
#include <ATen/ops/upsample_linear1d_backward.h>
|
| 1359 |
+
#include <ATen/ops/upsample_nearest1d.h>
|
| 1360 |
+
#include <ATen/ops/upsample_nearest1d_backward.h>
|
| 1361 |
+
#include <ATen/ops/upsample_nearest2d.h>
|
| 1362 |
+
#include <ATen/ops/upsample_nearest2d_backward.h>
|
| 1363 |
+
#include <ATen/ops/upsample_nearest3d.h>
|
| 1364 |
+
#include <ATen/ops/upsample_nearest3d_backward.h>
|
| 1365 |
+
#include <ATen/ops/upsample_trilinear3d.h>
|
| 1366 |
+
#include <ATen/ops/upsample_trilinear3d_backward.h>
|
| 1367 |
+
#include <ATen/ops/value_selecting_reduction_backward.h>
|
| 1368 |
+
#include <ATen/ops/values.h>
|
| 1369 |
+
#include <ATen/ops/values_copy.h>
|
| 1370 |
+
#include <ATen/ops/vander.h>
|
| 1371 |
+
#include <ATen/ops/var.h>
|
| 1372 |
+
#include <ATen/ops/var_mean.h>
|
| 1373 |
+
#include <ATen/ops/vdot.h>
|
| 1374 |
+
#include <ATen/ops/view.h>
|
| 1375 |
+
#include <ATen/ops/view_as.h>
|
| 1376 |
+
#include <ATen/ops/view_as_complex.h>
|
| 1377 |
+
#include <ATen/ops/view_as_complex_copy.h>
|
| 1378 |
+
#include <ATen/ops/view_as_real.h>
|
| 1379 |
+
#include <ATen/ops/view_as_real_copy.h>
|
| 1380 |
+
#include <ATen/ops/view_copy.h>
|
| 1381 |
+
#include <ATen/ops/vsplit.h>
|
| 1382 |
+
#include <ATen/ops/vstack.h>
|
| 1383 |
+
#include <ATen/ops/where.h>
|
| 1384 |
+
#include <ATen/ops/xlogy.h>
|
| 1385 |
+
#include <ATen/ops/xor.h>
|
| 1386 |
+
#include <ATen/ops/zero.h>
|
| 1387 |
+
#include <ATen/ops/zeros.h>
|
| 1388 |
+
#include <ATen/ops/zeros_like.h>
|
| 1389 |
+
|
| 1390 |
+
namespace at {
|
| 1391 |
+
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
// Special C++ only overloads for std()-like functions (See gh-40287)
|
| 1395 |
+
// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
|
| 1396 |
+
// So, for example std(0) would select the std(unbiased=False) overload
|
| 1397 |
+
TORCH_API inline Tensor var(const Tensor& self, int dim) {
|
| 1398 |
+
return at::var(self, IntArrayRef{dim});
|
| 1399 |
+
}
|
| 1400 |
+
TORCH_API inline std::tuple<Tensor, Tensor> var_mean(const Tensor& self, int dim) {
|
| 1401 |
+
return at::var_mean(self, IntArrayRef{dim});
|
| 1402 |
+
}
|
| 1403 |
+
TORCH_API inline Tensor std(const Tensor& self, int dim) {
|
| 1404 |
+
return at::std(self, IntArrayRef{dim});
|
| 1405 |
+
}
|
| 1406 |
+
TORCH_API inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim) {
|
| 1407 |
+
return at::std_mean(self, IntArrayRef{dim});
|
| 1408 |
+
}
|
| 1409 |
+
|
| 1410 |
+
inline int64_t numel(const Tensor& tensor) {
|
| 1411 |
+
return tensor.numel();
|
| 1412 |
+
}
|
| 1413 |
+
|
| 1414 |
+
inline int64_t size(const Tensor& tensor, int64_t dim) {
|
| 1415 |
+
return tensor.size(dim);
|
| 1416 |
+
}
|
| 1417 |
+
|
| 1418 |
+
inline int64_t stride(const Tensor& tensor, int64_t dim) {
|
| 1419 |
+
return tensor.stride(dim);
|
| 1420 |
+
}
|
| 1421 |
+
|
| 1422 |
+
inline bool is_complex(const Tensor& tensor) {
|
| 1423 |
+
return tensor.is_complex();
|
| 1424 |
+
}
|
| 1425 |
+
|
| 1426 |
+
inline bool is_floating_point(const Tensor& tensor) {
|
| 1427 |
+
return tensor.is_floating_point();
|
| 1428 |
+
}
|
| 1429 |
+
|
| 1430 |
+
inline bool is_signed(const Tensor& tensor) {
|
| 1431 |
+
return tensor.is_signed();
|
| 1432 |
+
}
|
| 1433 |
+
|
| 1434 |
+
inline bool is_inference(const Tensor& tensor) {
|
| 1435 |
+
return tensor.is_inference();
|
| 1436 |
+
}
|
| 1437 |
+
|
| 1438 |
+
inline bool _is_zerotensor(const Tensor& tensor) {
|
| 1439 |
+
return tensor._is_zerotensor();
|
| 1440 |
+
}
|
| 1441 |
+
|
| 1442 |
+
inline bool is_conj(const Tensor& tensor) {
|
| 1443 |
+
return tensor.is_conj();
|
| 1444 |
+
}
|
| 1445 |
+
|
| 1446 |
+
inline Tensor conj(const Tensor& tensor) {
|
| 1447 |
+
return tensor.conj();
|
| 1448 |
+
}
|
| 1449 |
+
|
| 1450 |
+
inline bool is_neg(const Tensor& tensor) {
|
| 1451 |
+
return tensor.is_neg();
|
| 1452 |
+
}
|
| 1453 |
+
|
| 1454 |
+
}
|
.venv/Lib/site-packages/torch/include/ATen/Generator.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Generator.h>
|
.venv/Lib/site-packages/torch/include/ATen/InferSize.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/DimVector.h>
|
| 4 |
+
#include <c10/core/ScalarType.h>
|
| 5 |
+
#include <c10/core/SymIntArrayRef.h>
|
| 6 |
+
#include <c10/util/DimVector.h>
|
| 7 |
+
#include <optional>
|
| 8 |
+
#include <sstream>
|
| 9 |
+
#include <vector>
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
|
| 13 |
+
// Infers the size of a dim with size -1, if it exists. Also checks that new
|
| 14 |
+
// shape is compatible with the number of elements.
|
| 15 |
+
//
|
| 16 |
+
// templated to handle std::vector<int64_t> and DimVector use cases, see
|
| 17 |
+
// below
|
| 18 |
+
//
|
| 19 |
+
template <typename InputArrayRef, typename NumelType, typename ResultVec>
|
| 20 |
+
inline void infer_size_impl(
|
| 21 |
+
InputArrayRef shape,
|
| 22 |
+
NumelType numel,
|
| 23 |
+
ResultVec& res) {
|
| 24 |
+
NumelType newsize = 1;
|
| 25 |
+
// N.B. this is an index, not a sym dim!
|
| 26 |
+
std::optional<int64_t> infer_dim;
|
| 27 |
+
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
|
| 28 |
+
if (shape[dim] == -1) {
|
| 29 |
+
if (infer_dim) {
|
| 30 |
+
throw std::runtime_error("only one dimension can be inferred");
|
| 31 |
+
}
|
| 32 |
+
infer_dim = dim;
|
| 33 |
+
} else if (shape[dim] >= 0) {
|
| 34 |
+
newsize *= shape[dim];
|
| 35 |
+
} else {
|
| 36 |
+
AT_ERROR("invalid shape dimension ", shape[dim]);
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
|
| 41 |
+
(infer_dim && newsize > 0 && numel % newsize == 0)) {
|
| 42 |
+
if (infer_dim) {
|
| 43 |
+
// We have a degree of freedom here to select the dimension size; follow
|
| 44 |
+
// NumPy semantics and just bail. However, a nice error message is needed
|
| 45 |
+
// because users often use `view` as a way to flatten & unflatten
|
| 46 |
+
// dimensions and will otherwise be confused why
|
| 47 |
+
// empty_tensor.view( 0, 0)
|
| 48 |
+
// works yet
|
| 49 |
+
// empty_tensor.view(-1, 0)
|
| 50 |
+
// doesn't.
|
| 51 |
+
TORCH_CHECK(
|
| 52 |
+
newsize != 0,
|
| 53 |
+
"cannot reshape tensor of 0 elements into shape ",
|
| 54 |
+
shape,
|
| 55 |
+
" because the unspecified dimension size -1 can be any "
|
| 56 |
+
"value and is ambiguous");
|
| 57 |
+
res[*infer_dim] = numel / newsize;
|
| 58 |
+
}
|
| 59 |
+
return;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
std::ostringstream ss;
|
| 63 |
+
ss << "shape '" << shape << "' is invalid for input of size " << numel;
|
| 64 |
+
throw std::runtime_error(ss.str());
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
|
| 68 |
+
auto res = shape.vec();
|
| 69 |
+
infer_size_impl(shape, numel, res);
|
| 70 |
+
return res;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
|
| 74 |
+
auto res = at::DimVector(shape);
|
| 75 |
+
infer_size_impl(shape, numel, res);
|
| 76 |
+
return res;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
inline at::SymDimVector infer_size_dv(
|
| 80 |
+
c10::SymIntArrayRef shape,
|
| 81 |
+
c10::SymInt numel) {
|
| 82 |
+
auto res = at::SymDimVector(shape);
|
| 83 |
+
infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
|
| 84 |
+
shape, std::move(numel), res);
|
| 85 |
+
return res;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/InitialTensorOptions.h
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/TensorOptions.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
// Represents the initial TensorOptions, before the "defaults" are ever changed.
|
| 8 |
+
// This is designed to be used in library code, where the explicit devices,
|
| 9 |
+
// dtypes, etc. are known. NOTE: this is not a stable API.
|
| 10 |
+
inline TensorOptions initialTensorOptions() {
|
| 11 |
+
return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad(
|
| 12 |
+
false);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/Layout.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Layout.h>
|
.venv/Lib/site-packages/torch/include/ATen/LegacyBatchedFallback.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 4 |
+
#include <torch/library.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
// If an operator doesn't have a batching rule implemented then we fallback
|
| 9 |
+
// to this implementation. The fallback only works on out-of-place operators
|
| 10 |
+
// that return only tensors with new memory. (e.g., no in-place operators, no
|
| 11 |
+
// view operations).
|
| 12 |
+
//
|
| 13 |
+
// The fallback effectively takes all of the BatchedTensors in `stack`, slices
|
| 14 |
+
// them, and runs `op` on all of the corresponding slices to produce slices
|
| 15 |
+
// of the outputs. The output slices then get `torch.stack`ed to create the
|
| 16 |
+
// final returns.
|
| 17 |
+
//
|
| 18 |
+
// The performance of the fallback is not very good because it introduces an
|
| 19 |
+
// extra copy from stacking the sliced outputs. Because of this, we prefer to
|
| 20 |
+
// write batching rules for operators whenever possible.
|
| 21 |
+
void batchedTensorForLoopFallback(
|
| 22 |
+
const c10::OperatorHandle& op,
|
| 23 |
+
torch::jit::Stack* stack);
|
| 24 |
+
|
| 25 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <bitset>
|
| 4 |
+
|
| 5 |
+
#include <ATen/ArrayRef.h>
|
| 6 |
+
#include <ATen/SmallVector.h>
|
| 7 |
+
#include <ATen/Tensor.h>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
// We assume this in a few other places in the codebase,
|
| 12 |
+
// but there isn't a centralized definition.
|
| 13 |
+
constexpr int64_t kVmapMaxTensorDims = 64;
|
| 14 |
+
|
| 15 |
+
// The valid vmap levels range from [0, 64). This effectively means that we
|
| 16 |
+
// support a maximum of 64 nested vmaps.
|
| 17 |
+
constexpr int64_t kVmapNumLevels = 64;
|
| 18 |
+
|
| 19 |
+
// Store this number of elements of BatchDims on the stack. Most people will
|
| 20 |
+
// probably use <= 5 nested vmaps, but adjust this number as necessary.
|
| 21 |
+
constexpr int64_t kBatchDimsStackSize = 5;
|
| 22 |
+
|
| 23 |
+
// a BatchDim represents a "private" dimension on a Tensor created inside of
|
| 24 |
+
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
|
| 25 |
+
// is being vmap'ed over and the `level` being an identifier for which vmap
|
| 26 |
+
// said dimension was created inside. The `dim` corresponds to a "physical
|
| 27 |
+
// dim" - it is a dimension index on the underlying physical tensor that is
|
| 28 |
+
// being vmapped over.
|
| 29 |
+
struct BatchDim {
|
| 30 |
+
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
|
| 31 |
+
int64_t dim() const {
|
| 32 |
+
return dim_;
|
| 33 |
+
}
|
| 34 |
+
int64_t level() const {
|
| 35 |
+
return level_;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
private:
|
| 39 |
+
int64_t dim_;
|
| 40 |
+
int64_t level_;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
|
| 44 |
+
using BatchDimsRef = ArrayRef<BatchDim>;
|
| 45 |
+
|
| 46 |
+
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
|
| 47 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 48 |
+
// BatchedTensorImpl.
|
| 49 |
+
//
|
| 50 |
+
// The batch dimensions are treated as being "private"; they are not
|
| 51 |
+
// user-visible. For example, in the following Tensor,
|
| 52 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
|
| 53 |
+
// dimensions 0 and 1 are batch dimensions.
|
| 54 |
+
//
|
| 55 |
+
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
|
| 56 |
+
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
|
| 57 |
+
// tensor.
|
| 58 |
+
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
|
| 59 |
+
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
|
| 60 |
+
|
| 61 |
+
// Returns a reference to BatchDims that represent which dimensions of this
|
| 62 |
+
// tensor are private.
|
| 63 |
+
BatchDimsRef bdims() const {
|
| 64 |
+
return bdims_;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// BatchedTensorImpl wraps a Tensor
|
| 68 |
+
const Tensor& value() const {
|
| 69 |
+
return value_;
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
// Given a public dimension index, return the dimension index in the
|
| 73 |
+
// underlying value() tensor. For example, if we have
|
| 74 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
|
| 75 |
+
// dim=2)])
|
| 76 |
+
// bt.actualDim(0) -> 1
|
| 77 |
+
// bt.actualDim(1) -> 3
|
| 78 |
+
// bt.actualDim(2) -> Error
|
| 79 |
+
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
|
| 80 |
+
|
| 81 |
+
// We have to override this because we opted into CustomStrides
|
| 82 |
+
IntArrayRef strides_custom() const override;
|
| 83 |
+
// Override a bunch of methods inherited from TensorImpl to return error
|
| 84 |
+
// messages.
|
| 85 |
+
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
| 86 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 87 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 88 |
+
void set_storage_offset(int64_t storage_offset) override;
|
| 89 |
+
#ifdef DEBUG
|
| 90 |
+
bool has_storage() const override;
|
| 91 |
+
#endif
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
// see NOTE: [BatchedTensorImpl levels invariant]
|
| 95 |
+
void checkInvariants() const;
|
| 96 |
+
const char* tensorimpl_type_name() const override;
|
| 97 |
+
|
| 98 |
+
Tensor value_;
|
| 99 |
+
|
| 100 |
+
// Note: [BatchedTensorImpl levels invariant]
|
| 101 |
+
// There is an invariant that the BatchDims must be stored in increasing
|
| 102 |
+
// `level` order. That is, for i < j, bdims_[i].level must be less than
|
| 103 |
+
// bdims_[j].level.
|
| 104 |
+
BatchDims bdims_;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 108 |
+
// BatchedTensorImpl.
|
| 109 |
+
inline bool isBatchedTensor(const Tensor& tensor) {
|
| 110 |
+
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// It is unsafe to call this on a Tensor that is not backed by a
|
| 114 |
+
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
|
| 115 |
+
inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
|
| 116 |
+
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
|
| 120 |
+
if (!isBatchedTensor(tensor)) {
|
| 121 |
+
return nullptr;
|
| 122 |
+
}
|
| 123 |
+
return unsafeGetBatchedImpl(tensor);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
|
| 127 |
+
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
|
| 128 |
+
BatchDimsRef bdims) {
|
| 129 |
+
std::bitset<kVmapMaxTensorDims> is_bdim;
|
| 130 |
+
for (const auto& bdim : bdims) {
|
| 131 |
+
is_bdim.set(bdim.dim());
|
| 132 |
+
}
|
| 133 |
+
return is_bdim;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// Creates a bitset for all of the levels present in `bdims`
|
| 137 |
+
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
|
| 138 |
+
std::bitset<kVmapNumLevels> result;
|
| 139 |
+
for (const auto& bdim : bdims) {
|
| 140 |
+
result.set(bdim.level());
|
| 141 |
+
}
|
| 142 |
+
return result;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
| 146 |
+
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
|
| 147 |
+
return out;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// Use this to construct a BatchedTensor from a regular Tensor
|
| 151 |
+
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
|
| 152 |
+
|
| 153 |
+
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
| 154 |
+
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
|
| 155 |
+
|
| 156 |
+
// Checks if an inplace operation on self and other is "vmap compatible".
|
| 157 |
+
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
|
| 158 |
+
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
|
| 159 |
+
|
| 160 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/LegacyVmapMode.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 4 |
+
|
| 5 |
+
namespace at::impl {
|
| 6 |
+
|
| 7 |
+
// VmapMode contains a thread local count of how many nested vmaps
|
| 8 |
+
// we are currently inside. That number is known as the `vmap level`.
|
| 9 |
+
// VmapMode is used in the implementation of the Python `torch.vmap` API.
|
| 10 |
+
//
|
| 11 |
+
// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
|
| 12 |
+
|
| 13 |
+
struct TORCH_API VmapMode {
|
| 14 |
+
// Returns the vmap level, aka the count of how many nested vmaps we're in.
|
| 15 |
+
static int64_t current_vmap_level();
|
| 16 |
+
|
| 17 |
+
// Increment the count of nested vmaps. If this causes the vmap level to be
|
| 18 |
+
// greater than 0, then it enables DispatchKey::VmapMode on all tensors.
|
| 19 |
+
static int64_t increment_nesting();
|
| 20 |
+
|
| 21 |
+
// Decrements the count of nested vmaps. If this causes the vmap level to be
|
| 22 |
+
// equal to 0, then it disables DispatchKey::VmapMode on all tensors.
|
| 23 |
+
static int64_t decrement_nesting();
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
} // namespace at::impl
|
.venv/Lib/site-packages/torch/include/ATen/LegacyVmapTransforms.h
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/LegacyBatchedTensorImpl.h>
|
| 4 |
+
#include <ATen/core/IListRef.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
// This file contains abstractions used for transforming *logical* vmap
|
| 9 |
+
// arguments into *physical* arguments. (Keep reading for definitions of these
|
| 10 |
+
// terms).
|
| 11 |
+
|
| 12 |
+
// NOTE: [Logical vs physical args]
|
| 13 |
+
// Consider the following vmap.
|
| 14 |
+
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
|
| 15 |
+
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
|
| 16 |
+
// with batch dims 0 and 2:
|
| 17 |
+
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
|
| 18 |
+
//
|
| 19 |
+
// We say the *logical* view of the tensor has size [3] -- tensors inside
|
| 20 |
+
// `func` appear to have size [3].
|
| 21 |
+
// However, the *physical* underlying tensor (the one passed to vmap) has size
|
| 22 |
+
// [2, 3, 4].
|
| 23 |
+
//
|
| 24 |
+
// This notion of logical vs physical also extends to non-tensor arguments.
|
| 25 |
+
// Consider the previous tensor; let's assume the user called
|
| 26 |
+
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
|
| 27 |
+
// dimension they are reducing over is dim 0 but the physical dim is dim 1
|
| 28 |
+
// (the first non-batch dimension)
|
| 29 |
+
|
| 30 |
+
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
|
| 31 |
+
struct VmapPhysicalView;
|
| 32 |
+
|
| 33 |
+
// Most PyTorch operators take 4 or fewer inputs.
|
| 34 |
+
constexpr int64_t kVmapTransformStaticInputSize = 4;
|
| 35 |
+
using VmapPhysicalViewVec =
|
| 36 |
+
SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
|
| 37 |
+
|
| 38 |
+
// Pytorch generally advertises good performance for <= 5 dims.
|
| 39 |
+
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
|
| 40 |
+
// dimensions to get 8. Adjust this number as necessary
|
| 41 |
+
constexpr int64_t kVmapStaticDimVecSize = 8;
|
| 42 |
+
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
|
| 43 |
+
using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
|
| 44 |
+
|
| 45 |
+
// NOTE: [What is an VmapTransform?]
|
| 46 |
+
// An *VmapTransform* converts logical views of tensors to physical views.
|
| 47 |
+
//
|
| 48 |
+
// Batching rules use VmapTransforms to convert logical arguments to
|
| 49 |
+
// physical arguments, then call one or more at:: operator that handles the
|
| 50 |
+
// physical arguments, and then converts the physical result back to a logical
|
| 51 |
+
// argument.
|
| 52 |
+
|
| 53 |
+
// VmapTransform for operators that take tensors with multiple batch dims.
|
| 54 |
+
// Given one or more logical views on Tensors, `logicalToPhysical`
|
| 55 |
+
// permutes all of the batch dims to the front of the tensor, aligns
|
| 56 |
+
// and expands the batch dims to match each other (according to their `level`),
|
| 57 |
+
// and returns a VmapPhysicalView on the tensor(s).
|
| 58 |
+
struct TORCH_API MultiBatchVmapTransform {
|
| 59 |
+
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
|
| 60 |
+
static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
// VmapTransform for operators that broadcast all inputs.
|
| 64 |
+
// Given some logical views on Tensors, `logicalToPhysical`:
|
| 65 |
+
// - permutes all of the batch dims to the front of the tensors
|
| 66 |
+
// - aligns all the batch dims to the collective levels of all of the tensors.
|
| 67 |
+
// If a tensor does not have a batch dim for a vmap level, then it receives
|
| 68 |
+
// a size-one dimension for said level.
|
| 69 |
+
// - aligns the non-batch dims to have the same dimensionality, adding extra
|
| 70 |
+
// size-1 dimensions in between the batch dimensions and the non-batch
|
| 71 |
+
// dimensions so that the batch dimensions are lined up from the right.
|
| 72 |
+
//
|
| 73 |
+
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
|
| 74 |
+
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
|
| 75 |
+
// tensors of size (B, 1, 2) and (B, 3, 2).
|
| 76 |
+
//
|
| 77 |
+
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
|
| 78 |
+
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
|
| 79 |
+
// actually *need* to return a tensor of size (1, 2) for the second tensor
|
| 80 |
+
// because the broadcasting operation takes care of that for us, but we do
|
| 81 |
+
// it anyways to keep things simple.
|
| 82 |
+
struct TORCH_API BroadcastingVmapTransform {
|
| 83 |
+
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
// Forward declared, if you're reading this file head to toe, don't worry about
|
| 87 |
+
// it yet.
|
| 88 |
+
struct VmapPhysicalToLogicalMap;
|
| 89 |
+
|
| 90 |
+
// NOTE: [What is a VmapPhysicalView?]
|
| 91 |
+
// VmapPhysicalView represents a physical view on a Tensor.
|
| 92 |
+
//
|
| 93 |
+
// One can use it to further convert logical dimension indices, logical shapes,
|
| 94 |
+
// and more to their physical variants, or convert a new (physical) tensor into
|
| 95 |
+
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
|
| 96 |
+
//
|
| 97 |
+
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
|
| 98 |
+
// the front and some levels that correspond to said batch dimensions.
|
| 99 |
+
//
|
| 100 |
+
// The levels bitset specifies which vmap levels correspond to the batch
|
| 101 |
+
// dimensions at the front of the tensor. In particular, the number of set bits
|
| 102 |
+
// corresponds to the number of batch dimensions on `tensor` and the rightmost
|
| 103 |
+
// bit of `levels` specifies the maximum number of nested vmaps we are in at
|
| 104 |
+
// this point in time.
|
| 105 |
+
// For example, given:
|
| 106 |
+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
|
| 107 |
+
//
|
| 108 |
+
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
|
| 109 |
+
// than or equal to 3.
|
| 110 |
+
// bitset: 010100
|
| 111 |
+
// ^
|
| 112 |
+
// |
|
| 113 |
+
// levels: 012345
|
| 114 |
+
struct TORCH_API VmapPhysicalView {
|
| 115 |
+
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
|
| 116 |
+
: levels_(levels), tensor_(std::move(tensor)) {
|
| 117 |
+
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
Tensor& tensor() {
|
| 121 |
+
return tensor_;
|
| 122 |
+
}
|
| 123 |
+
const Tensor& tensor() const {
|
| 124 |
+
return tensor_;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
|
| 128 |
+
//
|
| 129 |
+
// For example, given:
|
| 130 |
+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
|
| 131 |
+
//
|
| 132 |
+
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
|
| 133 |
+
// This is because the size of levels tell us that the first two dimensions
|
| 134 |
+
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
|
| 135 |
+
// a physical dim of `n + 2`.
|
| 136 |
+
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
|
| 137 |
+
int64_t getPhysicalDim(int64_t logical_dim) const;
|
| 138 |
+
|
| 139 |
+
// Returns a VmapPhysicalToLogicalMap object. This can be used for
|
| 140 |
+
// mapping a physical tensor to a new logical tensor (BatchedTensor)
|
| 141 |
+
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
|
| 142 |
+
|
| 143 |
+
// Maps a logical shape to a physical shape by pre-pending the batch
|
| 144 |
+
// sizes to the logical shape.
|
| 145 |
+
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
|
| 146 |
+
|
| 147 |
+
int64_t numBatchDims() const;
|
| 148 |
+
|
| 149 |
+
private:
|
| 150 |
+
int64_t numLogicalDims() const;
|
| 151 |
+
|
| 152 |
+
std::bitset<kVmapNumLevels> levels_;
|
| 153 |
+
Tensor tensor_;
|
| 154 |
+
};
|
| 155 |
+
|
| 156 |
+
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
|
| 157 |
+
// to a logical one (BatchedTensor). It holds some levels that are used to do
|
| 158 |
+
// the mapping and assumes that the batch dimensions in the physical tensor all
|
| 159 |
+
// occur at the front of the tensor.
|
| 160 |
+
struct TORCH_API VmapPhysicalToLogicalMap {
|
| 161 |
+
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
|
| 162 |
+
: levels_(levels) {}
|
| 163 |
+
|
| 164 |
+
// Maps a physical tensor to a new logical tensor (BatchedTensor).
|
| 165 |
+
// Assumes that all of the "batch dimensions" are at the front
|
| 166 |
+
// of the physical tensor. For example, given:
|
| 167 |
+
// - x = rank-4 Tensor with size 2, 3, 5, 7
|
| 168 |
+
// - levels = (2, 4)
|
| 169 |
+
// Returns:
|
| 170 |
+
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
|
| 171 |
+
Tensor apply(const Tensor& physical_tensor) const;
|
| 172 |
+
|
| 173 |
+
// Given a vector of physical tensors,
|
| 174 |
+
// 1. maps each tensor to a new logical tensor. Assumes that all of the
|
| 175 |
+
// "batch dimensions" are at the front of the physical tensors.
|
| 176 |
+
// 2. stores the new logical tensors back into the passed-in vector. This is
|
| 177 |
+
// to avoid additional dynamic allocations.
|
| 178 |
+
void applyInplace(std::vector<Tensor>& physical_tensors) const;
|
| 179 |
+
|
| 180 |
+
std::bitset<kVmapNumLevels> levels_;
|
| 181 |
+
};
|
| 182 |
+
|
| 183 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/LinalgBackend.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/Exception.h>
|
| 4 |
+
|
| 5 |
+
#include <ostream>
|
| 6 |
+
#include <string>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
enum class LinalgBackend : int8_t { Default, Cusolver, Magma };
|
| 11 |
+
|
| 12 |
+
inline std::string LinalgBackendToString(at::LinalgBackend backend) {
|
| 13 |
+
switch (backend) {
|
| 14 |
+
case LinalgBackend::Default:
|
| 15 |
+
return "at::LinalgBackend::Default";
|
| 16 |
+
case LinalgBackend::Cusolver:
|
| 17 |
+
return "at::LinalgBackend::Cusolver";
|
| 18 |
+
case LinalgBackend::Magma:
|
| 19 |
+
return "at::LinalgBackend::Magma";
|
| 20 |
+
default:
|
| 21 |
+
TORCH_CHECK(false, "Unknown linalg backend");
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
inline std::ostream& operator<<(
|
| 26 |
+
std::ostream& stream,
|
| 27 |
+
at::LinalgBackend backend) {
|
| 28 |
+
return stream << LinalgBackendToString(backend);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/MapAllocator.h
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Allocator.h>
|
| 4 |
+
#include <c10/util/string_view.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
enum MappedAllocatorModes {
|
| 9 |
+
ALLOCATOR_MAPPED_SHARED = 1,
|
| 10 |
+
ALLOCATOR_MAPPED_SHAREDMEM = 2,
|
| 11 |
+
ALLOCATOR_MAPPED_EXCLUSIVE = 4,
|
| 12 |
+
ALLOCATOR_MAPPED_NOCREATE = 8,
|
| 13 |
+
ALLOCATOR_MAPPED_KEEPFD = 16,
|
| 14 |
+
ALLOCATOR_MAPPED_FROMFD = 32,
|
| 15 |
+
ALLOCATOR_MAPPED_UNLINK = 64
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
// Sentinel value/type to help distinguish the file descriptor constructor from
|
| 19 |
+
// the non-file descriptor constructor
|
| 20 |
+
enum WithFd { WITH_FD };
|
| 21 |
+
|
| 22 |
+
TORCH_API std::string NewProcessWideShmHandle();
|
| 23 |
+
|
| 24 |
+
class TORCH_API MapAllocator {
|
| 25 |
+
public:
|
| 26 |
+
MapAllocator(c10::string_view filename, int flags, size_t size);
|
| 27 |
+
MapAllocator(
|
| 28 |
+
WithFd,
|
| 29 |
+
c10::string_view filename,
|
| 30 |
+
int fd,
|
| 31 |
+
int flags,
|
| 32 |
+
size_t size);
|
| 33 |
+
MapAllocator(const MapAllocator&) = delete;
|
| 34 |
+
MapAllocator& operator=(const MapAllocator&) = delete;
|
| 35 |
+
MapAllocator(MapAllocator&&) = delete;
|
| 36 |
+
MapAllocator& operator=(MapAllocator&&) = delete;
|
| 37 |
+
|
| 38 |
+
const char* filename() const {
|
| 39 |
+
return filename_.c_str();
|
| 40 |
+
}
|
| 41 |
+
int fd() const {
|
| 42 |
+
#ifdef _WIN32
|
| 43 |
+
TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
|
| 44 |
+
#else
|
| 45 |
+
return fd_;
|
| 46 |
+
#endif
|
| 47 |
+
}
|
| 48 |
+
ptrdiff_t size() const {
|
| 49 |
+
return size_;
|
| 50 |
+
}
|
| 51 |
+
// Return a pointer to the actual data for this allocator
|
| 52 |
+
// (in the case of the refcounted allocator, this is offset
|
| 53 |
+
// from the base pointer.)
|
| 54 |
+
virtual void* data() const {
|
| 55 |
+
return base_ptr_;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
int flags() const {
|
| 59 |
+
return flags_;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
static MapAllocator* fromDataPtr(const at::DataPtr&);
|
| 63 |
+
static at::DataPtr makeDataPtr(
|
| 64 |
+
c10::string_view filename,
|
| 65 |
+
int flags,
|
| 66 |
+
size_t size,
|
| 67 |
+
size_t* actual_size_out);
|
| 68 |
+
static at::DataPtr makeDataPtr(
|
| 69 |
+
WithFd,
|
| 70 |
+
const char* filename,
|
| 71 |
+
int fd,
|
| 72 |
+
int flags,
|
| 73 |
+
size_t size,
|
| 74 |
+
size_t* actual_size_out);
|
| 75 |
+
|
| 76 |
+
// Closes the data. Helps us avoid destructor shenanigans
|
| 77 |
+
virtual void close();
|
| 78 |
+
|
| 79 |
+
// This is very dangerous. You have to redefine this destructor for each
|
| 80 |
+
// subclass
|
| 81 |
+
virtual ~MapAllocator();
|
| 82 |
+
|
| 83 |
+
protected:
|
| 84 |
+
bool closed_ = false;
|
| 85 |
+
std::string filename_;
|
| 86 |
+
int flags_ = 0;
|
| 87 |
+
ptrdiff_t size_; /* mapped size */
|
| 88 |
+
#ifdef _WIN32
|
| 89 |
+
void* handle_;
|
| 90 |
+
void* event_;
|
| 91 |
+
std::string eventname_;
|
| 92 |
+
#else
|
| 93 |
+
int fd_ = -1;
|
| 94 |
+
#endif
|
| 95 |
+
void* base_ptr_ = nullptr;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
// Base-from-member idiom
|
| 99 |
+
struct TORCH_API RefcountedMapAllocatorArgCheck {
|
| 100 |
+
RefcountedMapAllocatorArgCheck(int flags);
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
|
| 104 |
+
public MapAllocator {
|
| 105 |
+
public:
|
| 106 |
+
RefcountedMapAllocator(const char* filename, int flags, size_t size);
|
| 107 |
+
RefcountedMapAllocator(
|
| 108 |
+
WithFd,
|
| 109 |
+
const char* filename,
|
| 110 |
+
int fd,
|
| 111 |
+
int flags,
|
| 112 |
+
size_t size);
|
| 113 |
+
|
| 114 |
+
static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
|
| 115 |
+
static at::DataPtr makeDataPtr(
|
| 116 |
+
const char* filename,
|
| 117 |
+
int flags,
|
| 118 |
+
size_t size,
|
| 119 |
+
size_t* actual_size_out);
|
| 120 |
+
static at::DataPtr makeDataPtr(
|
| 121 |
+
WithFd,
|
| 122 |
+
const char* filename,
|
| 123 |
+
int fd,
|
| 124 |
+
int flags,
|
| 125 |
+
size_t size,
|
| 126 |
+
size_t* actual_size_out);
|
| 127 |
+
|
| 128 |
+
void* data() const override;
|
| 129 |
+
|
| 130 |
+
void incref();
|
| 131 |
+
int decref();
|
| 132 |
+
void close() override;
|
| 133 |
+
|
| 134 |
+
~RefcountedMapAllocator() override {
|
| 135 |
+
RefcountedMapAllocator::close();
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
protected:
|
| 139 |
+
void checkFlags();
|
| 140 |
+
void initializeAlloc();
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/MatrixRef.h
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Utils.h>
|
| 3 |
+
#include <c10/util/ArrayRef.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
|
| 7 |
+
/// we can easily view it as a multidimensional array.
|
| 8 |
+
///
|
| 9 |
+
/// Like ArrayRef, this class does not own the underlying data, it is expected
|
| 10 |
+
/// to be used in situations where the data resides in some other buffer.
|
| 11 |
+
///
|
| 12 |
+
/// This is intended to be trivially copyable, so it should be passed by
|
| 13 |
+
/// value.
|
| 14 |
+
///
|
| 15 |
+
/// For now, 2D only (so the copies are actually cheap, without having
|
| 16 |
+
/// to write a SmallVector class) and contiguous only (so we can
|
| 17 |
+
/// return non-strided ArrayRef on index).
|
| 18 |
+
///
|
| 19 |
+
/// P.S. dimension 0 indexes rows, dimension 1 indexes columns
|
| 20 |
+
template <typename T>
|
| 21 |
+
class MatrixRef {
|
| 22 |
+
public:
|
| 23 |
+
typedef size_t size_type;
|
| 24 |
+
|
| 25 |
+
private:
|
| 26 |
+
/// Underlying ArrayRef
|
| 27 |
+
ArrayRef<T> arr;
|
| 28 |
+
|
| 29 |
+
/// Stride of dim 0 (outer dimension)
|
| 30 |
+
size_type stride0;
|
| 31 |
+
|
| 32 |
+
// Stride of dim 1 is assumed to be 1
|
| 33 |
+
|
| 34 |
+
public:
|
| 35 |
+
/// Construct an empty Matrixref.
|
| 36 |
+
/*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
|
| 37 |
+
|
| 38 |
+
/// Construct an MatrixRef from an ArrayRef and outer stride.
|
| 39 |
+
/*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
|
| 40 |
+
: arr(arr), stride0(stride0) {
|
| 41 |
+
TORCH_CHECK(
|
| 42 |
+
arr.size() % stride0 == 0,
|
| 43 |
+
"MatrixRef: ArrayRef size ",
|
| 44 |
+
arr.size(),
|
| 45 |
+
" not divisible by stride ",
|
| 46 |
+
stride0)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
/// @}
|
| 50 |
+
/// @name Simple Operations
|
| 51 |
+
/// @{
|
| 52 |
+
|
| 53 |
+
/// empty - Check if the matrix is empty.
|
| 54 |
+
bool empty() const {
|
| 55 |
+
return arr.empty();
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
const T* data() const {
|
| 59 |
+
return arr.data();
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/// size - Get size a dimension
|
| 63 |
+
size_t size(size_t dim) const {
|
| 64 |
+
if (dim == 0) {
|
| 65 |
+
return arr.size() / stride0;
|
| 66 |
+
} else if (dim == 1) {
|
| 67 |
+
return stride0;
|
| 68 |
+
} else {
|
| 69 |
+
TORCH_CHECK(
|
| 70 |
+
0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
size_t numel() const {
|
| 75 |
+
return arr.size();
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
/// equals - Check for element-wise equality.
|
| 79 |
+
bool equals(MatrixRef RHS) const {
|
| 80 |
+
return stride0 == RHS.stride0 && arr.equals(RHS.arr);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// @}
|
| 84 |
+
/// @name Operator Overloads
|
| 85 |
+
/// @{
|
| 86 |
+
ArrayRef<T> operator[](size_t Index) const {
|
| 87 |
+
return arr.slice(Index * stride0, stride0);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
/// Disallow accidental assignment from a temporary.
|
| 91 |
+
///
|
| 92 |
+
/// The declaration here is extra complicated so that "arrayRef = {}"
|
| 93 |
+
/// continues to select the move assignment operator.
|
| 94 |
+
template <typename U>
|
| 95 |
+
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
|
| 96 |
+
U&& Temporary) = delete;
|
| 97 |
+
|
| 98 |
+
/// Disallow accidental assignment from a temporary.
|
| 99 |
+
///
|
| 100 |
+
/// The declaration here is extra complicated so that "arrayRef = {}"
|
| 101 |
+
/// continues to select the move assignment operator.
|
| 102 |
+
template <typename U>
|
| 103 |
+
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
|
| 104 |
+
std::initializer_list<U>) = delete;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
} // end namespace at
|
.venv/Lib/site-packages/torch/include/ATen/MemoryOverlap.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Export.h>
|
| 4 |
+
|
| 5 |
+
namespace c10 {
|
| 6 |
+
struct TensorImpl;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
class TensorBase;
|
| 11 |
+
|
| 12 |
+
// MemOverlap: Whether or not there is memory overlap
|
| 13 |
+
//
|
| 14 |
+
// No: Absolutely no memory overlap
|
| 15 |
+
// Yes: Absolutely yes memory overlap
|
| 16 |
+
// TooHard: There might be memory overlap, but it was too expensive to compute.
|
| 17 |
+
//
|
| 18 |
+
// NB: Please update the python test for these if you renumber them.
|
| 19 |
+
enum class MemOverlap { No, Yes, TooHard };
|
| 20 |
+
|
| 21 |
+
enum class MemOverlapStatus { Full, Partial, No, TooHard };
|
| 22 |
+
|
| 23 |
+
TORCH_API MemOverlap has_internal_overlap(const TensorBase& t);
|
| 24 |
+
TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t);
|
| 25 |
+
|
| 26 |
+
TORCH_API void assert_no_internal_overlap(const TensorBase& t);
|
| 27 |
+
TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t);
|
| 28 |
+
|
| 29 |
+
TORCH_API MemOverlapStatus
|
| 30 |
+
get_overlap_status(const TensorBase& a, const TensorBase& b);
|
| 31 |
+
TORCH_API MemOverlapStatus
|
| 32 |
+
get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b);
|
| 33 |
+
|
| 34 |
+
TORCH_API void assert_no_partial_overlap(
|
| 35 |
+
const TensorBase& a,
|
| 36 |
+
const TensorBase& b);
|
| 37 |
+
void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
|
| 38 |
+
|
| 39 |
+
TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b);
|
| 40 |
+
TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
|
| 41 |
+
|
| 42 |
+
} // namespace at
|
.venv/Lib/site-packages/torch/include/ATen/MetaFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/MetaFunctions_inl.h>
|
.venv/Lib/site-packages/torch/include/ATen/MetaFunctions_inl.h
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_meta_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_add_relu_meta_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_addmm_activation_meta_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_amp_update_scale_meta_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_coalesced_meta_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_meta_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_meta_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_ctc_loss_meta_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_efficientzerotensor_meta_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_meta_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_fused_sdp_choice_meta_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_index_put_impl_meta_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_linalg_det_meta_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_linalg_eigh_meta_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_linalg_slogdet_meta_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_linalg_solve_ex_meta_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_linalg_svd_meta_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_log_softmax_meta_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_log_softmax_backward_data_meta_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_mkldnn_transpose_meta_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_reshape_alias_meta_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_resize_output_meta_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_softmax_meta_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_softmax_backward_data_meta_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_meta_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_meta_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_meta_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_meta_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_meta_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_upsample_nearest_exact1d_meta_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_meta_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_upsample_nearest_exact2d_meta_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_meta_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_upsample_nearest_exact3d_meta_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h>
|
| 54 |
+
#include <ATen/ops/acos_meta_dispatch.h>
|
| 55 |
+
#include <ATen/ops/acosh_meta_dispatch.h>
|
| 56 |
+
#include <ATen/ops/adaptive_max_pool2d_meta_dispatch.h>
|
| 57 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h>
|
| 58 |
+
#include <ATen/ops/adaptive_max_pool3d_meta_dispatch.h>
|
| 59 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h>
|
| 60 |
+
#include <ATen/ops/add_meta_dispatch.h>
|
| 61 |
+
#include <ATen/ops/addbmm_meta_dispatch.h>
|
| 62 |
+
#include <ATen/ops/addcdiv_meta_dispatch.h>
|
| 63 |
+
#include <ATen/ops/addcmul_meta_dispatch.h>
|
| 64 |
+
#include <ATen/ops/addmm_meta_dispatch.h>
|
| 65 |
+
#include <ATen/ops/addmv_meta_dispatch.h>
|
| 66 |
+
#include <ATen/ops/all_meta_dispatch.h>
|
| 67 |
+
#include <ATen/ops/amax_meta_dispatch.h>
|
| 68 |
+
#include <ATen/ops/amin_meta_dispatch.h>
|
| 69 |
+
#include <ATen/ops/aminmax_meta_dispatch.h>
|
| 70 |
+
#include <ATen/ops/any_meta_dispatch.h>
|
| 71 |
+
#include <ATen/ops/arange_meta_dispatch.h>
|
| 72 |
+
#include <ATen/ops/argmax_meta_dispatch.h>
|
| 73 |
+
#include <ATen/ops/argmin_meta_dispatch.h>
|
| 74 |
+
#include <ATen/ops/as_strided_meta_dispatch.h>
|
| 75 |
+
#include <ATen/ops/asin_meta_dispatch.h>
|
| 76 |
+
#include <ATen/ops/asinh_meta_dispatch.h>
|
| 77 |
+
#include <ATen/ops/atan_meta_dispatch.h>
|
| 78 |
+
#include <ATen/ops/atan2_meta_dispatch.h>
|
| 79 |
+
#include <ATen/ops/atanh_meta_dispatch.h>
|
| 80 |
+
#include <ATen/ops/avg_pool2d_meta_dispatch.h>
|
| 81 |
+
#include <ATen/ops/avg_pool2d_backward_meta_dispatch.h>
|
| 82 |
+
#include <ATen/ops/avg_pool3d_meta_dispatch.h>
|
| 83 |
+
#include <ATen/ops/avg_pool3d_backward_meta_dispatch.h>
|
| 84 |
+
#include <ATen/ops/baddbmm_meta_dispatch.h>
|
| 85 |
+
#include <ATen/ops/bernoulli_meta_dispatch.h>
|
| 86 |
+
#include <ATen/ops/bitwise_and_meta_dispatch.h>
|
| 87 |
+
#include <ATen/ops/bitwise_left_shift_meta_dispatch.h>
|
| 88 |
+
#include <ATen/ops/bitwise_not_meta_dispatch.h>
|
| 89 |
+
#include <ATen/ops/bitwise_or_meta_dispatch.h>
|
| 90 |
+
#include <ATen/ops/bitwise_right_shift_meta_dispatch.h>
|
| 91 |
+
#include <ATen/ops/bitwise_xor_meta_dispatch.h>
|
| 92 |
+
#include <ATen/ops/bmm_meta_dispatch.h>
|
| 93 |
+
#include <ATen/ops/cat_meta_dispatch.h>
|
| 94 |
+
#include <ATen/ops/cauchy_meta_dispatch.h>
|
| 95 |
+
#include <ATen/ops/ceil_meta_dispatch.h>
|
| 96 |
+
#include <ATen/ops/clamp_meta_dispatch.h>
|
| 97 |
+
#include <ATen/ops/clamp_max_meta_dispatch.h>
|
| 98 |
+
#include <ATen/ops/clamp_min_meta_dispatch.h>
|
| 99 |
+
#include <ATen/ops/copy_meta_dispatch.h>
|
| 100 |
+
#include <ATen/ops/copy_sparse_to_sparse_meta_dispatch.h>
|
| 101 |
+
#include <ATen/ops/copysign_meta_dispatch.h>
|
| 102 |
+
#include <ATen/ops/cos_meta_dispatch.h>
|
| 103 |
+
#include <ATen/ops/cosh_meta_dispatch.h>
|
| 104 |
+
#include <ATen/ops/cumprod_meta_dispatch.h>
|
| 105 |
+
#include <ATen/ops/cumsum_meta_dispatch.h>
|
| 106 |
+
#include <ATen/ops/digamma_meta_dispatch.h>
|
| 107 |
+
#include <ATen/ops/div_meta_dispatch.h>
|
| 108 |
+
#include <ATen/ops/elu_meta_dispatch.h>
|
| 109 |
+
#include <ATen/ops/elu_backward_meta_dispatch.h>
|
| 110 |
+
#include <ATen/ops/embedding_renorm_meta_dispatch.h>
|
| 111 |
+
#include <ATen/ops/empty_meta_dispatch.h>
|
| 112 |
+
#include <ATen/ops/empty_strided_meta_dispatch.h>
|
| 113 |
+
#include <ATen/ops/eq_meta_dispatch.h>
|
| 114 |
+
#include <ATen/ops/erf_meta_dispatch.h>
|
| 115 |
+
#include <ATen/ops/erfc_meta_dispatch.h>
|
| 116 |
+
#include <ATen/ops/erfinv_meta_dispatch.h>
|
| 117 |
+
#include <ATen/ops/exp_meta_dispatch.h>
|
| 118 |
+
#include <ATen/ops/exp2_meta_dispatch.h>
|
| 119 |
+
#include <ATen/ops/expm1_meta_dispatch.h>
|
| 120 |
+
#include <ATen/ops/exponential_meta_dispatch.h>
|
| 121 |
+
#include <ATen/ops/eye_meta_dispatch.h>
|
| 122 |
+
#include <ATen/ops/fill_meta_dispatch.h>
|
| 123 |
+
#include <ATen/ops/floor_meta_dispatch.h>
|
| 124 |
+
#include <ATen/ops/floor_divide_meta_dispatch.h>
|
| 125 |
+
#include <ATen/ops/fmax_meta_dispatch.h>
|
| 126 |
+
#include <ATen/ops/fmin_meta_dispatch.h>
|
| 127 |
+
#include <ATen/ops/fmod_meta_dispatch.h>
|
| 128 |
+
#include <ATen/ops/frac_meta_dispatch.h>
|
| 129 |
+
#include <ATen/ops/fractional_max_pool2d_meta_dispatch.h>
|
| 130 |
+
#include <ATen/ops/fractional_max_pool2d_backward_meta_dispatch.h>
|
| 131 |
+
#include <ATen/ops/fractional_max_pool3d_meta_dispatch.h>
|
| 132 |
+
#include <ATen/ops/gather_meta_dispatch.h>
|
| 133 |
+
#include <ATen/ops/gcd_meta_dispatch.h>
|
| 134 |
+
#include <ATen/ops/ge_meta_dispatch.h>
|
| 135 |
+
#include <ATen/ops/gelu_meta_dispatch.h>
|
| 136 |
+
#include <ATen/ops/gelu_backward_meta_dispatch.h>
|
| 137 |
+
#include <ATen/ops/geometric_meta_dispatch.h>
|
| 138 |
+
#include <ATen/ops/glu_meta_dispatch.h>
|
| 139 |
+
#include <ATen/ops/gt_meta_dispatch.h>
|
| 140 |
+
#include <ATen/ops/hardshrink_meta_dispatch.h>
|
| 141 |
+
#include <ATen/ops/hardshrink_backward_meta_dispatch.h>
|
| 142 |
+
#include <ATen/ops/hardsigmoid_meta_dispatch.h>
|
| 143 |
+
#include <ATen/ops/hardsigmoid_backward_meta_dispatch.h>
|
| 144 |
+
#include <ATen/ops/hardswish_meta_dispatch.h>
|
| 145 |
+
#include <ATen/ops/hardtanh_meta_dispatch.h>
|
| 146 |
+
#include <ATen/ops/heaviside_meta_dispatch.h>
|
| 147 |
+
#include <ATen/ops/hypot_meta_dispatch.h>
|
| 148 |
+
#include <ATen/ops/i0_meta_dispatch.h>
|
| 149 |
+
#include <ATen/ops/igamma_meta_dispatch.h>
|
| 150 |
+
#include <ATen/ops/igammac_meta_dispatch.h>
|
| 151 |
+
#include <ATen/ops/index_meta_dispatch.h>
|
| 152 |
+
#include <ATen/ops/index_add_meta_dispatch.h>
|
| 153 |
+
#include <ATen/ops/index_copy_meta_dispatch.h>
|
| 154 |
+
#include <ATen/ops/index_fill_meta_dispatch.h>
|
| 155 |
+
#include <ATen/ops/index_reduce_meta_dispatch.h>
|
| 156 |
+
#include <ATen/ops/isin_meta_dispatch.h>
|
| 157 |
+
#include <ATen/ops/isneginf_meta_dispatch.h>
|
| 158 |
+
#include <ATen/ops/isposinf_meta_dispatch.h>
|
| 159 |
+
#include <ATen/ops/lcm_meta_dispatch.h>
|
| 160 |
+
#include <ATen/ops/le_meta_dispatch.h>
|
| 161 |
+
#include <ATen/ops/leaky_relu_meta_dispatch.h>
|
| 162 |
+
#include <ATen/ops/leaky_relu_backward_meta_dispatch.h>
|
| 163 |
+
#include <ATen/ops/lerp_meta_dispatch.h>
|
| 164 |
+
#include <ATen/ops/lgamma_meta_dispatch.h>
|
| 165 |
+
#include <ATen/ops/linalg_cholesky_ex_meta_dispatch.h>
|
| 166 |
+
#include <ATen/ops/linalg_cross_meta_dispatch.h>
|
| 167 |
+
#include <ATen/ops/linalg_inv_ex_meta_dispatch.h>
|
| 168 |
+
#include <ATen/ops/linalg_ldl_factor_ex_meta_dispatch.h>
|
| 169 |
+
#include <ATen/ops/linalg_ldl_solve_meta_dispatch.h>
|
| 170 |
+
#include <ATen/ops/linalg_lu_meta_dispatch.h>
|
| 171 |
+
#include <ATen/ops/linalg_lu_factor_ex_meta_dispatch.h>
|
| 172 |
+
#include <ATen/ops/linalg_lu_solve_meta_dispatch.h>
|
| 173 |
+
#include <ATen/ops/linalg_qr_meta_dispatch.h>
|
| 174 |
+
#include <ATen/ops/linalg_vector_norm_meta_dispatch.h>
|
| 175 |
+
#include <ATen/ops/linspace_meta_dispatch.h>
|
| 176 |
+
#include <ATen/ops/log_meta_dispatch.h>
|
| 177 |
+
#include <ATen/ops/log10_meta_dispatch.h>
|
| 178 |
+
#include <ATen/ops/log1p_meta_dispatch.h>
|
| 179 |
+
#include <ATen/ops/log2_meta_dispatch.h>
|
| 180 |
+
#include <ATen/ops/log_normal_meta_dispatch.h>
|
| 181 |
+
#include <ATen/ops/logaddexp_meta_dispatch.h>
|
| 182 |
+
#include <ATen/ops/logaddexp2_meta_dispatch.h>
|
| 183 |
+
#include <ATen/ops/logit_meta_dispatch.h>
|
| 184 |
+
#include <ATen/ops/logit_backward_meta_dispatch.h>
|
| 185 |
+
#include <ATen/ops/logspace_meta_dispatch.h>
|
| 186 |
+
#include <ATen/ops/lshift_meta_dispatch.h>
|
| 187 |
+
#include <ATen/ops/lt_meta_dispatch.h>
|
| 188 |
+
#include <ATen/ops/lu_unpack_meta_dispatch.h>
|
| 189 |
+
#include <ATen/ops/masked_fill_meta_dispatch.h>
|
| 190 |
+
#include <ATen/ops/masked_scatter_meta_dispatch.h>
|
| 191 |
+
#include <ATen/ops/max_meta_dispatch.h>
|
| 192 |
+
#include <ATen/ops/max_pool2d_with_indices_meta_dispatch.h>
|
| 193 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_meta_dispatch.h>
|
| 194 |
+
#include <ATen/ops/maximum_meta_dispatch.h>
|
| 195 |
+
#include <ATen/ops/mean_meta_dispatch.h>
|
| 196 |
+
#include <ATen/ops/min_meta_dispatch.h>
|
| 197 |
+
#include <ATen/ops/minimum_meta_dispatch.h>
|
| 198 |
+
#include <ATen/ops/mish_meta_dispatch.h>
|
| 199 |
+
#include <ATen/ops/mm_meta_dispatch.h>
|
| 200 |
+
#include <ATen/ops/mse_loss_meta_dispatch.h>
|
| 201 |
+
#include <ATen/ops/mul_meta_dispatch.h>
|
| 202 |
+
#include <ATen/ops/ne_meta_dispatch.h>
|
| 203 |
+
#include <ATen/ops/neg_meta_dispatch.h>
|
| 204 |
+
#include <ATen/ops/nextafter_meta_dispatch.h>
|
| 205 |
+
#include <ATen/ops/nll_loss_backward_meta_dispatch.h>
|
| 206 |
+
#include <ATen/ops/nll_loss_forward_meta_dispatch.h>
|
| 207 |
+
#include <ATen/ops/norm_meta_dispatch.h>
|
| 208 |
+
#include <ATen/ops/normal_meta_dispatch.h>
|
| 209 |
+
#include <ATen/ops/polygamma_meta_dispatch.h>
|
| 210 |
+
#include <ATen/ops/pow_meta_dispatch.h>
|
| 211 |
+
#include <ATen/ops/prod_meta_dispatch.h>
|
| 212 |
+
#include <ATen/ops/put_meta_dispatch.h>
|
| 213 |
+
#include <ATen/ops/random_meta_dispatch.h>
|
| 214 |
+
#include <ATen/ops/range_meta_dispatch.h>
|
| 215 |
+
#include <ATen/ops/reciprocal_meta_dispatch.h>
|
| 216 |
+
#include <ATen/ops/reflection_pad1d_meta_dispatch.h>
|
| 217 |
+
#include <ATen/ops/reflection_pad1d_backward_meta_dispatch.h>
|
| 218 |
+
#include <ATen/ops/reflection_pad3d_meta_dispatch.h>
|
| 219 |
+
#include <ATen/ops/reflection_pad3d_backward_meta_dispatch.h>
|
| 220 |
+
#include <ATen/ops/relu_meta_dispatch.h>
|
| 221 |
+
#include <ATen/ops/remainder_meta_dispatch.h>
|
| 222 |
+
#include <ATen/ops/renorm_meta_dispatch.h>
|
| 223 |
+
#include <ATen/ops/replication_pad1d_meta_dispatch.h>
|
| 224 |
+
#include <ATen/ops/replication_pad1d_backward_meta_dispatch.h>
|
| 225 |
+
#include <ATen/ops/replication_pad2d_meta_dispatch.h>
|
| 226 |
+
#include <ATen/ops/replication_pad3d_meta_dispatch.h>
|
| 227 |
+
#include <ATen/ops/resize_meta_dispatch.h>
|
| 228 |
+
#include <ATen/ops/resize_as_sparse_meta_dispatch.h>
|
| 229 |
+
#include <ATen/ops/round_meta_dispatch.h>
|
| 230 |
+
#include <ATen/ops/rrelu_with_noise_meta_dispatch.h>
|
| 231 |
+
#include <ATen/ops/rshift_meta_dispatch.h>
|
| 232 |
+
#include <ATen/ops/rsqrt_meta_dispatch.h>
|
| 233 |
+
#include <ATen/ops/scatter_meta_dispatch.h>
|
| 234 |
+
#include <ATen/ops/scatter_add_meta_dispatch.h>
|
| 235 |
+
#include <ATen/ops/scatter_reduce_meta_dispatch.h>
|
| 236 |
+
#include <ATen/ops/set_meta_dispatch.h>
|
| 237 |
+
#include <ATen/ops/sgn_meta_dispatch.h>
|
| 238 |
+
#include <ATen/ops/sigmoid_meta_dispatch.h>
|
| 239 |
+
#include <ATen/ops/sigmoid_backward_meta_dispatch.h>
|
| 240 |
+
#include <ATen/ops/sign_meta_dispatch.h>
|
| 241 |
+
#include <ATen/ops/signbit_meta_dispatch.h>
|
| 242 |
+
#include <ATen/ops/silu_meta_dispatch.h>
|
| 243 |
+
#include <ATen/ops/silu_backward_meta_dispatch.h>
|
| 244 |
+
#include <ATen/ops/sin_meta_dispatch.h>
|
| 245 |
+
#include <ATen/ops/sinc_meta_dispatch.h>
|
| 246 |
+
#include <ATen/ops/sinh_meta_dispatch.h>
|
| 247 |
+
#include <ATen/ops/slow_conv_transpose2d_meta_dispatch.h>
|
| 248 |
+
#include <ATen/ops/smooth_l1_loss_meta_dispatch.h>
|
| 249 |
+
#include <ATen/ops/softplus_meta_dispatch.h>
|
| 250 |
+
#include <ATen/ops/softplus_backward_meta_dispatch.h>
|
| 251 |
+
#include <ATen/ops/softshrink_meta_dispatch.h>
|
| 252 |
+
#include <ATen/ops/softshrink_backward_meta_dispatch.h>
|
| 253 |
+
#include <ATen/ops/sort_meta_dispatch.h>
|
| 254 |
+
#include <ATen/ops/sparse_resize_meta_dispatch.h>
|
| 255 |
+
#include <ATen/ops/sparse_resize_and_clear_meta_dispatch.h>
|
| 256 |
+
#include <ATen/ops/special_airy_ai_meta_dispatch.h>
|
| 257 |
+
#include <ATen/ops/special_bessel_j0_meta_dispatch.h>
|
| 258 |
+
#include <ATen/ops/special_bessel_j1_meta_dispatch.h>
|
| 259 |
+
#include <ATen/ops/special_bessel_y0_meta_dispatch.h>
|
| 260 |
+
#include <ATen/ops/special_bessel_y1_meta_dispatch.h>
|
| 261 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_meta_dispatch.h>
|
| 262 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_meta_dispatch.h>
|
| 263 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_meta_dispatch.h>
|
| 264 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_meta_dispatch.h>
|
| 265 |
+
#include <ATen/ops/special_entr_meta_dispatch.h>
|
| 266 |
+
#include <ATen/ops/special_erfcx_meta_dispatch.h>
|
| 267 |
+
#include <ATen/ops/special_hermite_polynomial_h_meta_dispatch.h>
|
| 268 |
+
#include <ATen/ops/special_hermite_polynomial_he_meta_dispatch.h>
|
| 269 |
+
#include <ATen/ops/special_i0e_meta_dispatch.h>
|
| 270 |
+
#include <ATen/ops/special_i1_meta_dispatch.h>
|
| 271 |
+
#include <ATen/ops/special_i1e_meta_dispatch.h>
|
| 272 |
+
#include <ATen/ops/special_laguerre_polynomial_l_meta_dispatch.h>
|
| 273 |
+
#include <ATen/ops/special_legendre_polynomial_p_meta_dispatch.h>
|
| 274 |
+
#include <ATen/ops/special_log_ndtr_meta_dispatch.h>
|
| 275 |
+
#include <ATen/ops/special_modified_bessel_i0_meta_dispatch.h>
|
| 276 |
+
#include <ATen/ops/special_modified_bessel_i1_meta_dispatch.h>
|
| 277 |
+
#include <ATen/ops/special_modified_bessel_k0_meta_dispatch.h>
|
| 278 |
+
#include <ATen/ops/special_modified_bessel_k1_meta_dispatch.h>
|
| 279 |
+
#include <ATen/ops/special_ndtri_meta_dispatch.h>
|
| 280 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_meta_dispatch.h>
|
| 281 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_meta_dispatch.h>
|
| 282 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta_dispatch.h>
|
| 283 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta_dispatch.h>
|
| 284 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta_dispatch.h>
|
| 285 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta_dispatch.h>
|
| 286 |
+
#include <ATen/ops/special_spherical_bessel_j0_meta_dispatch.h>
|
| 287 |
+
#include <ATen/ops/special_xlog1py_meta_dispatch.h>
|
| 288 |
+
#include <ATen/ops/special_zeta_meta_dispatch.h>
|
| 289 |
+
#include <ATen/ops/sqrt_meta_dispatch.h>
|
| 290 |
+
#include <ATen/ops/sub_meta_dispatch.h>
|
| 291 |
+
#include <ATen/ops/sum_meta_dispatch.h>
|
| 292 |
+
#include <ATen/ops/tan_meta_dispatch.h>
|
| 293 |
+
#include <ATen/ops/tanh_meta_dispatch.h>
|
| 294 |
+
#include <ATen/ops/tanh_backward_meta_dispatch.h>
|
| 295 |
+
#include <ATen/ops/threshold_meta_dispatch.h>
|
| 296 |
+
#include <ATen/ops/threshold_backward_meta_dispatch.h>
|
| 297 |
+
#include <ATen/ops/topk_meta_dispatch.h>
|
| 298 |
+
#include <ATen/ops/triangular_solve_meta_dispatch.h>
|
| 299 |
+
#include <ATen/ops/tril_meta_dispatch.h>
|
| 300 |
+
#include <ATen/ops/triu_meta_dispatch.h>
|
| 301 |
+
#include <ATen/ops/trunc_meta_dispatch.h>
|
| 302 |
+
#include <ATen/ops/unfold_meta_dispatch.h>
|
| 303 |
+
#include <ATen/ops/uniform_meta_dispatch.h>
|
| 304 |
+
#include <ATen/ops/upsample_bicubic2d_meta_dispatch.h>
|
| 305 |
+
#include <ATen/ops/upsample_bicubic2d_backward_meta_dispatch.h>
|
| 306 |
+
#include <ATen/ops/upsample_bilinear2d_meta_dispatch.h>
|
| 307 |
+
#include <ATen/ops/upsample_bilinear2d_backward_meta_dispatch.h>
|
| 308 |
+
#include <ATen/ops/upsample_linear1d_meta_dispatch.h>
|
| 309 |
+
#include <ATen/ops/upsample_linear1d_backward_meta_dispatch.h>
|
| 310 |
+
#include <ATen/ops/upsample_nearest1d_meta_dispatch.h>
|
| 311 |
+
#include <ATen/ops/upsample_nearest1d_backward_meta_dispatch.h>
|
| 312 |
+
#include <ATen/ops/upsample_nearest2d_meta_dispatch.h>
|
| 313 |
+
#include <ATen/ops/upsample_nearest2d_backward_meta_dispatch.h>
|
| 314 |
+
#include <ATen/ops/upsample_nearest3d_meta_dispatch.h>
|
| 315 |
+
#include <ATen/ops/upsample_nearest3d_backward_meta_dispatch.h>
|
| 316 |
+
#include <ATen/ops/upsample_trilinear3d_meta_dispatch.h>
|
| 317 |
+
#include <ATen/ops/upsample_trilinear3d_backward_meta_dispatch.h>
|
| 318 |
+
#include <ATen/ops/view_meta_dispatch.h>
|
| 319 |
+
#include <ATen/ops/view_as_complex_meta_dispatch.h>
|
| 320 |
+
#include <ATen/ops/view_as_real_meta_dispatch.h>
|
| 321 |
+
#include <ATen/ops/xlogy_meta_dispatch.h>
|
| 322 |
+
#include <ATen/ops/zero_meta_dispatch.h>
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
|
.venv/Lib/site-packages/torch/include/ATen/MethodOperators.h
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from MethodOperators.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 14 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 15 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 16 |
+
#include <ATen/core/ATen_fwd.h>
|
| 17 |
+
|
| 18 |
+
#include <ATen/ops/_addmm_activation_ops.h>
|
| 19 |
+
#include <ATen/ops/_autocast_to_full_precision_ops.h>
|
| 20 |
+
#include <ATen/ops/_autocast_to_reduced_precision_ops.h>
|
| 21 |
+
#include <ATen/ops/_backward_ops.h>
|
| 22 |
+
#include <ATen/ops/_coalesced_ops.h>
|
| 23 |
+
#include <ATen/ops/_conj_ops.h>
|
| 24 |
+
#include <ATen/ops/_conj_physical_ops.h>
|
| 25 |
+
#include <ATen/ops/_dimI_ops.h>
|
| 26 |
+
#include <ATen/ops/_dimV_ops.h>
|
| 27 |
+
#include <ATen/ops/_fw_primal_ops.h>
|
| 28 |
+
#include <ATen/ops/_indices_ops.h>
|
| 29 |
+
#include <ATen/ops/_is_all_true_ops.h>
|
| 30 |
+
#include <ATen/ops/_is_any_true_ops.h>
|
| 31 |
+
#include <ATen/ops/_is_zerotensor_ops.h>
|
| 32 |
+
#include <ATen/ops/_lazy_clone_ops.h>
|
| 33 |
+
#include <ATen/ops/_neg_view_ops.h>
|
| 34 |
+
#include <ATen/ops/_nested_tensor_size_ops.h>
|
| 35 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
|
| 36 |
+
#include <ATen/ops/_nested_tensor_strides_ops.h>
|
| 37 |
+
#include <ATen/ops/_nnz_ops.h>
|
| 38 |
+
#include <ATen/ops/_reshape_alias_ops.h>
|
| 39 |
+
#include <ATen/ops/_sparse_mask_projection_ops.h>
|
| 40 |
+
#include <ATen/ops/_to_dense_ops.h>
|
| 41 |
+
#include <ATen/ops/_to_sparse_bsc_ops.h>
|
| 42 |
+
#include <ATen/ops/_to_sparse_bsr_ops.h>
|
| 43 |
+
#include <ATen/ops/_to_sparse_csc_ops.h>
|
| 44 |
+
#include <ATen/ops/_to_sparse_csr_ops.h>
|
| 45 |
+
#include <ATen/ops/_to_sparse_ops.h>
|
| 46 |
+
#include <ATen/ops/_values_ops.h>
|
| 47 |
+
#include <ATen/ops/_version_ops.h>
|
| 48 |
+
#include <ATen/ops/abs_ops.h>
|
| 49 |
+
#include <ATen/ops/absolute_ops.h>
|
| 50 |
+
#include <ATen/ops/acos_ops.h>
|
| 51 |
+
#include <ATen/ops/acosh_ops.h>
|
| 52 |
+
#include <ATen/ops/add_ops.h>
|
| 53 |
+
#include <ATen/ops/addbmm_ops.h>
|
| 54 |
+
#include <ATen/ops/addcdiv_ops.h>
|
| 55 |
+
#include <ATen/ops/addcmul_ops.h>
|
| 56 |
+
#include <ATen/ops/addmm_ops.h>
|
| 57 |
+
#include <ATen/ops/addmv_ops.h>
|
| 58 |
+
#include <ATen/ops/addr_ops.h>
|
| 59 |
+
#include <ATen/ops/adjoint_ops.h>
|
| 60 |
+
#include <ATen/ops/alias_ops.h>
|
| 61 |
+
#include <ATen/ops/align_as_ops.h>
|
| 62 |
+
#include <ATen/ops/align_to_ops.h>
|
| 63 |
+
#include <ATen/ops/all_ops.h>
|
| 64 |
+
#include <ATen/ops/allclose_ops.h>
|
| 65 |
+
#include <ATen/ops/amax_ops.h>
|
| 66 |
+
#include <ATen/ops/amin_ops.h>
|
| 67 |
+
#include <ATen/ops/aminmax_ops.h>
|
| 68 |
+
#include <ATen/ops/and_ops.h>
|
| 69 |
+
#include <ATen/ops/angle_ops.h>
|
| 70 |
+
#include <ATen/ops/any_ops.h>
|
| 71 |
+
#include <ATen/ops/arccos_ops.h>
|
| 72 |
+
#include <ATen/ops/arccosh_ops.h>
|
| 73 |
+
#include <ATen/ops/arcsin_ops.h>
|
| 74 |
+
#include <ATen/ops/arcsinh_ops.h>
|
| 75 |
+
#include <ATen/ops/arctan2_ops.h>
|
| 76 |
+
#include <ATen/ops/arctan_ops.h>
|
| 77 |
+
#include <ATen/ops/arctanh_ops.h>
|
| 78 |
+
#include <ATen/ops/argmax_ops.h>
|
| 79 |
+
#include <ATen/ops/argmin_ops.h>
|
| 80 |
+
#include <ATen/ops/argsort_ops.h>
|
| 81 |
+
#include <ATen/ops/argwhere_ops.h>
|
| 82 |
+
#include <ATen/ops/as_strided_ops.h>
|
| 83 |
+
#include <ATen/ops/as_strided_scatter_ops.h>
|
| 84 |
+
#include <ATen/ops/asin_ops.h>
|
| 85 |
+
#include <ATen/ops/asinh_ops.h>
|
| 86 |
+
#include <ATen/ops/atan2_ops.h>
|
| 87 |
+
#include <ATen/ops/atan_ops.h>
|
| 88 |
+
#include <ATen/ops/atanh_ops.h>
|
| 89 |
+
#include <ATen/ops/baddbmm_ops.h>
|
| 90 |
+
#include <ATen/ops/bernoulli_ops.h>
|
| 91 |
+
#include <ATen/ops/bincount_ops.h>
|
| 92 |
+
#include <ATen/ops/bitwise_and_ops.h>
|
| 93 |
+
#include <ATen/ops/bitwise_left_shift_ops.h>
|
| 94 |
+
#include <ATen/ops/bitwise_not_ops.h>
|
| 95 |
+
#include <ATen/ops/bitwise_or_ops.h>
|
| 96 |
+
#include <ATen/ops/bitwise_right_shift_ops.h>
|
| 97 |
+
#include <ATen/ops/bitwise_xor_ops.h>
|
| 98 |
+
#include <ATen/ops/bmm_ops.h>
|
| 99 |
+
#include <ATen/ops/broadcast_to_ops.h>
|
| 100 |
+
#include <ATen/ops/cauchy_ops.h>
|
| 101 |
+
#include <ATen/ops/ccol_indices_ops.h>
|
| 102 |
+
#include <ATen/ops/ceil_ops.h>
|
| 103 |
+
#include <ATen/ops/chalf_ops.h>
|
| 104 |
+
#include <ATen/ops/cholesky_inverse_ops.h>
|
| 105 |
+
#include <ATen/ops/cholesky_ops.h>
|
| 106 |
+
#include <ATen/ops/cholesky_solve_ops.h>
|
| 107 |
+
#include <ATen/ops/chunk_ops.h>
|
| 108 |
+
#include <ATen/ops/clamp_max_ops.h>
|
| 109 |
+
#include <ATen/ops/clamp_min_ops.h>
|
| 110 |
+
#include <ATen/ops/clamp_ops.h>
|
| 111 |
+
#include <ATen/ops/clip_ops.h>
|
| 112 |
+
#include <ATen/ops/clone_ops.h>
|
| 113 |
+
#include <ATen/ops/coalesce_ops.h>
|
| 114 |
+
#include <ATen/ops/col_indices_ops.h>
|
| 115 |
+
#include <ATen/ops/conj_ops.h>
|
| 116 |
+
#include <ATen/ops/conj_physical_ops.h>
|
| 117 |
+
#include <ATen/ops/contiguous_ops.h>
|
| 118 |
+
#include <ATen/ops/copy_ops.h>
|
| 119 |
+
#include <ATen/ops/copysign_ops.h>
|
| 120 |
+
#include <ATen/ops/corrcoef_ops.h>
|
| 121 |
+
#include <ATen/ops/cos_ops.h>
|
| 122 |
+
#include <ATen/ops/cosh_ops.h>
|
| 123 |
+
#include <ATen/ops/count_nonzero_ops.h>
|
| 124 |
+
#include <ATen/ops/cov_ops.h>
|
| 125 |
+
#include <ATen/ops/cross_ops.h>
|
| 126 |
+
#include <ATen/ops/crow_indices_ops.h>
|
| 127 |
+
#include <ATen/ops/cummax_ops.h>
|
| 128 |
+
#include <ATen/ops/cummin_ops.h>
|
| 129 |
+
#include <ATen/ops/cumprod_ops.h>
|
| 130 |
+
#include <ATen/ops/cumsum_ops.h>
|
| 131 |
+
#include <ATen/ops/data_ops.h>
|
| 132 |
+
#include <ATen/ops/deg2rad_ops.h>
|
| 133 |
+
#include <ATen/ops/dense_dim_ops.h>
|
| 134 |
+
#include <ATen/ops/dequantize_ops.h>
|
| 135 |
+
#include <ATen/ops/det_ops.h>
|
| 136 |
+
#include <ATen/ops/detach_ops.h>
|
| 137 |
+
#include <ATen/ops/diag_embed_ops.h>
|
| 138 |
+
#include <ATen/ops/diag_ops.h>
|
| 139 |
+
#include <ATen/ops/diagflat_ops.h>
|
| 140 |
+
#include <ATen/ops/diagonal_ops.h>
|
| 141 |
+
#include <ATen/ops/diagonal_scatter_ops.h>
|
| 142 |
+
#include <ATen/ops/diff_ops.h>
|
| 143 |
+
#include <ATen/ops/digamma_ops.h>
|
| 144 |
+
#include <ATen/ops/dist_ops.h>
|
| 145 |
+
#include <ATen/ops/div_ops.h>
|
| 146 |
+
#include <ATen/ops/divide_ops.h>
|
| 147 |
+
#include <ATen/ops/dot_ops.h>
|
| 148 |
+
#include <ATen/ops/dsplit_ops.h>
|
| 149 |
+
#include <ATen/ops/eq_ops.h>
|
| 150 |
+
#include <ATen/ops/equal_ops.h>
|
| 151 |
+
#include <ATen/ops/erf_ops.h>
|
| 152 |
+
#include <ATen/ops/erfc_ops.h>
|
| 153 |
+
#include <ATen/ops/erfinv_ops.h>
|
| 154 |
+
#include <ATen/ops/exp2_ops.h>
|
| 155 |
+
#include <ATen/ops/exp_ops.h>
|
| 156 |
+
#include <ATen/ops/expand_as_ops.h>
|
| 157 |
+
#include <ATen/ops/expand_ops.h>
|
| 158 |
+
#include <ATen/ops/expm1_ops.h>
|
| 159 |
+
#include <ATen/ops/exponential_ops.h>
|
| 160 |
+
#include <ATen/ops/fill_diagonal_ops.h>
|
| 161 |
+
#include <ATen/ops/fill_ops.h>
|
| 162 |
+
#include <ATen/ops/fix_ops.h>
|
| 163 |
+
#include <ATen/ops/flatten_ops.h>
|
| 164 |
+
#include <ATen/ops/flip_ops.h>
|
| 165 |
+
#include <ATen/ops/fliplr_ops.h>
|
| 166 |
+
#include <ATen/ops/flipud_ops.h>
|
| 167 |
+
#include <ATen/ops/float_power_ops.h>
|
| 168 |
+
#include <ATen/ops/floor_divide_ops.h>
|
| 169 |
+
#include <ATen/ops/floor_ops.h>
|
| 170 |
+
#include <ATen/ops/fmax_ops.h>
|
| 171 |
+
#include <ATen/ops/fmin_ops.h>
|
| 172 |
+
#include <ATen/ops/fmod_ops.h>
|
| 173 |
+
#include <ATen/ops/frac_ops.h>
|
| 174 |
+
#include <ATen/ops/frexp_ops.h>
|
| 175 |
+
#include <ATen/ops/gather_ops.h>
|
| 176 |
+
#include <ATen/ops/gcd_ops.h>
|
| 177 |
+
#include <ATen/ops/ge_ops.h>
|
| 178 |
+
#include <ATen/ops/geometric_ops.h>
|
| 179 |
+
#include <ATen/ops/geqrf_ops.h>
|
| 180 |
+
#include <ATen/ops/ger_ops.h>
|
| 181 |
+
#include <ATen/ops/greater_equal_ops.h>
|
| 182 |
+
#include <ATen/ops/greater_ops.h>
|
| 183 |
+
#include <ATen/ops/gt_ops.h>
|
| 184 |
+
#include <ATen/ops/hardshrink_backward_ops.h>
|
| 185 |
+
#include <ATen/ops/hardshrink_ops.h>
|
| 186 |
+
#include <ATen/ops/heaviside_ops.h>
|
| 187 |
+
#include <ATen/ops/histc_ops.h>
|
| 188 |
+
#include <ATen/ops/histogram_ops.h>
|
| 189 |
+
#include <ATen/ops/hsplit_ops.h>
|
| 190 |
+
#include <ATen/ops/hypot_ops.h>
|
| 191 |
+
#include <ATen/ops/i0_ops.h>
|
| 192 |
+
#include <ATen/ops/igamma_ops.h>
|
| 193 |
+
#include <ATen/ops/igammac_ops.h>
|
| 194 |
+
#include <ATen/ops/index_add_ops.h>
|
| 195 |
+
#include <ATen/ops/index_copy_ops.h>
|
| 196 |
+
#include <ATen/ops/index_fill_ops.h>
|
| 197 |
+
#include <ATen/ops/index_ops.h>
|
| 198 |
+
#include <ATen/ops/index_put_ops.h>
|
| 199 |
+
#include <ATen/ops/index_reduce_ops.h>
|
| 200 |
+
#include <ATen/ops/index_select_ops.h>
|
| 201 |
+
#include <ATen/ops/indices_ops.h>
|
| 202 |
+
#include <ATen/ops/inner_ops.h>
|
| 203 |
+
#include <ATen/ops/int_repr_ops.h>
|
| 204 |
+
#include <ATen/ops/inverse_ops.h>
|
| 205 |
+
#include <ATen/ops/is_coalesced_ops.h>
|
| 206 |
+
#include <ATen/ops/is_complex_ops.h>
|
| 207 |
+
#include <ATen/ops/is_conj_ops.h>
|
| 208 |
+
#include <ATen/ops/is_distributed_ops.h>
|
| 209 |
+
#include <ATen/ops/is_floating_point_ops.h>
|
| 210 |
+
#include <ATen/ops/is_inference_ops.h>
|
| 211 |
+
#include <ATen/ops/is_leaf_ops.h>
|
| 212 |
+
#include <ATen/ops/is_neg_ops.h>
|
| 213 |
+
#include <ATen/ops/is_nonzero_ops.h>
|
| 214 |
+
#include <ATen/ops/is_pinned_ops.h>
|
| 215 |
+
#include <ATen/ops/is_same_size_ops.h>
|
| 216 |
+
#include <ATen/ops/is_set_to_ops.h>
|
| 217 |
+
#include <ATen/ops/is_signed_ops.h>
|
| 218 |
+
#include <ATen/ops/isclose_ops.h>
|
| 219 |
+
#include <ATen/ops/isfinite_ops.h>
|
| 220 |
+
#include <ATen/ops/isinf_ops.h>
|
| 221 |
+
#include <ATen/ops/isnan_ops.h>
|
| 222 |
+
#include <ATen/ops/isneginf_ops.h>
|
| 223 |
+
#include <ATen/ops/isposinf_ops.h>
|
| 224 |
+
#include <ATen/ops/isreal_ops.h>
|
| 225 |
+
#include <ATen/ops/istft_ops.h>
|
| 226 |
+
#include <ATen/ops/item_ops.h>
|
| 227 |
+
#include <ATen/ops/kron_ops.h>
|
| 228 |
+
#include <ATen/ops/kthvalue_ops.h>
|
| 229 |
+
#include <ATen/ops/lcm_ops.h>
|
| 230 |
+
#include <ATen/ops/ldexp_ops.h>
|
| 231 |
+
#include <ATen/ops/le_ops.h>
|
| 232 |
+
#include <ATen/ops/lerp_ops.h>
|
| 233 |
+
#include <ATen/ops/less_equal_ops.h>
|
| 234 |
+
#include <ATen/ops/less_ops.h>
|
| 235 |
+
#include <ATen/ops/lgamma_ops.h>
|
| 236 |
+
#include <ATen/ops/log10_ops.h>
|
| 237 |
+
#include <ATen/ops/log1p_ops.h>
|
| 238 |
+
#include <ATen/ops/log2_ops.h>
|
| 239 |
+
#include <ATen/ops/log_normal_ops.h>
|
| 240 |
+
#include <ATen/ops/log_ops.h>
|
| 241 |
+
#include <ATen/ops/log_softmax_ops.h>
|
| 242 |
+
#include <ATen/ops/logaddexp2_ops.h>
|
| 243 |
+
#include <ATen/ops/logaddexp_ops.h>
|
| 244 |
+
#include <ATen/ops/logcumsumexp_ops.h>
|
| 245 |
+
#include <ATen/ops/logdet_ops.h>
|
| 246 |
+
#include <ATen/ops/logical_and_ops.h>
|
| 247 |
+
#include <ATen/ops/logical_not_ops.h>
|
| 248 |
+
#include <ATen/ops/logical_or_ops.h>
|
| 249 |
+
#include <ATen/ops/logical_xor_ops.h>
|
| 250 |
+
#include <ATen/ops/logit_ops.h>
|
| 251 |
+
#include <ATen/ops/logsumexp_ops.h>
|
| 252 |
+
#include <ATen/ops/lshift_ops.h>
|
| 253 |
+
#include <ATen/ops/lt_ops.h>
|
| 254 |
+
#include <ATen/ops/lu_solve_ops.h>
|
| 255 |
+
#include <ATen/ops/mH_ops.h>
|
| 256 |
+
#include <ATen/ops/mT_ops.h>
|
| 257 |
+
#include <ATen/ops/masked_fill_ops.h>
|
| 258 |
+
#include <ATen/ops/masked_scatter_ops.h>
|
| 259 |
+
#include <ATen/ops/masked_select_ops.h>
|
| 260 |
+
#include <ATen/ops/matmul_ops.h>
|
| 261 |
+
#include <ATen/ops/matrix_H_ops.h>
|
| 262 |
+
#include <ATen/ops/matrix_exp_ops.h>
|
| 263 |
+
#include <ATen/ops/matrix_power_ops.h>
|
| 264 |
+
#include <ATen/ops/max_ops.h>
|
| 265 |
+
#include <ATen/ops/maximum_ops.h>
|
| 266 |
+
#include <ATen/ops/mean_ops.h>
|
| 267 |
+
#include <ATen/ops/median_ops.h>
|
| 268 |
+
#include <ATen/ops/min_ops.h>
|
| 269 |
+
#include <ATen/ops/minimum_ops.h>
|
| 270 |
+
#include <ATen/ops/mm_ops.h>
|
| 271 |
+
#include <ATen/ops/mode_ops.h>
|
| 272 |
+
#include <ATen/ops/moveaxis_ops.h>
|
| 273 |
+
#include <ATen/ops/movedim_ops.h>
|
| 274 |
+
#include <ATen/ops/msort_ops.h>
|
| 275 |
+
#include <ATen/ops/mul_ops.h>
|
| 276 |
+
#include <ATen/ops/multinomial_ops.h>
|
| 277 |
+
#include <ATen/ops/multiply_ops.h>
|
| 278 |
+
#include <ATen/ops/mv_ops.h>
|
| 279 |
+
#include <ATen/ops/mvlgamma_ops.h>
|
| 280 |
+
#include <ATen/ops/nan_to_num_ops.h>
|
| 281 |
+
#include <ATen/ops/nanmean_ops.h>
|
| 282 |
+
#include <ATen/ops/nanmedian_ops.h>
|
| 283 |
+
#include <ATen/ops/nanquantile_ops.h>
|
| 284 |
+
#include <ATen/ops/nansum_ops.h>
|
| 285 |
+
#include <ATen/ops/narrow_copy_ops.h>
|
| 286 |
+
#include <ATen/ops/narrow_ops.h>
|
| 287 |
+
#include <ATen/ops/ne_ops.h>
|
| 288 |
+
#include <ATen/ops/neg_ops.h>
|
| 289 |
+
#include <ATen/ops/negative_ops.h>
|
| 290 |
+
#include <ATen/ops/new_empty_ops.h>
|
| 291 |
+
#include <ATen/ops/new_empty_strided_ops.h>
|
| 292 |
+
#include <ATen/ops/new_full_ops.h>
|
| 293 |
+
#include <ATen/ops/new_ones_ops.h>
|
| 294 |
+
#include <ATen/ops/new_zeros_ops.h>
|
| 295 |
+
#include <ATen/ops/nextafter_ops.h>
|
| 296 |
+
#include <ATen/ops/nonzero_numpy_ops.h>
|
| 297 |
+
#include <ATen/ops/nonzero_ops.h>
|
| 298 |
+
#include <ATen/ops/nonzero_static_ops.h>
|
| 299 |
+
#include <ATen/ops/norm_ops.h>
|
| 300 |
+
#include <ATen/ops/normal_ops.h>
|
| 301 |
+
#include <ATen/ops/not_equal_ops.h>
|
| 302 |
+
#include <ATen/ops/numpy_T_ops.h>
|
| 303 |
+
#include <ATen/ops/or_ops.h>
|
| 304 |
+
#include <ATen/ops/orgqr_ops.h>
|
| 305 |
+
#include <ATen/ops/ormqr_ops.h>
|
| 306 |
+
#include <ATen/ops/outer_ops.h>
|
| 307 |
+
#include <ATen/ops/output_nr_ops.h>
|
| 308 |
+
#include <ATen/ops/permute_ops.h>
|
| 309 |
+
#include <ATen/ops/pin_memory_ops.h>
|
| 310 |
+
#include <ATen/ops/pinverse_ops.h>
|
| 311 |
+
#include <ATen/ops/polygamma_ops.h>
|
| 312 |
+
#include <ATen/ops/positive_ops.h>
|
| 313 |
+
#include <ATen/ops/pow_ops.h>
|
| 314 |
+
#include <ATen/ops/prelu_ops.h>
|
| 315 |
+
#include <ATen/ops/prod_ops.h>
|
| 316 |
+
#include <ATen/ops/put_ops.h>
|
| 317 |
+
#include <ATen/ops/q_per_channel_axis_ops.h>
|
| 318 |
+
#include <ATen/ops/q_per_channel_scales_ops.h>
|
| 319 |
+
#include <ATen/ops/q_per_channel_zero_points_ops.h>
|
| 320 |
+
#include <ATen/ops/q_scale_ops.h>
|
| 321 |
+
#include <ATen/ops/q_zero_point_ops.h>
|
| 322 |
+
#include <ATen/ops/qr_ops.h>
|
| 323 |
+
#include <ATen/ops/qscheme_ops.h>
|
| 324 |
+
#include <ATen/ops/quantile_ops.h>
|
| 325 |
+
#include <ATen/ops/rad2deg_ops.h>
|
| 326 |
+
#include <ATen/ops/random_ops.h>
|
| 327 |
+
#include <ATen/ops/ravel_ops.h>
|
| 328 |
+
#include <ATen/ops/reciprocal_ops.h>
|
| 329 |
+
#include <ATen/ops/record_stream_ops.h>
|
| 330 |
+
#include <ATen/ops/refine_names_ops.h>
|
| 331 |
+
#include <ATen/ops/relu_ops.h>
|
| 332 |
+
#include <ATen/ops/remainder_ops.h>
|
| 333 |
+
#include <ATen/ops/rename_ops.h>
|
| 334 |
+
#include <ATen/ops/renorm_ops.h>
|
| 335 |
+
#include <ATen/ops/repeat_interleave_ops.h>
|
| 336 |
+
#include <ATen/ops/repeat_ops.h>
|
| 337 |
+
#include <ATen/ops/requires_grad_ops.h>
|
| 338 |
+
#include <ATen/ops/reshape_as_ops.h>
|
| 339 |
+
#include <ATen/ops/reshape_ops.h>
|
| 340 |
+
#include <ATen/ops/resize_as_ops.h>
|
| 341 |
+
#include <ATen/ops/resize_as_sparse_ops.h>
|
| 342 |
+
#include <ATen/ops/resize_ops.h>
|
| 343 |
+
#include <ATen/ops/resolve_conj_ops.h>
|
| 344 |
+
#include <ATen/ops/resolve_neg_ops.h>
|
| 345 |
+
#include <ATen/ops/retain_grad_ops.h>
|
| 346 |
+
#include <ATen/ops/retains_grad_ops.h>
|
| 347 |
+
#include <ATen/ops/roll_ops.h>
|
| 348 |
+
#include <ATen/ops/rot90_ops.h>
|
| 349 |
+
#include <ATen/ops/round_ops.h>
|
| 350 |
+
#include <ATen/ops/row_indices_ops.h>
|
| 351 |
+
#include <ATen/ops/rshift_ops.h>
|
| 352 |
+
#include <ATen/ops/rsqrt_ops.h>
|
| 353 |
+
#include <ATen/ops/scatter_add_ops.h>
|
| 354 |
+
#include <ATen/ops/scatter_ops.h>
|
| 355 |
+
#include <ATen/ops/scatter_reduce_ops.h>
|
| 356 |
+
#include <ATen/ops/select_ops.h>
|
| 357 |
+
#include <ATen/ops/select_scatter_ops.h>
|
| 358 |
+
#include <ATen/ops/set_data_ops.h>
|
| 359 |
+
#include <ATen/ops/set_ops.h>
|
| 360 |
+
#include <ATen/ops/sgn_ops.h>
|
| 361 |
+
#include <ATen/ops/sigmoid_ops.h>
|
| 362 |
+
#include <ATen/ops/sign_ops.h>
|
| 363 |
+
#include <ATen/ops/signbit_ops.h>
|
| 364 |
+
#include <ATen/ops/sin_ops.h>
|
| 365 |
+
#include <ATen/ops/sinc_ops.h>
|
| 366 |
+
#include <ATen/ops/sinh_ops.h>
|
| 367 |
+
#include <ATen/ops/size_ops.h>
|
| 368 |
+
#include <ATen/ops/slice_inverse_ops.h>
|
| 369 |
+
#include <ATen/ops/slice_ops.h>
|
| 370 |
+
#include <ATen/ops/slice_scatter_ops.h>
|
| 371 |
+
#include <ATen/ops/slogdet_ops.h>
|
| 372 |
+
#include <ATen/ops/smm_ops.h>
|
| 373 |
+
#include <ATen/ops/softmax_ops.h>
|
| 374 |
+
#include <ATen/ops/sort_ops.h>
|
| 375 |
+
#include <ATen/ops/sparse_dim_ops.h>
|
| 376 |
+
#include <ATen/ops/sparse_mask_ops.h>
|
| 377 |
+
#include <ATen/ops/sparse_resize_and_clear_ops.h>
|
| 378 |
+
#include <ATen/ops/sparse_resize_ops.h>
|
| 379 |
+
#include <ATen/ops/split_ops.h>
|
| 380 |
+
#include <ATen/ops/split_with_sizes_ops.h>
|
| 381 |
+
#include <ATen/ops/sqrt_ops.h>
|
| 382 |
+
#include <ATen/ops/square_ops.h>
|
| 383 |
+
#include <ATen/ops/squeeze_ops.h>
|
| 384 |
+
#include <ATen/ops/sspaddmm_ops.h>
|
| 385 |
+
#include <ATen/ops/std_ops.h>
|
| 386 |
+
#include <ATen/ops/stft_ops.h>
|
| 387 |
+
#include <ATen/ops/stride_ops.h>
|
| 388 |
+
#include <ATen/ops/sub_ops.h>
|
| 389 |
+
#include <ATen/ops/subtract_ops.h>
|
| 390 |
+
#include <ATen/ops/sum_ops.h>
|
| 391 |
+
#include <ATen/ops/sum_to_size_ops.h>
|
| 392 |
+
#include <ATen/ops/svd_ops.h>
|
| 393 |
+
#include <ATen/ops/swapaxes_ops.h>
|
| 394 |
+
#include <ATen/ops/swapdims_ops.h>
|
| 395 |
+
#include <ATen/ops/t_ops.h>
|
| 396 |
+
#include <ATen/ops/take_along_dim_ops.h>
|
| 397 |
+
#include <ATen/ops/take_ops.h>
|
| 398 |
+
#include <ATen/ops/tan_ops.h>
|
| 399 |
+
#include <ATen/ops/tanh_ops.h>
|
| 400 |
+
#include <ATen/ops/tensor_split_ops.h>
|
| 401 |
+
#include <ATen/ops/tile_ops.h>
|
| 402 |
+
#include <ATen/ops/to_dense_ops.h>
|
| 403 |
+
#include <ATen/ops/to_mkldnn_ops.h>
|
| 404 |
+
#include <ATen/ops/to_ops.h>
|
| 405 |
+
#include <ATen/ops/to_padded_tensor_ops.h>
|
| 406 |
+
#include <ATen/ops/to_sparse_bsc_ops.h>
|
| 407 |
+
#include <ATen/ops/to_sparse_bsr_ops.h>
|
| 408 |
+
#include <ATen/ops/to_sparse_csc_ops.h>
|
| 409 |
+
#include <ATen/ops/to_sparse_csr_ops.h>
|
| 410 |
+
#include <ATen/ops/to_sparse_ops.h>
|
| 411 |
+
#include <ATen/ops/topk_ops.h>
|
| 412 |
+
#include <ATen/ops/trace_ops.h>
|
| 413 |
+
#include <ATen/ops/transpose_ops.h>
|
| 414 |
+
#include <ATen/ops/triangular_solve_ops.h>
|
| 415 |
+
#include <ATen/ops/tril_ops.h>
|
| 416 |
+
#include <ATen/ops/triu_ops.h>
|
| 417 |
+
#include <ATen/ops/true_divide_ops.h>
|
| 418 |
+
#include <ATen/ops/trunc_ops.h>
|
| 419 |
+
#include <ATen/ops/type_as_ops.h>
|
| 420 |
+
#include <ATen/ops/unbind_ops.h>
|
| 421 |
+
#include <ATen/ops/unflatten_ops.h>
|
| 422 |
+
#include <ATen/ops/unfold_ops.h>
|
| 423 |
+
#include <ATen/ops/uniform_ops.h>
|
| 424 |
+
#include <ATen/ops/unsafe_chunk_ops.h>
|
| 425 |
+
#include <ATen/ops/unsafe_split_ops.h>
|
| 426 |
+
#include <ATen/ops/unsafe_split_with_sizes_ops.h>
|
| 427 |
+
#include <ATen/ops/unsqueeze_ops.h>
|
| 428 |
+
#include <ATen/ops/values_ops.h>
|
| 429 |
+
#include <ATen/ops/var_ops.h>
|
| 430 |
+
#include <ATen/ops/vdot_ops.h>
|
| 431 |
+
#include <ATen/ops/view_as_ops.h>
|
| 432 |
+
#include <ATen/ops/view_ops.h>
|
| 433 |
+
#include <ATen/ops/vsplit_ops.h>
|
| 434 |
+
#include <ATen/ops/where_ops.h>
|
| 435 |
+
#include <ATen/ops/xlogy_ops.h>
|
| 436 |
+
#include <ATen/ops/xor_ops.h>
|
| 437 |
+
#include <ATen/ops/zero_ops.h>
|
| 438 |
+
|
| 439 |
+
namespace at {
|
| 440 |
+
namespace _ops {
|
| 441 |
+
|
| 442 |
+
} // namespace _ops
|
| 443 |
+
} // namespace at
|