Spaces:
Sleeping
Sleeping
Upload 5061 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +15 -0
- torch/_C.cp310-win_amd64.pyd +0 -0
- torch/_C/_VariableFunctions.pyi +0 -0
- torch/_C/__init__.pyi +0 -0
- torch/_C/_autograd.pyi +123 -0
- torch/_C/_cpu.pyi +5 -0
- torch/_C/_cudnn.pyi +17 -0
- torch/_C/_distributed_autograd.pyi +26 -0
- torch/_C/_distributed_c10d.pyi +478 -0
- torch/_C/_distributed_rpc.pyi +188 -0
- torch/_C/_distributed_rpc_testing.pyi +35 -0
- torch/_C/_functions.pyi +11 -0
- torch/_C/_functorch.pyi +71 -0
- torch/_C/_itt.pyi +5 -0
- torch/_C/_lazy.pyi +28 -0
- torch/_C/_lazy_ts_backend.pyi +11 -0
- torch/_C/_monitor.pyi +44 -0
- torch/_C/_nn.pyi +86 -0
- torch/_C/_nvtx.pyi +6 -0
- torch/_C/_onnx.pyi +38 -0
- torch/_C/_profiler.pyi +238 -0
- torch/_C/_verbose.pyi +3 -0
- torch/_VF.py +30 -0
- torch/_VF.pyi +0 -0
- torch/__config__.py +22 -0
- torch/__future__.py +21 -0
- torch/_appdirs.py +666 -0
- torch/_awaits/__init__.py +54 -0
- torch/_awaits/__pycache__/__init__.cpython-310.pyc +0 -0
- torch/_classes.py +55 -0
- torch/_compile.py +30 -0
- torch/_custom_op/__init__.py +0 -0
- torch/_custom_op/__pycache__/__init__.cpython-310.pyc +0 -0
- torch/_custom_op/__pycache__/autograd.cpython-310.pyc +0 -0
- torch/_custom_op/__pycache__/functional.cpython-310.pyc +0 -0
- torch/_custom_op/__pycache__/impl.cpython-310.pyc +0 -0
- torch/_custom_op/autograd.py +274 -0
- torch/_custom_op/functional.py +187 -0
- torch/_custom_op/impl.py +976 -0
- torch/_custom_ops.py +322 -0
- torch/_decomp/__init__.py +444 -0
- torch/_decomp/__pycache__/__init__.cpython-310.pyc +0 -0
- torch/_decomp/__pycache__/decompositions.cpython-310.pyc +0 -0
- torch/_decomp/__pycache__/decompositions_for_jvp.cpython-310.pyc +0 -0
- torch/_decomp/__pycache__/decompositions_for_rng.cpython-310.pyc +0 -0
- torch/_decomp/decompositions.py +0 -0
- torch/_decomp/decompositions_for_jvp.py +302 -0
- torch/_decomp/decompositions_for_rng.py +263 -0
- torch/_deploy.py +105 -0
- torch/_dispatch/__init__.py +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
torch/bin/fbgemm.dll filter=lfs diff=lfs merge=lfs -text
|
37 |
+
torch/bin/protoc.exe filter=lfs diff=lfs merge=lfs -text
|
38 |
+
torch/lib/dnnl.lib filter=lfs diff=lfs merge=lfs -text
|
39 |
+
torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text
|
40 |
+
torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text
|
41 |
+
torch/lib/fmt.lib filter=lfs diff=lfs merge=lfs -text
|
42 |
+
torch/lib/kineto.lib filter=lfs diff=lfs merge=lfs -text
|
43 |
+
torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text
|
44 |
+
torch/lib/libprotobuf-lite.lib filter=lfs diff=lfs merge=lfs -text
|
45 |
+
torch/lib/libprotobuf.lib filter=lfs diff=lfs merge=lfs -text
|
46 |
+
torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text
|
47 |
+
torch/lib/torch_cpu.dll filter=lfs diff=lfs merge=lfs -text
|
48 |
+
torch/lib/torch_cpu.lib filter=lfs diff=lfs merge=lfs -text
|
49 |
+
torch/lib/torch_python.dll filter=lfs diff=lfs merge=lfs -text
|
50 |
+
torch/lib/XNNPACK.lib filter=lfs diff=lfs merge=lfs -text
|
torch/_C.cp310-win_amd64.pyd
ADDED
Binary file (10.2 kB). View file
|
|
torch/_C/_VariableFunctions.pyi
ADDED
The diff for this file is too large to render.
See raw diff
|
|
torch/_C/__init__.pyi
ADDED
The diff for this file is too large to render.
See raw diff
|
|
torch/_C/_autograd.pyi
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from typing import Any, Callable, List, Optional, Set
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from ._profiler import (
|
7 |
+
_ProfilerEvent,
|
8 |
+
ActiveProfilerType,
|
9 |
+
ProfilerActivity,
|
10 |
+
ProfilerConfig,
|
11 |
+
)
|
12 |
+
|
13 |
+
# Defined in tools/autograd/init.cpp
|
14 |
+
|
15 |
+
class DeviceType(Enum):
|
16 |
+
CPU = ...
|
17 |
+
CUDA = ...
|
18 |
+
MKLDNN = ...
|
19 |
+
OPENGL = ...
|
20 |
+
OPENCL = ...
|
21 |
+
IDEEP = ...
|
22 |
+
HIP = ...
|
23 |
+
FPGA = ...
|
24 |
+
ORT = ...
|
25 |
+
XLA = ...
|
26 |
+
MPS = ...
|
27 |
+
HPU = ...
|
28 |
+
Meta = ...
|
29 |
+
Vulkan = ...
|
30 |
+
Metal = ...
|
31 |
+
PrivateUse1 = ...
|
32 |
+
|
33 |
+
class ProfilerEvent:
|
34 |
+
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
35 |
+
def cpu_memory_usage(self) -> int: ...
|
36 |
+
def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
37 |
+
def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
38 |
+
def cuda_memory_usage(self) -> int: ...
|
39 |
+
def device(self) -> int: ...
|
40 |
+
def handle(self) -> int: ...
|
41 |
+
def has_cuda(self) -> bool: ...
|
42 |
+
def is_remote(self) -> bool: ...
|
43 |
+
def kind(self) -> int: ...
|
44 |
+
def name(self) -> str: ...
|
45 |
+
def node_id(self) -> int: ...
|
46 |
+
def sequence_nr(self) -> int: ...
|
47 |
+
def shapes(self) -> List[List[int]]: ...
|
48 |
+
def thread_id(self) -> int: ...
|
49 |
+
def flops(self) -> float: ...
|
50 |
+
def is_async(self) -> bool: ...
|
51 |
+
|
52 |
+
class _KinetoEvent:
|
53 |
+
def name(self) -> str: ...
|
54 |
+
def device_index(self) -> int: ...
|
55 |
+
def start_us(self) -> int: ...
|
56 |
+
def duration_us(self) -> int: ...
|
57 |
+
def is_async(self) -> bool: ...
|
58 |
+
def linked_correlation_id(self) -> int: ...
|
59 |
+
def shapes(self) -> List[List[int]]: ...
|
60 |
+
def dtypes(self) -> List[str]: ...
|
61 |
+
def concrete_inputs(self) -> List[Any]: ...
|
62 |
+
def device_type(self) -> DeviceType: ...
|
63 |
+
def start_thread_id(self) -> int: ...
|
64 |
+
def end_thread_id(self) -> int: ...
|
65 |
+
def correlation_id(self) -> int: ...
|
66 |
+
def fwd_thread_id(self) -> int: ...
|
67 |
+
def stack(self) -> List[str]: ...
|
68 |
+
def scope(self) -> int: ...
|
69 |
+
def sequence_nr(self) -> int: ...
|
70 |
+
def flops(self) -> int: ...
|
71 |
+
def cuda_elapsed_us(self) -> int: ...
|
72 |
+
def privateuse1_elapsed_us(self) -> int: ...
|
73 |
+
|
74 |
+
class _ProfilerResult:
|
75 |
+
def events(self) -> List[_KinetoEvent]: ...
|
76 |
+
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
|
77 |
+
def save(self, path: str) -> None: ...
|
78 |
+
def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
|
79 |
+
def trace_start_us(self) -> int: ...
|
80 |
+
|
81 |
+
class SavedTensor: ...
|
82 |
+
|
83 |
+
def _enable_profiler(
|
84 |
+
config: ProfilerConfig,
|
85 |
+
activities: Set[ProfilerActivity],
|
86 |
+
) -> None: ...
|
87 |
+
def _prepare_profiler(
|
88 |
+
config: ProfilerConfig,
|
89 |
+
activities: Set[ProfilerActivity],
|
90 |
+
) -> None: ...
|
91 |
+
def _disable_profiler() -> _ProfilerResult: ...
|
92 |
+
def _profiler_enabled() -> bool: ...
|
93 |
+
def _add_metadata_json(key: str, value: str) -> None: ...
|
94 |
+
def _kineto_step() -> None: ...
|
95 |
+
def _get_sequence_nr() -> int: ...
|
96 |
+
def kineto_available() -> bool: ...
|
97 |
+
def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
|
98 |
+
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
|
99 |
+
def _supported_activities() -> Set[ProfilerActivity]: ...
|
100 |
+
def _enable_record_function(enable: bool) -> None: ...
|
101 |
+
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
|
102 |
+
def _push_saved_tensors_default_hooks(
|
103 |
+
pack_hook: Callable[[torch.Tensor], Any],
|
104 |
+
unpack_hook: Callable[[Any], torch.Tensor],
|
105 |
+
) -> None: ...
|
106 |
+
def _pop_saved_tensors_default_hooks() -> None: ...
|
107 |
+
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
|
108 |
+
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
109 |
+
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
|
110 |
+
def _profiler_type() -> ActiveProfilerType: ...
|
111 |
+
def _saved_tensors_hooks_enable() -> None: ...
|
112 |
+
def _saved_tensors_hooks_disable(message: str) -> None: ...
|
113 |
+
def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
|
114 |
+
|
115 |
+
class CreationMeta(Enum):
|
116 |
+
DEFAULT = ...
|
117 |
+
IN_CUSTOM_FUNCTION = ...
|
118 |
+
MULTI_OUTPUT_NODE = ...
|
119 |
+
NO_GRAD_MODE = ...
|
120 |
+
INFERENCE_MODE = ...
|
121 |
+
|
122 |
+
def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
|
123 |
+
def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...
|
torch/_C/_cpu.pyi
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.types import _bool
|
2 |
+
|
3 |
+
# Defined in torch/csrc/cpu/Module.cpp
|
4 |
+
|
5 |
+
def _is_cpu_support_vnni() -> _bool: ...
|
torch/_C/_cudnn.pyi
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
from torch.types import _bool, Tuple
|
4 |
+
|
5 |
+
# Defined in torch/csrc/cuda/shared/cudnn.cpp
|
6 |
+
is_cuda: _bool
|
7 |
+
|
8 |
+
def getRuntimeVersion() -> Tuple[int, int, int]: ...
|
9 |
+
def getCompileVersion() -> Tuple[int, int, int]: ...
|
10 |
+
def getVersionInt() -> int: ...
|
11 |
+
|
12 |
+
class RNNMode(int, Enum):
|
13 |
+
value: int
|
14 |
+
rnn_relu = ...
|
15 |
+
rnn_tanh = ...
|
16 |
+
lstm = ...
|
17 |
+
gru = ...
|
torch/_C/_distributed_autograd.pyi
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Set
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# This module is defined in torch/csrc/distributed/autograd/init.cpp
|
6 |
+
|
7 |
+
class DistAutogradContext:
|
8 |
+
def _context_id(self) -> int: ...
|
9 |
+
def _recv_functions(self) -> Dict[int, Any]: ...
|
10 |
+
def _send_functions(self) -> Dict[int, Any]: ...
|
11 |
+
def _known_worker_ids(self) -> Set[int]: ...
|
12 |
+
|
13 |
+
def _new_context() -> DistAutogradContext: ...
|
14 |
+
def _release_context(context_id: int) -> None: ...
|
15 |
+
def _get_max_id() -> int: ...
|
16 |
+
def _is_valid_context(worker_id: int) -> bool: ...
|
17 |
+
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
|
18 |
+
def _current_context() -> DistAutogradContext: ...
|
19 |
+
def _init(worker_id: int) -> None: ...
|
20 |
+
def _get_debug_info() -> Dict[str, str]: ...
|
21 |
+
def backward(
|
22 |
+
context_id: int,
|
23 |
+
roots: List[torch.Tensor],
|
24 |
+
retain_graph=False,
|
25 |
+
) -> None: ...
|
26 |
+
def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...
|
torch/_C/_distributed_c10d.pyi
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mypy: disable-error-code="type-arg"
|
2 |
+
from datetime import timedelta
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Any, Dict, List, Optional, overload, Tuple, Union
|
5 |
+
|
6 |
+
from torch import Tensor
|
7 |
+
from torch._C import ScriptObject
|
8 |
+
from torch.futures import Future
|
9 |
+
|
10 |
+
# This module is defined in torch/csrc/distributed/c10d/init.cpp
|
11 |
+
|
12 |
+
_DEFAULT_FIRST_BUCKET_BYTES: int
|
13 |
+
_DEFAULT_NO_TIMEOUT: timedelta
|
14 |
+
_DEFAULT_PG_TIMEOUT: timedelta
|
15 |
+
_DEFAULT_PG_NCCL_TIMEOUT: timedelta
|
16 |
+
|
17 |
+
class BuiltinCommHookType(Enum):
|
18 |
+
ALLREDUCE = ...
|
19 |
+
FP16_COMPRESS = ...
|
20 |
+
|
21 |
+
def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
|
22 |
+
def _register_builtin_comm_hook(
|
23 |
+
reducer: Reducer,
|
24 |
+
comm_hook_type: BuiltinCommHookType,
|
25 |
+
): ...
|
26 |
+
|
27 |
+
class GradBucket:
|
28 |
+
def index(self) -> int: ...
|
29 |
+
def buffer(self) -> Tensor: ...
|
30 |
+
def gradients(self) -> List[Tensor]: ...
|
31 |
+
def is_last(self) -> bool: ...
|
32 |
+
def set_buffer(self, tensor: Tensor) -> None: ...
|
33 |
+
def parameters(self) -> List[Tensor]: ...
|
34 |
+
|
35 |
+
class Reducer:
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
params: List[Tensor],
|
39 |
+
bucket_indices: List[List[int]],
|
40 |
+
per_bucket_size_limits: List[int],
|
41 |
+
process_group: ProcessGroup,
|
42 |
+
expect_sparse_gradients: List[bool] = ...,
|
43 |
+
bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
|
44 |
+
find_unused_parameters: bool = ...,
|
45 |
+
gradient_as_bucket_view: bool = ...,
|
46 |
+
param_to_name_mapping: Dict[int, str] = ...,
|
47 |
+
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
|
48 |
+
): ...
|
49 |
+
def prepare_for_forward(self) -> None: ...
|
50 |
+
def prepare_for_backward(self, output: List[Tensor]) -> None: ...
|
51 |
+
def get_backward_stats(self) -> List[int]: ...
|
52 |
+
def _install_post_backward_futures(self, futures: List[Future]) -> None: ...
|
53 |
+
def _rebuild_buckets(self) -> bool: ...
|
54 |
+
def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ...
|
55 |
+
def _push_all_rebuilt_params(self) -> None: ...
|
56 |
+
def _set_forward_pass_work_handle(
|
57 |
+
self,
|
58 |
+
work: Work,
|
59 |
+
use_static_world_size: bool,
|
60 |
+
): ...
|
61 |
+
def _get_local_used_map(self) -> Tensor: ...
|
62 |
+
def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
|
63 |
+
def _set_static_graph(self) -> None: ...
|
64 |
+
def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
|
65 |
+
def set_logger(self, logger: Logger) -> None: ...
|
66 |
+
def _remove_autograd_hooks(self) -> None: ...
|
67 |
+
def _check_reducer_finalized(self) -> None: ...
|
68 |
+
def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
|
69 |
+
def _reset_state(self) -> None: ...
|
70 |
+
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
|
71 |
+
|
72 |
+
class DDPLoggingData:
|
73 |
+
strs_map: Dict[str, str]
|
74 |
+
ints_map: Dict[str, int]
|
75 |
+
|
76 |
+
class Logger:
|
77 |
+
def __init__(self, reducer: Reducer): ...
|
78 |
+
def set_construction_data_and_log(
|
79 |
+
self,
|
80 |
+
module_name: str,
|
81 |
+
device_ids: List[int],
|
82 |
+
output_device: int,
|
83 |
+
broadcast_buffers: bool,
|
84 |
+
has_sync_bn: bool,
|
85 |
+
static_graph: bool,
|
86 |
+
): ...
|
87 |
+
def set_runtime_stats_and_log(self) -> None: ...
|
88 |
+
def set_error_and_log(self, error: str) -> None: ...
|
89 |
+
def _get_ddp_logging_data(self) -> DDPLoggingData: ...
|
90 |
+
def _set_comm_hook_name(self, comm_hook: str) -> None: ...
|
91 |
+
def _set_uneven_input_join(self) -> None: ...
|
92 |
+
def _set_static_graph(self) -> None: ...
|
93 |
+
|
94 |
+
def get_debug_level(): ...
|
95 |
+
def set_debug_level(): ...
|
96 |
+
def set_debug_level_from_env(): ...
|
97 |
+
|
98 |
+
class DebugLevel(Enum):
|
99 |
+
OFF = ...
|
100 |
+
INFO = ...
|
101 |
+
DETAIL = ...
|
102 |
+
|
103 |
+
class ReduceOp:
|
104 |
+
def __init__(self, op: RedOpType): ...
|
105 |
+
|
106 |
+
SUM: RedOpType = ...
|
107 |
+
AVG: RedOpType = ...
|
108 |
+
PRODUCT: RedOpType = ...
|
109 |
+
MIN: RedOpType = ...
|
110 |
+
MAX: RedOpType = ...
|
111 |
+
BAND: RedOpType = ...
|
112 |
+
BOR: RedOpType = ...
|
113 |
+
BXOR: RedOpType = ...
|
114 |
+
PREMUL_SUM: RedOpType = ...
|
115 |
+
UNUSED: RedOpType = ...
|
116 |
+
|
117 |
+
class RedOpType(Enum): ...
|
118 |
+
|
119 |
+
class BroadcastOptions:
|
120 |
+
rootRank: int
|
121 |
+
rootTensor: int
|
122 |
+
timeout: timedelta
|
123 |
+
asyncOp: bool
|
124 |
+
|
125 |
+
class AllreduceOptions:
|
126 |
+
reduceOp: ReduceOp
|
127 |
+
timeout: timedelta
|
128 |
+
|
129 |
+
class AllreduceCoalescedOptions(AllreduceOptions): ...
|
130 |
+
|
131 |
+
class ReduceOptions:
|
132 |
+
reduceOp: ReduceOp
|
133 |
+
rootRank: int
|
134 |
+
rootTensor: int
|
135 |
+
timeout: timedelta
|
136 |
+
|
137 |
+
class AllgatherOptions:
|
138 |
+
timeout: timedelta
|
139 |
+
asyncOp: bool
|
140 |
+
|
141 |
+
class GatherOptions:
|
142 |
+
rootRank: int
|
143 |
+
timeout: timedelta
|
144 |
+
|
145 |
+
class ScatterOptions:
|
146 |
+
rootRank: int
|
147 |
+
timeout: timedelta
|
148 |
+
asyncOp: bool
|
149 |
+
|
150 |
+
class ReduceScatterOptions:
|
151 |
+
reduceOp: ReduceOp
|
152 |
+
timeout: timedelta
|
153 |
+
asyncOp: bool
|
154 |
+
|
155 |
+
class BarrierOptions:
|
156 |
+
device_ids: List[int]
|
157 |
+
timeout: timedelta
|
158 |
+
|
159 |
+
class AllToAllOptions:
|
160 |
+
timeout: timedelta
|
161 |
+
|
162 |
+
class Store:
|
163 |
+
def set(self, key: str, value: str): ...
|
164 |
+
def get(self, key: str) -> bytes: ...
|
165 |
+
def add(self, key: str, value: int) -> int: ...
|
166 |
+
def compare_set(
|
167 |
+
self,
|
168 |
+
key: str,
|
169 |
+
expected_value: str,
|
170 |
+
desired_value: str,
|
171 |
+
) -> bytes: ...
|
172 |
+
def delete_key(self, key: str) -> bool: ...
|
173 |
+
def num_keys(self) -> int: ...
|
174 |
+
def set_timeout(self, timeout: timedelta): ...
|
175 |
+
@overload
|
176 |
+
def wait(self, keys: List[str]): ...
|
177 |
+
@overload
|
178 |
+
def wait(self, keys: List[str], timeout: timedelta): ...
|
179 |
+
|
180 |
+
class FileStore(Store):
|
181 |
+
def __init__(self, path: str, numWorkers: int = ...): ...
|
182 |
+
|
183 |
+
class HashStore(Store):
|
184 |
+
def __init__(self): ...
|
185 |
+
|
186 |
+
class TCPStore(Store):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
host_name: str,
|
190 |
+
port: int,
|
191 |
+
world_size: Optional[int] = ...,
|
192 |
+
is_master: bool = ...,
|
193 |
+
timeout: timedelta = ...,
|
194 |
+
wait_for_workers: bool = ...,
|
195 |
+
multi_tenant: bool = ...,
|
196 |
+
master_listen_fd: Optional[int] = ...,
|
197 |
+
use_libuv: Optional[bool] = ...,
|
198 |
+
): ...
|
199 |
+
@property
|
200 |
+
def host(self) -> str: ...
|
201 |
+
@property
|
202 |
+
def port(self) -> int: ...
|
203 |
+
|
204 |
+
class PrefixStore(Store):
|
205 |
+
def __init__(self, prefix: str, store: Store): ...
|
206 |
+
@property
|
207 |
+
def underlying_store(self) -> Store: ...
|
208 |
+
|
209 |
+
class Work:
|
210 |
+
def is_completed(self) -> bool: ...
|
211 |
+
def is_success(self) -> bool: ...
|
212 |
+
def exception(self) -> Any: ...
|
213 |
+
def wait(self, timeout: timedelta = ...) -> bool: ...
|
214 |
+
def source_rank(self) -> int: ...
|
215 |
+
def _source_rank(self) -> int: ...
|
216 |
+
def result(self) -> List[Tensor]: ...
|
217 |
+
def synchronize(self): ...
|
218 |
+
def boxed(self) -> ScriptObject: ...
|
219 |
+
@staticmethod
|
220 |
+
def unbox(obj: ScriptObject) -> Work: ...
|
221 |
+
|
222 |
+
class ProcessGroup:
|
223 |
+
class Options: ...
|
224 |
+
|
225 |
+
def __init__(self): ...
|
226 |
+
def rank(self) -> int: ...
|
227 |
+
def size(self) -> int: ...
|
228 |
+
@overload
|
229 |
+
def broadcast(
|
230 |
+
self,
|
231 |
+
tensors: List[Tensor],
|
232 |
+
opts=...,
|
233 |
+
) -> Work: ...
|
234 |
+
@overload
|
235 |
+
def broadcast(
|
236 |
+
self,
|
237 |
+
tensor: Tensor,
|
238 |
+
root: int,
|
239 |
+
) -> Work: ...
|
240 |
+
@overload
|
241 |
+
def allreduce(
|
242 |
+
self,
|
243 |
+
tensors: List[Tensor],
|
244 |
+
opts: AllreduceOptions = ...,
|
245 |
+
) -> Work: ...
|
246 |
+
@overload
|
247 |
+
def allreduce(
|
248 |
+
self,
|
249 |
+
tensors: List[Tensor],
|
250 |
+
op=...,
|
251 |
+
) -> Work: ...
|
252 |
+
@overload
|
253 |
+
def allreduce(
|
254 |
+
self,
|
255 |
+
tensor: Tensor,
|
256 |
+
op=...,
|
257 |
+
) -> Work: ...
|
258 |
+
def allreduce_coalesced(
|
259 |
+
self,
|
260 |
+
tensors: List[Tensor],
|
261 |
+
opts=...,
|
262 |
+
) -> Work: ...
|
263 |
+
@overload
|
264 |
+
def reduce(
|
265 |
+
self,
|
266 |
+
tensors: List[Tensor],
|
267 |
+
opts=...,
|
268 |
+
) -> Work: ...
|
269 |
+
@overload
|
270 |
+
def reduce(
|
271 |
+
self,
|
272 |
+
tensor: Tensor,
|
273 |
+
root: int,
|
274 |
+
op=...,
|
275 |
+
) -> Work: ...
|
276 |
+
@overload
|
277 |
+
def allgather(
|
278 |
+
self,
|
279 |
+
output_tensors: List[List[Tensor]],
|
280 |
+
input_tensors: List[Tensor],
|
281 |
+
opts=...,
|
282 |
+
) -> Work: ...
|
283 |
+
@overload
|
284 |
+
def allgather(
|
285 |
+
self,
|
286 |
+
output_tensors: List[Tensor],
|
287 |
+
input_tensor: Tensor,
|
288 |
+
) -> Work: ...
|
289 |
+
def _allgather_base(
|
290 |
+
self,
|
291 |
+
output: Tensor,
|
292 |
+
input: Tensor,
|
293 |
+
opts=...,
|
294 |
+
) -> Work: ...
|
295 |
+
def allgather_coalesced(
|
296 |
+
self,
|
297 |
+
output_lists: List[List[Tensor]],
|
298 |
+
input_list: List[Tensor],
|
299 |
+
opts=...,
|
300 |
+
) -> Work: ...
|
301 |
+
@overload
|
302 |
+
def gather(
|
303 |
+
self,
|
304 |
+
output_tensors: List[List[Tensor]],
|
305 |
+
input_tensors: List[Tensor],
|
306 |
+
opts=...,
|
307 |
+
) -> Work: ...
|
308 |
+
@overload
|
309 |
+
def gather(
|
310 |
+
self,
|
311 |
+
output_tensors: List[Tensor],
|
312 |
+
input_tensor: Tensor,
|
313 |
+
root: int,
|
314 |
+
) -> Work: ...
|
315 |
+
@overload
|
316 |
+
def scatter(
|
317 |
+
self,
|
318 |
+
output_tensors: List[Tensor],
|
319 |
+
input_tensors: List[List[Tensor]],
|
320 |
+
opts=...,
|
321 |
+
) -> Work: ...
|
322 |
+
@overload
|
323 |
+
def scatter(
|
324 |
+
self,
|
325 |
+
output_tensor: Tensor,
|
326 |
+
input_tensors: List[Tensor],
|
327 |
+
root: int,
|
328 |
+
) -> Work: ...
|
329 |
+
@overload
|
330 |
+
def reduce_scatter(
|
331 |
+
self,
|
332 |
+
output_tensors: List[Tensor],
|
333 |
+
input_tensors: List[List[Tensor]],
|
334 |
+
opts=...,
|
335 |
+
) -> Work: ...
|
336 |
+
@overload
|
337 |
+
def reduce_scatter(
|
338 |
+
self,
|
339 |
+
output_tensors: Tensor,
|
340 |
+
input_tensor: List[Tensor],
|
341 |
+
) -> Work: ...
|
342 |
+
def _reduce_scatter_base(
|
343 |
+
self,
|
344 |
+
outputTensor: Tensor,
|
345 |
+
inputTensor: Tensor,
|
346 |
+
) -> Work: ...
|
347 |
+
@overload
|
348 |
+
def alltoall_base(
|
349 |
+
self,
|
350 |
+
output_tensor: Tensor,
|
351 |
+
input_tensor: Tensor,
|
352 |
+
output_split_sizes: List[int],
|
353 |
+
input_split_sizes: List[int],
|
354 |
+
opts=...,
|
355 |
+
) -> Work: ...
|
356 |
+
@overload
|
357 |
+
def alltoall_base(
|
358 |
+
self,
|
359 |
+
output: Tensor,
|
360 |
+
input: Tensor,
|
361 |
+
output_split_sizes: List[int],
|
362 |
+
input_split_sizes: List[int],
|
363 |
+
) -> Work: ...
|
364 |
+
@overload
|
365 |
+
def alltoall(
|
366 |
+
self,
|
367 |
+
output_tensor: List[Tensor],
|
368 |
+
input_tensor: List[Tensor],
|
369 |
+
opts=...,
|
370 |
+
) -> Work: ...
|
371 |
+
@overload
|
372 |
+
def alltoall(
|
373 |
+
self,
|
374 |
+
output: List[Tensor],
|
375 |
+
input: List[Tensor],
|
376 |
+
) -> Work: ...
|
377 |
+
def send(
|
378 |
+
self,
|
379 |
+
tensors: List[Tensor],
|
380 |
+
dstRank: int,
|
381 |
+
tag: int,
|
382 |
+
) -> Work: ...
|
383 |
+
def recv(
|
384 |
+
self,
|
385 |
+
tensors: List[Tensor],
|
386 |
+
srcRank: int,
|
387 |
+
tag: int,
|
388 |
+
) -> Work: ...
|
389 |
+
def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
|
390 |
+
def barrier(self, opts=...) -> Work: ...
|
391 |
+
def boxed(self) -> ScriptObject: ...
|
392 |
+
@staticmethod
|
393 |
+
def unbox(obj: ScriptObject) -> ProcessGroup: ...
|
394 |
+
|
395 |
+
class ProcessGroupRoundRobin(ProcessGroup): ...
|
396 |
+
|
397 |
+
def _round_robin_process_groups(
|
398 |
+
process_groups: List[ProcessGroup],
|
399 |
+
) -> ProcessGroupRoundRobin: ...
|
400 |
+
|
401 |
+
class ProcessGroupGloo(ProcessGroup):
|
402 |
+
class Device: ...
|
403 |
+
class Options: ...
|
404 |
+
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
store: Store,
|
408 |
+
rank: int,
|
409 |
+
size: int,
|
410 |
+
timeout: timedelta,
|
411 |
+
): ...
|
412 |
+
@staticmethod
|
413 |
+
def create_device(hostname="", interface="") -> Device: ...
|
414 |
+
@staticmethod
|
415 |
+
def create_default_device() -> Device: ...
|
416 |
+
|
417 |
+
class _ProcessGroupWrapper(ProcessGroup):
|
418 |
+
def __init__(self, pg: ProcessGroup, gloo_pg: ProcessGroupGloo): ...
|
419 |
+
wrapped_pg: ProcessGroup
|
420 |
+
|
421 |
+
class ProcessGroupNCCL(ProcessGroup):
|
422 |
+
class Options: ...
|
423 |
+
|
424 |
+
def __init__(
|
425 |
+
self,
|
426 |
+
store: Store,
|
427 |
+
rank: int,
|
428 |
+
size: int,
|
429 |
+
timeout: timedelta,
|
430 |
+
): ...
|
431 |
+
def _group_start(self) -> None: ...
|
432 |
+
def _group_end(self) -> None: ...
|
433 |
+
|
434 |
+
class ProcessGroupUCC(ProcessGroup):
|
435 |
+
def __init__(
|
436 |
+
self,
|
437 |
+
store: Store,
|
438 |
+
rank: int,
|
439 |
+
size: int,
|
440 |
+
timeout: timedelta,
|
441 |
+
): ...
|
442 |
+
|
443 |
+
class ProcessGroupMPI(ProcessGroup):
|
444 |
+
def __init__(
|
445 |
+
self,
|
446 |
+
rank: int,
|
447 |
+
size: int,
|
448 |
+
pgComm: int,
|
449 |
+
): ...
|
450 |
+
@staticmethod
|
451 |
+
def create(ranks: List[int]) -> ProcessGroupMPI: ...
|
452 |
+
|
453 |
+
def _compute_bucket_assignment_by_size(
|
454 |
+
tensors: List[Tensor],
|
455 |
+
bucket_size_limits: List[int],
|
456 |
+
expect_sparse_gradient: List[bool] = ...,
|
457 |
+
tensor_indices: List[int] = ...,
|
458 |
+
) -> Tuple[List[List[int]], List[int]]: ...
|
459 |
+
def _broadcast_coalesced(
|
460 |
+
process_group: ProcessGroup,
|
461 |
+
tensors: List[Tensor],
|
462 |
+
buffer_size: int,
|
463 |
+
src: int,
|
464 |
+
): ...
|
465 |
+
def _test_python_store(store: Store): ...
|
466 |
+
def _verify_params_across_processes(
|
467 |
+
process_group: ProcessGroup,
|
468 |
+
params: List[Tensor],
|
469 |
+
logger: Optional[Logger],
|
470 |
+
): ...
|
471 |
+
def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
|
472 |
+
|
473 |
+
class Backend:
|
474 |
+
def __init__(
|
475 |
+
self,
|
476 |
+
rank: int,
|
477 |
+
size: int,
|
478 |
+
): ...
|
torch/_C/_distributed_rpc.pyi
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mypy: disable-error-code="type-arg"
|
2 |
+
from datetime import timedelta
|
3 |
+
from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from . import Future
|
8 |
+
from ._autograd import ProfilerEvent
|
9 |
+
from ._distributed_c10d import Store
|
10 |
+
from ._profiler import ProfilerConfig
|
11 |
+
|
12 |
+
# This module is defined in torch/csrc/distributed/rpc/init.cpp
|
13 |
+
|
14 |
+
_DEFAULT_INIT_METHOD: str
|
15 |
+
_DEFAULT_NUM_WORKER_THREADS: int
|
16 |
+
_UNSET_RPC_TIMEOUT: float
|
17 |
+
_DEFAULT_RPC_TIMEOUT_SEC: float
|
18 |
+
|
19 |
+
_T = TypeVar("_T")
|
20 |
+
|
21 |
+
class RpcBackendOptions:
|
22 |
+
rpc_timeout: float
|
23 |
+
init_method: str
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
rpc_timeout: float = ...,
|
27 |
+
init_method: str = ...,
|
28 |
+
): ...
|
29 |
+
|
30 |
+
class WorkerInfo:
|
31 |
+
def __init__(self, name: str, worker_id: int): ...
|
32 |
+
@property
|
33 |
+
def name(self) -> str: ...
|
34 |
+
@property
|
35 |
+
def id(self) -> int: ...
|
36 |
+
def __eq__(self, other: object) -> bool: ...
|
37 |
+
|
38 |
+
class RpcAgent:
|
39 |
+
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
40 |
+
def sync(self): ...
|
41 |
+
def shutdown(self): ...
|
42 |
+
@overload
|
43 |
+
def get_worker_info(self) -> WorkerInfo: ...
|
44 |
+
@overload
|
45 |
+
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
46 |
+
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
47 |
+
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
|
48 |
+
def get_debug_info(self) -> Dict[str, str]: ...
|
49 |
+
def get_metrics(self) -> Dict[str, str]: ...
|
50 |
+
|
51 |
+
class PyRRef(Generic[_T]):
|
52 |
+
def __init__(self, value: _T, type_hint: Any = None) -> None: ...
|
53 |
+
def is_owner(self) -> bool: ...
|
54 |
+
def confirmed_by_owner(self) -> bool: ...
|
55 |
+
def owner(self) -> WorkerInfo: ...
|
56 |
+
def owner_name(self) -> str: ...
|
57 |
+
def to_here(self, timeout: float = ...) -> _T: ...
|
58 |
+
def local_value(self) -> Any: ...
|
59 |
+
def rpc_sync(self, timeout: float = ...) -> Any: ...
|
60 |
+
def rpc_async(self, timeout: float = ...) -> Any: ...
|
61 |
+
def remote(self, timeout: float = ...) -> Any: ...
|
62 |
+
def _serialize(self) -> Tuple: ...
|
63 |
+
@staticmethod
|
64 |
+
def _deserialize(tp: Tuple) -> PyRRef: ...
|
65 |
+
def _get_type(self) -> Type[_T]: ...
|
66 |
+
def _get_future(self) -> Future[_T]: ...
|
67 |
+
def _get_profiling_future(self) -> Future[_T]: ...
|
68 |
+
def _set_profiling_future(self, profilingFuture: Future[_T]): ...
|
69 |
+
|
70 |
+
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
|
71 |
+
num_worker_threads: int
|
72 |
+
device_maps: Dict[str, Dict[torch.device, torch.device]]
|
73 |
+
devices: List[torch.device]
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
num_worker_threads: int,
|
77 |
+
_transports: Optional[List],
|
78 |
+
_channels: Optional[List],
|
79 |
+
rpc_timeout: float = ...,
|
80 |
+
init_method: str = ...,
|
81 |
+
device_maps: Dict[str, Dict[torch.device, torch.device]] = {}, # noqa: B006
|
82 |
+
devices: List[torch.device] = [], # noqa: B006
|
83 |
+
): ...
|
84 |
+
def _set_device_map(
|
85 |
+
self,
|
86 |
+
to: str,
|
87 |
+
device_map: Dict[torch.device, torch.device],
|
88 |
+
): ...
|
89 |
+
|
90 |
+
class TensorPipeAgent(RpcAgent):
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
store: Store,
|
94 |
+
name: str,
|
95 |
+
worker_id: int,
|
96 |
+
world_size: Optional[int],
|
97 |
+
opts: _TensorPipeRpcBackendOptionsBase,
|
98 |
+
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
|
99 |
+
devices: List[torch.device],
|
100 |
+
): ...
|
101 |
+
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
102 |
+
def shutdown(self): ...
|
103 |
+
@overload
|
104 |
+
def get_worker_info(self) -> WorkerInfo: ...
|
105 |
+
@overload
|
106 |
+
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
107 |
+
@overload
|
108 |
+
def get_worker_info(self, id: int) -> WorkerInfo: ...
|
109 |
+
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
110 |
+
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
|
111 |
+
def _update_group_membership(
|
112 |
+
self,
|
113 |
+
worker_info: WorkerInfo,
|
114 |
+
my_devices: List[torch.device],
|
115 |
+
reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
|
116 |
+
is_join: bool,
|
117 |
+
): ...
|
118 |
+
def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
|
119 |
+
@property
|
120 |
+
def is_static_group(self) -> bool: ...
|
121 |
+
@property
|
122 |
+
def store(self) -> Store: ...
|
123 |
+
|
124 |
+
def _is_current_rpc_agent_set() -> bool: ...
|
125 |
+
def _get_current_rpc_agent() -> RpcAgent: ...
|
126 |
+
def _set_and_start_rpc_agent(agent: RpcAgent): ...
|
127 |
+
def _reset_current_rpc_agent(): ...
|
128 |
+
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
|
129 |
+
def _destroy_rref_context(ignoreRRefLeak: bool): ...
|
130 |
+
def _rref_context_get_debug_info() -> Dict[str, str]: ...
|
131 |
+
def _cleanup_python_rpc_handler(): ...
|
132 |
+
def _invoke_rpc_builtin(
|
133 |
+
dst: WorkerInfo,
|
134 |
+
opName: str,
|
135 |
+
rpcTimeoutSeconds: float,
|
136 |
+
*args: Any,
|
137 |
+
**kwargs: Any,
|
138 |
+
): ...
|
139 |
+
def _invoke_rpc_python_udf(
|
140 |
+
dst: WorkerInfo,
|
141 |
+
pickledPythonUDF: str,
|
142 |
+
tensors: List[torch.Tensor],
|
143 |
+
rpcTimeoutSeconds: float,
|
144 |
+
isAsyncExecution: bool,
|
145 |
+
): ...
|
146 |
+
def _invoke_rpc_torchscript(
|
147 |
+
dstWorkerName: str,
|
148 |
+
qualifiedNameStr: str,
|
149 |
+
argsTuple: Tuple,
|
150 |
+
kwargsDict: Dict,
|
151 |
+
rpcTimeoutSeconds: float,
|
152 |
+
isAsyncExecution: bool,
|
153 |
+
): ...
|
154 |
+
def _invoke_remote_builtin(
|
155 |
+
dst: WorkerInfo,
|
156 |
+
opName: str,
|
157 |
+
rpcTimeoutSeconds: float,
|
158 |
+
*args: Any,
|
159 |
+
**kwargs: Any,
|
160 |
+
): ...
|
161 |
+
def _invoke_remote_python_udf(
|
162 |
+
dst: WorkerInfo,
|
163 |
+
pickledPythonUDF: str,
|
164 |
+
tensors: List[torch.Tensor],
|
165 |
+
rpcTimeoutSeconds: float,
|
166 |
+
isAsyncExecution: bool,
|
167 |
+
): ...
|
168 |
+
def _invoke_remote_torchscript(
|
169 |
+
dstWorkerName: WorkerInfo,
|
170 |
+
qualifiedNameStr: str,
|
171 |
+
rpcTimeoutSeconds: float,
|
172 |
+
isAsyncExecution: bool,
|
173 |
+
*args: Any,
|
174 |
+
**kwargs: Any,
|
175 |
+
): ...
|
176 |
+
def get_rpc_timeout() -> float: ...
|
177 |
+
def enable_gil_profiling(flag: bool): ...
|
178 |
+
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
|
179 |
+
|
180 |
+
class RemoteProfilerManager:
|
181 |
+
@staticmethod
|
182 |
+
def set_current_profiling_key(key: str): ...
|
183 |
+
|
184 |
+
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
|
185 |
+
def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
|
186 |
+
def _set_profiler_node_id(default_node_id: int): ...
|
187 |
+
def _enable_jit_rref_pickle(): ...
|
188 |
+
def _disable_jit_rref_pickle(): ...
|
torch/_C/_distributed_rpc_testing.pyi
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._distributed_c10d import Store
|
6 |
+
from ._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
|
7 |
+
|
8 |
+
# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
|
9 |
+
|
10 |
+
class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
num_worker_threads: int,
|
14 |
+
rpc_timeout: float,
|
15 |
+
init_method: str,
|
16 |
+
messages_to_fail: List[str],
|
17 |
+
messages_to_delay: Dict[str, float],
|
18 |
+
num_fail_sends: int,
|
19 |
+
): ...
|
20 |
+
num_send_recv_threads: int
|
21 |
+
messages_to_fail: List[str]
|
22 |
+
messages_to_delay: Dict[str, float]
|
23 |
+
num_fail_sends: int
|
24 |
+
|
25 |
+
class FaultyTensorPipeAgent(TensorPipeAgent):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
store: Store,
|
29 |
+
name: str,
|
30 |
+
rank: int,
|
31 |
+
world_size: int,
|
32 |
+
options: FaultyTensorPipeRpcBackendOptions,
|
33 |
+
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
|
34 |
+
devices: List[torch.device],
|
35 |
+
): ...
|
torch/_C/_functions.pyi
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import AnyStr, List
|
2 |
+
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
class UndefinedGrad:
|
6 |
+
def __init__(self) -> None: ...
|
7 |
+
def __call__(self, *inputs: Tensor) -> List[Tensor]: ...
|
8 |
+
|
9 |
+
class DelayedError:
|
10 |
+
def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
|
11 |
+
def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
|
torch/_C/_functorch.pyi
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
# Defined in torch/csrc/functorch/init.cpp
|
7 |
+
|
8 |
+
def _set_dynamic_layer_keys_included(included: bool) -> None: ...
|
9 |
+
def get_unwrapped(tensor: Tensor) -> Tensor: ...
|
10 |
+
def is_batchedtensor(tensor: Tensor) -> bool: ...
|
11 |
+
def is_functionaltensor(tensor: Tensor) -> bool: ...
|
12 |
+
def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
|
13 |
+
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
|
14 |
+
def maybe_get_bdim(tensor: Tensor) -> int: ...
|
15 |
+
def maybe_get_level(tensor: Tensor) -> int: ...
|
16 |
+
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
|
17 |
+
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
18 |
+
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
19 |
+
def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
|
20 |
+
def current_level() -> int: ...
|
21 |
+
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
22 |
+
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
|
23 |
+
def get_single_level_autograd_function_allowed() -> bool: ...
|
24 |
+
def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
|
25 |
+
def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
|
26 |
+
|
27 |
+
# Defined in aten/src/ATen/functorch/Interpreter.h
|
28 |
+
class TransformType(Enum):
|
29 |
+
Torch: TransformType = ...
|
30 |
+
Vmap: TransformType = ...
|
31 |
+
Grad: TransformType = ...
|
32 |
+
Jvp: TransformType = ...
|
33 |
+
Functionalize: TransformType = ...
|
34 |
+
|
35 |
+
class RandomnessType(Enum):
|
36 |
+
Error: TransformType = ...
|
37 |
+
Same: TransformType = ...
|
38 |
+
Different: TransformType = ...
|
39 |
+
|
40 |
+
class CInterpreter:
|
41 |
+
def key(self) -> TransformType: ...
|
42 |
+
def level(self) -> int: ...
|
43 |
+
|
44 |
+
class CGradInterpreterPtr:
|
45 |
+
def __init__(self, interpreter: CInterpreter): ...
|
46 |
+
def lift(self, Tensor) -> Tensor: ...
|
47 |
+
def prevGradMode(self) -> bool: ...
|
48 |
+
|
49 |
+
class CJvpInterpreterPtr:
|
50 |
+
def __init__(self, interpreter: CInterpreter): ...
|
51 |
+
def lift(self, Tensor) -> Tensor: ...
|
52 |
+
def prevFwdGradMode(self) -> bool: ...
|
53 |
+
|
54 |
+
class CFunctionalizeInterpreterPtr:
|
55 |
+
def __init__(self, interpreter: CInterpreter): ...
|
56 |
+
def key(self) -> TransformType: ...
|
57 |
+
def level(self) -> int: ...
|
58 |
+
def functionalizeAddBackViews(self) -> bool: ...
|
59 |
+
|
60 |
+
class CVmapInterpreterPtr:
|
61 |
+
def __init__(self, interpreter: CInterpreter): ...
|
62 |
+
def key(self) -> TransformType: ...
|
63 |
+
def level(self) -> int: ...
|
64 |
+
def batchSize(self) -> int: ...
|
65 |
+
def randomness(self) -> RandomnessType: ...
|
66 |
+
|
67 |
+
class DynamicLayer: ...
|
68 |
+
|
69 |
+
def peek_interpreter_stack() -> CInterpreter: ...
|
70 |
+
def pop_dynamic_layer_stack() -> DynamicLayer: ...
|
71 |
+
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
|
torch/_C/_itt.pyi
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defined in torch/csrc/itt.cpp
|
2 |
+
def is_available() -> None: ...
|
3 |
+
def rangePush(message: str) -> None: ...
|
4 |
+
def rangePop() -> None: ...
|
5 |
+
def mark(message: str) -> None: ...
|
torch/_C/_lazy.pyi
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
# defined in torch/csrc/lazy/python/init.cpp
|
6 |
+
def _mark_step(device: str, devices: List[str], wait: bool): ...
|
7 |
+
def _wait_device_ops(devices: List[str]): ...
|
8 |
+
def _reset_metrics(): ...
|
9 |
+
def _counter_names() -> List[str]: ...
|
10 |
+
def _counter_value(name: str) -> int: ...
|
11 |
+
def _metrics_report() -> str: ...
|
12 |
+
def _get_graph_hash(tensors: List[Tensor]) -> str: ...
|
13 |
+
def _sync_multi(
|
14 |
+
tensors: List[Tensor],
|
15 |
+
devices: List[str],
|
16 |
+
wait: bool = True,
|
17 |
+
sync_ltc_data: bool = True,
|
18 |
+
): ...
|
19 |
+
def _get_tensor_id(tensor: Tensor) -> int: ...
|
20 |
+
def _get_tensors_text(tensors: List[Tensor]) -> str: ...
|
21 |
+
def _get_tensors_dot(tensors: List[Tensor]) -> str: ...
|
22 |
+
def _get_tensors_backend(tensors: List[Tensor]) -> str: ...
|
23 |
+
def _get_force_fallback() -> str: ...
|
24 |
+
def _set_force_fallback(newval: str): ...
|
25 |
+
def _clear_ir_cache(): ...
|
26 |
+
def _dump_ir_cache(filename: str): ...
|
27 |
+
def _set_reuse_ir(val: bool): ...
|
28 |
+
def _get_default_device_type(): ...
|
torch/_C/_lazy_ts_backend.pyi
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# defined in torch/csrc/lazy/python/init.cpp
|
2 |
+
|
3 |
+
from typing import Any, List, Tuple
|
4 |
+
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
def _init(): ...
|
8 |
+
def _get_tensors_ts_device_data_node(
|
9 |
+
tensors: List[Tensor],
|
10 |
+
) -> Tuple[List[int], List[Any]]: ...
|
11 |
+
def _run_cached_graph(hash_str: str, graph_inputs: List[Any]) -> List[Tensor]: ...
|
torch/_C/_monitor.pyi
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defined in torch/csrc/monitor/python_init.cpp
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Callable, Dict, List, Union
|
6 |
+
|
7 |
+
class Aggregation(Enum):
|
8 |
+
VALUE = ...
|
9 |
+
MEAN = ...
|
10 |
+
COUNT = ...
|
11 |
+
SUM = ...
|
12 |
+
MAX = ...
|
13 |
+
MIN = ...
|
14 |
+
|
15 |
+
class Stat:
|
16 |
+
name: str
|
17 |
+
count: int
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
name: str,
|
21 |
+
aggregations: List[Aggregation],
|
22 |
+
window_size: int,
|
23 |
+
max_samples: int = -1,
|
24 |
+
) -> None: ...
|
25 |
+
def add(self, v: float) -> None: ...
|
26 |
+
def get(self) -> Dict[Aggregation, float]: ...
|
27 |
+
|
28 |
+
class Event:
|
29 |
+
name: str
|
30 |
+
timestamp: datetime.datetime
|
31 |
+
data: Dict[str, Union[int, float, bool, str]]
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
name: str,
|
35 |
+
timestamp: datetime.datetime,
|
36 |
+
data: Dict[str, Union[int, float, bool, str]],
|
37 |
+
) -> None: ...
|
38 |
+
|
39 |
+
def log_event(e: Event) -> None: ...
|
40 |
+
|
41 |
+
class EventHandlerHandle: ...
|
42 |
+
|
43 |
+
def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ...
|
44 |
+
def unregister_event_handler(handle: EventHandlerHandle) -> None: ...
|
torch/_C/_nn.pyi
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mypy: disable-error-code="type-arg"
|
2 |
+
from typing import List, Optional, overload, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
from torch import memory_format, Tensor
|
5 |
+
from torch.types import _bool, _device, _dtype, _int, _size
|
6 |
+
|
7 |
+
# Defined in tools/autograd/templates/python_nn_functions.cpp
|
8 |
+
|
9 |
+
def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
|
10 |
+
def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
|
11 |
+
def avg_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
|
12 |
+
def avg_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
|
13 |
+
def elu_(input: Tensor, alpha: float = ...) -> Tensor: ...
|
14 |
+
def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
|
15 |
+
def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
|
16 |
+
def gelu(input: Tensor, approximate: str = ...) -> Tensor: ...
|
17 |
+
def hardsigmoid(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ...
|
18 |
+
def hardtanh(input: Tensor, min_val: float = ..., max_val: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
|
19 |
+
def hardtanh_(input: Tensor, min_val: float = ..., max_val: float = ...) -> Tensor: ...
|
20 |
+
def leaky_relu(input: Tensor, negative_slope: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
|
21 |
+
def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ...
|
22 |
+
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: ...
|
23 |
+
def log_sigmoid(input: Tensor) -> Tensor: ...
|
24 |
+
def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ...
|
25 |
+
def pad(input: Tensor, pad: Sequence[int], mode: str = ..., value: Optional[float] = None) -> Tensor: ...
|
26 |
+
def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None) -> Tensor: ...
|
27 |
+
def softplus(input: Tensor, beta: int = ..., threshold: int = ...) -> Tensor: ...
|
28 |
+
def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
|
29 |
+
|
30 |
+
# Defined in aten/src/ATen/native/mkldnn/Linear.cpp
|
31 |
+
def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ...
|
32 |
+
|
33 |
+
# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
|
34 |
+
def mkldnn_reorder_conv2d_weight(
|
35 |
+
self: Tensor,
|
36 |
+
padding: List,
|
37 |
+
stride: List,
|
38 |
+
dilatation: List,
|
39 |
+
groups: int,
|
40 |
+
) -> Tensor: ...
|
41 |
+
def mkldnn_reorder_conv3d_weight(
|
42 |
+
self: Tensor,
|
43 |
+
padding: List,
|
44 |
+
stride: List,
|
45 |
+
dilatation: List,
|
46 |
+
groups: int,
|
47 |
+
) -> Tensor: ...
|
48 |
+
|
49 |
+
# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
|
50 |
+
def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
|
51 |
+
|
52 |
+
# Defined at tools/autograd/templates/python_nn_functions.cpp
|
53 |
+
@overload
|
54 |
+
def _parse_to(
|
55 |
+
device: _device,
|
56 |
+
dtype: _dtype,
|
57 |
+
non_blocking: _bool,
|
58 |
+
copy: _bool,
|
59 |
+
*,
|
60 |
+
memory_format: memory_format,
|
61 |
+
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
|
62 |
+
@overload
|
63 |
+
def _parse_to(
|
64 |
+
dtype: _dtype,
|
65 |
+
non_blocking: _bool,
|
66 |
+
copy: _bool,
|
67 |
+
*,
|
68 |
+
memory_format: memory_format,
|
69 |
+
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
|
70 |
+
@overload
|
71 |
+
def _parse_to(
|
72 |
+
tensor: Tensor,
|
73 |
+
non_blocking: _bool,
|
74 |
+
copy: _bool,
|
75 |
+
*,
|
76 |
+
memory_format: memory_format,
|
77 |
+
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
|
78 |
+
|
79 |
+
# Defined in aten/src/ATen/native/PadSequence.cpp
|
80 |
+
def pad_sequence(
|
81 |
+
sequences: List[Tensor],
|
82 |
+
batch_first: bool = False,
|
83 |
+
padding_value: float = ...,
|
84 |
+
) -> Tensor: ...
|
85 |
+
def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ...
|
86 |
+
def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ...
|
torch/_C/_nvtx.pyi
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defined in torch/csrc/cuda/shared/nvtx.cpp
|
2 |
+
def rangePushA(message: str) -> int: ...
|
3 |
+
def rangePop() -> int: ...
|
4 |
+
def rangeStartA(message: str) -> int: ...
|
5 |
+
def rangeEnd(int) -> None: ...
|
6 |
+
def markA(message: str) -> None: ...
|
torch/_C/_onnx.pyi
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defined in torch/csrc/onnx/init.cpp
|
2 |
+
|
3 |
+
from enum import Enum
|
4 |
+
|
5 |
+
_CAFFE2_ATEN_FALLBACK: bool
|
6 |
+
PRODUCER_VERSION: str
|
7 |
+
|
8 |
+
class TensorProtoDataType(Enum):
|
9 |
+
UNDEFINED = ...
|
10 |
+
FLOAT = ...
|
11 |
+
UINT8 = ...
|
12 |
+
INT8 = ...
|
13 |
+
UINT16 = ...
|
14 |
+
INT16 = ...
|
15 |
+
INT32 = ...
|
16 |
+
INT64 = ...
|
17 |
+
STRING = ...
|
18 |
+
BOOL = ...
|
19 |
+
FLOAT16 = ...
|
20 |
+
DOUBLE = ...
|
21 |
+
UINT32 = ...
|
22 |
+
UINT64 = ...
|
23 |
+
COMPLEX64 = ...
|
24 |
+
COMPLEX128 = ...
|
25 |
+
BFLOAT16 = ...
|
26 |
+
FLOAT8E5M2 = ...
|
27 |
+
FLOAT8E4M3FN = ...
|
28 |
+
|
29 |
+
class OperatorExportTypes(Enum):
|
30 |
+
ONNX = ...
|
31 |
+
ONNX_ATEN = ...
|
32 |
+
ONNX_ATEN_FALLBACK = ...
|
33 |
+
ONNX_FALLTHROUGH = ...
|
34 |
+
|
35 |
+
class TrainingMode(Enum):
|
36 |
+
EVAL = ...
|
37 |
+
PRESERVE = ...
|
38 |
+
TRAINING = ...
|
torch/_C/_profiler.pyi
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
3 |
+
|
4 |
+
from torch._C import device, dtype, layout
|
5 |
+
from typing_extensions import TypeAlias
|
6 |
+
|
7 |
+
# defined in torch/csrc/profiler/python/init.cpp
|
8 |
+
|
9 |
+
class RecordScope(Enum):
|
10 |
+
FUNCTION = ...
|
11 |
+
BACKWARD_FUNCTION = ...
|
12 |
+
TORCHSCRIPT_FUNCTION = ...
|
13 |
+
KERNEL_FUNCTION_DTYPE = ...
|
14 |
+
CUSTOM_CLASS = ...
|
15 |
+
BUILD_FEATURE = ...
|
16 |
+
LITE_INTERPRETER = ...
|
17 |
+
USER_SCOPE = ...
|
18 |
+
STATIC_RUNTIME_OP = ...
|
19 |
+
STATIC_RUNTIME_MODEL = ...
|
20 |
+
|
21 |
+
class ProfilerState(Enum):
|
22 |
+
Disable = ...
|
23 |
+
CPU = ...
|
24 |
+
CUDA = ...
|
25 |
+
NVTX = ...
|
26 |
+
ITT = ...
|
27 |
+
KINETO = ...
|
28 |
+
KINETO_GPU_FALLBACK = ...
|
29 |
+
KINETO_PRIVATEUSE1_FALLBACK = ...
|
30 |
+
KINETO_PRIVATEUSE1 = ...
|
31 |
+
|
32 |
+
class ActiveProfilerType(Enum):
|
33 |
+
NONE = ...
|
34 |
+
LEGACY = ...
|
35 |
+
KINETO = ...
|
36 |
+
NVTX = ...
|
37 |
+
ITT = ...
|
38 |
+
|
39 |
+
class ProfilerActivity(Enum):
|
40 |
+
CPU = ...
|
41 |
+
CUDA = ...
|
42 |
+
MTIA = ...
|
43 |
+
PrivateUse1 = ...
|
44 |
+
|
45 |
+
class _EventType(Enum):
|
46 |
+
TorchOp = ...
|
47 |
+
Backend = ...
|
48 |
+
Allocation = ...
|
49 |
+
OutOfMemory = ...
|
50 |
+
PyCall = ...
|
51 |
+
PyCCall = ...
|
52 |
+
Kineto = ...
|
53 |
+
|
54 |
+
class _ExperimentalConfig:
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
profiler_metrics: List[str] = ...,
|
58 |
+
profiler_measure_per_kernel: bool = ...,
|
59 |
+
verbose: bool = ...,
|
60 |
+
performance_events: List[str] = ...,
|
61 |
+
enable_cuda_sync_events: bool = ...,
|
62 |
+
) -> None: ...
|
63 |
+
|
64 |
+
class ProfilerConfig:
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
state: ProfilerState,
|
68 |
+
report_input_shapes: bool,
|
69 |
+
profile_memory: bool,
|
70 |
+
with_stack: bool,
|
71 |
+
with_flops: bool,
|
72 |
+
with_modules: bool,
|
73 |
+
experimental_config: _ExperimentalConfig,
|
74 |
+
) -> None: ...
|
75 |
+
|
76 |
+
class _ProfilerEvent:
|
77 |
+
start_tid: int
|
78 |
+
start_time_ns: int
|
79 |
+
children: List[_ProfilerEvent]
|
80 |
+
|
81 |
+
# TODO(robieta): remove in favor of `self.typed`
|
82 |
+
extra_fields: Union[
|
83 |
+
_ExtraFields_TorchOp,
|
84 |
+
_ExtraFields_Backend,
|
85 |
+
_ExtraFields_Allocation,
|
86 |
+
_ExtraFields_OutOfMemory,
|
87 |
+
_ExtraFields_PyCall,
|
88 |
+
_ExtraFields_PyCCall,
|
89 |
+
_ExtraFields_Kineto,
|
90 |
+
]
|
91 |
+
|
92 |
+
@property
|
93 |
+
def typed(
|
94 |
+
self,
|
95 |
+
) -> Union[
|
96 |
+
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
|
97 |
+
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
|
98 |
+
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
|
99 |
+
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
|
100 |
+
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
|
101 |
+
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
|
102 |
+
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
|
103 |
+
]: ...
|
104 |
+
@property
|
105 |
+
def name(self) -> str: ...
|
106 |
+
@property
|
107 |
+
def tag(self) -> _EventType: ...
|
108 |
+
@property
|
109 |
+
def id(self) -> int: ...
|
110 |
+
@property
|
111 |
+
def parent(self) -> Optional[_ProfilerEvent]: ...
|
112 |
+
@property
|
113 |
+
def correlation_id(self) -> int: ...
|
114 |
+
@property
|
115 |
+
def end_time_ns(self) -> int: ...
|
116 |
+
@property
|
117 |
+
def duration_time_ns(self) -> int: ...
|
118 |
+
|
119 |
+
class _TensorMetadata:
|
120 |
+
impl_ptr: Optional[int]
|
121 |
+
storage_data_ptr: Optional[int]
|
122 |
+
id: Optional[int]
|
123 |
+
|
124 |
+
@property
|
125 |
+
def allocation_id(self) -> Optional[int]: ...
|
126 |
+
@property
|
127 |
+
def layout(self) -> layout: ...
|
128 |
+
@property
|
129 |
+
def device(self) -> device: ...
|
130 |
+
@property
|
131 |
+
def dtype(self) -> dtype: ...
|
132 |
+
@property
|
133 |
+
def sizes(self) -> List[int]: ...
|
134 |
+
@property
|
135 |
+
def strides(self) -> List[int]: ...
|
136 |
+
|
137 |
+
Scalar: TypeAlias = Union[int, float, bool, complex]
|
138 |
+
Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
|
139 |
+
|
140 |
+
class _ExtraFields_TorchOp:
|
141 |
+
name: str
|
142 |
+
sequence_number: int
|
143 |
+
allow_tf32_cublas: bool
|
144 |
+
|
145 |
+
@property
|
146 |
+
def inputs(self) -> List[Input]: ...
|
147 |
+
@property
|
148 |
+
def scope(self) -> RecordScope: ...
|
149 |
+
|
150 |
+
class _ExtraFields_Backend: ...
|
151 |
+
|
152 |
+
class _ExtraFields_Allocation:
|
153 |
+
ptr: int
|
154 |
+
id: Optional[int]
|
155 |
+
alloc_size: int
|
156 |
+
total_allocated: int
|
157 |
+
total_reserved: int
|
158 |
+
|
159 |
+
@property
|
160 |
+
def allocation_id(self) -> Optional[int]: ...
|
161 |
+
@property
|
162 |
+
def device(self) -> device: ...
|
163 |
+
|
164 |
+
class _ExtraFields_OutOfMemory: ...
|
165 |
+
|
166 |
+
class _PyFrameState:
|
167 |
+
line_number: int
|
168 |
+
function_name: str
|
169 |
+
|
170 |
+
@property
|
171 |
+
def file_name(self) -> str: ...
|
172 |
+
|
173 |
+
class _NNModuleInfo:
|
174 |
+
@property
|
175 |
+
def self_ptr(self) -> int: ...
|
176 |
+
@property
|
177 |
+
def cls_ptr(self) -> int: ...
|
178 |
+
@property
|
179 |
+
def cls_name(self) -> str: ...
|
180 |
+
@property
|
181 |
+
def parameters(
|
182 |
+
self,
|
183 |
+
) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
|
184 |
+
|
185 |
+
class _OptimizerInfo:
|
186 |
+
@property
|
187 |
+
def parameters(
|
188 |
+
self,
|
189 |
+
) -> List[
|
190 |
+
Tuple[
|
191 |
+
# Parameter
|
192 |
+
_TensorMetadata,
|
193 |
+
#
|
194 |
+
# Gradient (if present during optimizer.step())
|
195 |
+
Optional[_TensorMetadata],
|
196 |
+
#
|
197 |
+
# Optimizer state for Parameter as (name, tensor) pairs
|
198 |
+
List[Tuple[str, _TensorMetadata]],
|
199 |
+
]
|
200 |
+
]: ...
|
201 |
+
|
202 |
+
class _ExtraFields_PyCCall:
|
203 |
+
@property
|
204 |
+
def caller(self) -> _PyFrameState: ...
|
205 |
+
|
206 |
+
class _ExtraFields_PyCall:
|
207 |
+
@property
|
208 |
+
def callsite(self) -> _PyFrameState: ...
|
209 |
+
@property
|
210 |
+
def caller(self) -> _PyFrameState: ...
|
211 |
+
@property
|
212 |
+
def module(self) -> Optional[_NNModuleInfo]: ...
|
213 |
+
@property
|
214 |
+
def optimizer(self) -> Optional[_OptimizerInfo]: ...
|
215 |
+
|
216 |
+
class _ExtraFields_Kineto: ...
|
217 |
+
|
218 |
+
def _add_execution_trace_observer(output_file_path: str) -> bool: ...
|
219 |
+
def _remove_execution_trace_observer() -> None: ...
|
220 |
+
def _enable_execution_trace_observer() -> None: ...
|
221 |
+
def _disable_execution_trace_observer() -> None: ...
|
222 |
+
def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
|
223 |
+
def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
|
224 |
+
def _set_cuda_sync_enabled_val(val: bool) -> None: ...
|
225 |
+
|
226 |
+
class CapturedTraceback: ...
|
227 |
+
|
228 |
+
def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ...
|
229 |
+
|
230 |
+
# The Dict has name, filename, line
|
231 |
+
def symbolize_tracebacks(
|
232 |
+
to_symbolize: List[CapturedTraceback],
|
233 |
+
) -> List[List[Dict[str, str]]]: ...
|
234 |
+
|
235 |
+
class _RecordFunctionFast:
|
236 |
+
def __init__(self, name: str) -> None: ...
|
237 |
+
def __enter__(self) -> None: ...
|
238 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ...
|
torch/_C/_verbose.pyi
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Defined in torch/csrc/utils/verbose.cpp
|
2 |
+
def mkl_set_verbose(enable: int) -> int: ...
|
3 |
+
def mkldnn_set_verbose(level: int) -> int: ...
|
torch/_VF.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This makes the functions in torch._C._VariableFunctions available as
|
3 |
+
torch._VF.<funcname>
|
4 |
+
without mypy being able to find them.
|
5 |
+
|
6 |
+
A subset of those functions are mapped to ATen functions in
|
7 |
+
torch/jit/_builtins.py
|
8 |
+
|
9 |
+
See https://github.com/pytorch/pytorch/issues/21478 for the reason for
|
10 |
+
introducing torch._VF
|
11 |
+
|
12 |
+
"""
|
13 |
+
import sys
|
14 |
+
import types
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
|
19 |
+
class VFModule(types.ModuleType):
|
20 |
+
vf: types.ModuleType
|
21 |
+
|
22 |
+
def __init__(self, name):
|
23 |
+
super().__init__(name)
|
24 |
+
self.vf = torch._C._VariableFunctions
|
25 |
+
|
26 |
+
def __getattr__(self, attr):
|
27 |
+
return getattr(self.vf, attr)
|
28 |
+
|
29 |
+
|
30 |
+
sys.modules[__name__] = VFModule(__name__)
|
torch/_VF.pyi
ADDED
The diff for this file is too large to render.
See raw diff
|
|
torch/__config__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def show():
|
5 |
+
"""
|
6 |
+
Return a human-readable string with descriptions of the
|
7 |
+
configuration of PyTorch.
|
8 |
+
"""
|
9 |
+
return torch._C._show_config()
|
10 |
+
|
11 |
+
|
12 |
+
# TODO: In principle, we could provide more structured version/config
|
13 |
+
# information here. For now only CXX_FLAGS is exposed, as Timer
|
14 |
+
# uses them.
|
15 |
+
def _cxx_flags():
|
16 |
+
"""Returns the CXX_FLAGS used when building PyTorch."""
|
17 |
+
return torch._C._cxx_flags()
|
18 |
+
|
19 |
+
|
20 |
+
def parallel_info():
|
21 |
+
r"""Returns detailed string with parallelization settings"""
|
22 |
+
return torch._C._parallel_info()
|
torch/__future__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This global flag controls whether to assign new tensors to the parameters
|
3 |
+
instead of changing the existing parameters in-place when converting an `nn.Module`
|
4 |
+
using the following methods:
|
5 |
+
1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
|
6 |
+
2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
|
7 |
+
3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
|
8 |
+
4. `module._apply(fn)` (for generic functions applied to `module`)
|
9 |
+
|
10 |
+
Default: False
|
11 |
+
"""
|
12 |
+
_overwrite_module_params_on_conversion = False
|
13 |
+
|
14 |
+
|
15 |
+
def set_overwrite_module_params_on_conversion(value):
|
16 |
+
global _overwrite_module_params_on_conversion
|
17 |
+
_overwrite_module_params_on_conversion = value
|
18 |
+
|
19 |
+
|
20 |
+
def get_overwrite_module_params_on_conversion():
|
21 |
+
return _overwrite_module_params_on_conversion
|
torch/_appdirs.py
ADDED
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright (c) 2005-2010 ActiveState Software Inc.
|
4 |
+
# Copyright (c) 2013 Eddy Petrișor
|
5 |
+
|
6 |
+
# flake8: noqa
|
7 |
+
|
8 |
+
"""
|
9 |
+
This file is directly from
|
10 |
+
https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py
|
11 |
+
|
12 |
+
The license of https://github.com/ActiveState/appdirs copied below:
|
13 |
+
|
14 |
+
|
15 |
+
# This is the MIT license
|
16 |
+
|
17 |
+
Copyright (c) 2010 ActiveState Software Inc.
|
18 |
+
|
19 |
+
Permission is hereby granted, free of charge, to any person obtaining a
|
20 |
+
copy of this software and associated documentation files (the
|
21 |
+
"Software"), to deal in the Software without restriction, including
|
22 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
23 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
24 |
+
permit persons to whom the Software is furnished to do so, subject to
|
25 |
+
the following conditions:
|
26 |
+
|
27 |
+
The above copyright notice and this permission notice shall be included
|
28 |
+
in all copies or substantial portions of the Software.
|
29 |
+
|
30 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
31 |
+
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
32 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
33 |
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
34 |
+
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
35 |
+
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
36 |
+
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
37 |
+
"""
|
38 |
+
|
39 |
+
"""Utilities for determining application-specific dirs.
|
40 |
+
|
41 |
+
See <https://github.com/ActiveState/appdirs> for details and usage.
|
42 |
+
"""
|
43 |
+
# Dev Notes:
|
44 |
+
# - MSDN on where to store app data files:
|
45 |
+
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
|
46 |
+
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
|
47 |
+
# - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
|
48 |
+
|
49 |
+
__version__ = "1.4.4"
|
50 |
+
__version_info__ = tuple(int(segment) for segment in __version__.split("."))
|
51 |
+
|
52 |
+
|
53 |
+
import os
|
54 |
+
import sys
|
55 |
+
|
56 |
+
unicode = str
|
57 |
+
|
58 |
+
if sys.platform.startswith("java"):
|
59 |
+
import platform
|
60 |
+
|
61 |
+
os_name = platform.java_ver()[3][0]
|
62 |
+
if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc.
|
63 |
+
system = "win32"
|
64 |
+
elif os_name.startswith("Mac"): # "Mac OS X", etc.
|
65 |
+
system = "darwin"
|
66 |
+
else: # "Linux", "SunOS", "FreeBSD", etc.
|
67 |
+
# Setting this to "linux2" is not ideal, but only Windows or Mac
|
68 |
+
# are actually checked for and the rest of the module expects
|
69 |
+
# *sys.platform* style strings.
|
70 |
+
system = "linux2"
|
71 |
+
else:
|
72 |
+
system = sys.platform
|
73 |
+
|
74 |
+
|
75 |
+
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
|
76 |
+
r"""Return full path to the user-specific data dir for this application.
|
77 |
+
|
78 |
+
"appname" is the name of application.
|
79 |
+
If None, just the system directory is returned.
|
80 |
+
"appauthor" (only used on Windows) is the name of the
|
81 |
+
appauthor or distributing body for this application. Typically
|
82 |
+
it is the owning company name. This falls back to appname. You may
|
83 |
+
pass False to disable it.
|
84 |
+
"version" is an optional version path element to append to the
|
85 |
+
path. You might want to use this if you want multiple versions
|
86 |
+
of your app to be able to run independently. If used, this
|
87 |
+
would typically be "<major>.<minor>".
|
88 |
+
Only applied when appname is present.
|
89 |
+
"roaming" (boolean, default False) can be set True to use the Windows
|
90 |
+
roaming appdata directory. That means that for users on a Windows
|
91 |
+
network setup for roaming profiles, this user data will be
|
92 |
+
sync'd on login. See
|
93 |
+
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
94 |
+
for a discussion of issues.
|
95 |
+
|
96 |
+
Typical user data directories are:
|
97 |
+
Mac OS X: ~/Library/Application Support/<AppName>
|
98 |
+
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
|
99 |
+
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
|
100 |
+
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
|
101 |
+
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
|
102 |
+
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
|
103 |
+
|
104 |
+
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
|
105 |
+
That means, by default "~/.local/share/<AppName>".
|
106 |
+
"""
|
107 |
+
if system == "win32":
|
108 |
+
if appauthor is None:
|
109 |
+
appauthor = appname
|
110 |
+
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
|
111 |
+
path = os.path.normpath(_get_win_folder(const))
|
112 |
+
if appname:
|
113 |
+
if appauthor is not False:
|
114 |
+
path = os.path.join(path, appauthor, appname)
|
115 |
+
else:
|
116 |
+
path = os.path.join(path, appname)
|
117 |
+
elif system == "darwin":
|
118 |
+
path = os.path.expanduser("~/Library/Application Support/")
|
119 |
+
if appname:
|
120 |
+
path = os.path.join(path, appname)
|
121 |
+
else:
|
122 |
+
path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
|
123 |
+
if appname:
|
124 |
+
path = os.path.join(path, appname)
|
125 |
+
if appname and version:
|
126 |
+
path = os.path.join(path, version)
|
127 |
+
return path
|
128 |
+
|
129 |
+
|
130 |
+
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
|
131 |
+
r"""Return full path to the user-shared data dir for this application.
|
132 |
+
|
133 |
+
"appname" is the name of application.
|
134 |
+
If None, just the system directory is returned.
|
135 |
+
"appauthor" (only used on Windows) is the name of the
|
136 |
+
appauthor or distributing body for this application. Typically
|
137 |
+
it is the owning company name. This falls back to appname. You may
|
138 |
+
pass False to disable it.
|
139 |
+
"version" is an optional version path element to append to the
|
140 |
+
path. You might want to use this if you want multiple versions
|
141 |
+
of your app to be able to run independently. If used, this
|
142 |
+
would typically be "<major>.<minor>".
|
143 |
+
Only applied when appname is present.
|
144 |
+
"multipath" is an optional parameter only applicable to *nix
|
145 |
+
which indicates that the entire list of data dirs should be
|
146 |
+
returned. By default, the first item from XDG_DATA_DIRS is
|
147 |
+
returned, or '/usr/local/share/<AppName>',
|
148 |
+
if XDG_DATA_DIRS is not set
|
149 |
+
|
150 |
+
Typical site data directories are:
|
151 |
+
Mac OS X: /Library/Application Support/<AppName>
|
152 |
+
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
|
153 |
+
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
|
154 |
+
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
155 |
+
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
|
156 |
+
|
157 |
+
For Unix, this is using the $XDG_DATA_DIRS[0] default.
|
158 |
+
|
159 |
+
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
160 |
+
"""
|
161 |
+
if system == "win32":
|
162 |
+
if appauthor is None:
|
163 |
+
appauthor = appname
|
164 |
+
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
|
165 |
+
if appname:
|
166 |
+
if appauthor is not False:
|
167 |
+
path = os.path.join(path, appauthor, appname)
|
168 |
+
else:
|
169 |
+
path = os.path.join(path, appname)
|
170 |
+
elif system == "darwin":
|
171 |
+
path = os.path.expanduser("/Library/Application Support")
|
172 |
+
if appname:
|
173 |
+
path = os.path.join(path, appname)
|
174 |
+
else:
|
175 |
+
# XDG default for $XDG_DATA_DIRS
|
176 |
+
# only first, if multipath is False
|
177 |
+
path = os.getenv(
|
178 |
+
"XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"])
|
179 |
+
)
|
180 |
+
pathlist = [
|
181 |
+
os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
|
182 |
+
]
|
183 |
+
if appname:
|
184 |
+
if version:
|
185 |
+
appname = os.path.join(appname, version)
|
186 |
+
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
187 |
+
|
188 |
+
if multipath:
|
189 |
+
path = os.pathsep.join(pathlist)
|
190 |
+
else:
|
191 |
+
path = pathlist[0]
|
192 |
+
return path
|
193 |
+
|
194 |
+
if appname and version:
|
195 |
+
path = os.path.join(path, version)
|
196 |
+
return path
|
197 |
+
|
198 |
+
|
199 |
+
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
|
200 |
+
r"""Return full path to the user-specific config dir for this application.
|
201 |
+
|
202 |
+
"appname" is the name of application.
|
203 |
+
If None, just the system directory is returned.
|
204 |
+
"appauthor" (only used on Windows) is the name of the
|
205 |
+
appauthor or distributing body for this application. Typically
|
206 |
+
it is the owning company name. This falls back to appname. You may
|
207 |
+
pass False to disable it.
|
208 |
+
"version" is an optional version path element to append to the
|
209 |
+
path. You might want to use this if you want multiple versions
|
210 |
+
of your app to be able to run independently. If used, this
|
211 |
+
would typically be "<major>.<minor>".
|
212 |
+
Only applied when appname is present.
|
213 |
+
"roaming" (boolean, default False) can be set True to use the Windows
|
214 |
+
roaming appdata directory. That means that for users on a Windows
|
215 |
+
network setup for roaming profiles, this user data will be
|
216 |
+
sync'd on login. See
|
217 |
+
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
218 |
+
for a discussion of issues.
|
219 |
+
|
220 |
+
Typical user config directories are:
|
221 |
+
Mac OS X: ~/Library/Preferences/<AppName>
|
222 |
+
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
|
223 |
+
Win *: same as user_data_dir
|
224 |
+
|
225 |
+
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
|
226 |
+
That means, by default "~/.config/<AppName>".
|
227 |
+
"""
|
228 |
+
if system == "win32":
|
229 |
+
path = user_data_dir(appname, appauthor, None, roaming)
|
230 |
+
elif system == "darwin":
|
231 |
+
path = os.path.expanduser("~/Library/Preferences/")
|
232 |
+
if appname:
|
233 |
+
path = os.path.join(path, appname)
|
234 |
+
else:
|
235 |
+
path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
|
236 |
+
if appname:
|
237 |
+
path = os.path.join(path, appname)
|
238 |
+
if appname and version:
|
239 |
+
path = os.path.join(path, version)
|
240 |
+
return path
|
241 |
+
|
242 |
+
|
243 |
+
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
|
244 |
+
r"""Return full path to the user-shared data dir for this application.
|
245 |
+
|
246 |
+
"appname" is the name of application.
|
247 |
+
If None, just the system directory is returned.
|
248 |
+
"appauthor" (only used on Windows) is the name of the
|
249 |
+
appauthor or distributing body for this application. Typically
|
250 |
+
it is the owning company name. This falls back to appname. You may
|
251 |
+
pass False to disable it.
|
252 |
+
"version" is an optional version path element to append to the
|
253 |
+
path. You might want to use this if you want multiple versions
|
254 |
+
of your app to be able to run independently. If used, this
|
255 |
+
would typically be "<major>.<minor>".
|
256 |
+
Only applied when appname is present.
|
257 |
+
"multipath" is an optional parameter only applicable to *nix
|
258 |
+
which indicates that the entire list of config dirs should be
|
259 |
+
returned. By default, the first item from XDG_CONFIG_DIRS is
|
260 |
+
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
|
261 |
+
|
262 |
+
Typical site config directories are:
|
263 |
+
Mac OS X: same as site_data_dir
|
264 |
+
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
|
265 |
+
$XDG_CONFIG_DIRS
|
266 |
+
Win *: same as site_data_dir
|
267 |
+
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
268 |
+
|
269 |
+
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
|
270 |
+
|
271 |
+
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
272 |
+
"""
|
273 |
+
if system == "win32":
|
274 |
+
path = site_data_dir(appname, appauthor)
|
275 |
+
if appname and version:
|
276 |
+
path = os.path.join(path, version)
|
277 |
+
elif system == "darwin":
|
278 |
+
path = os.path.expanduser("/Library/Preferences")
|
279 |
+
if appname:
|
280 |
+
path = os.path.join(path, appname)
|
281 |
+
else:
|
282 |
+
# XDG default for $XDG_CONFIG_DIRS
|
283 |
+
# only first, if multipath is False
|
284 |
+
path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg")
|
285 |
+
pathlist = [
|
286 |
+
os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
|
287 |
+
]
|
288 |
+
if appname:
|
289 |
+
if version:
|
290 |
+
appname = os.path.join(appname, version)
|
291 |
+
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
292 |
+
|
293 |
+
if multipath:
|
294 |
+
path = os.pathsep.join(pathlist)
|
295 |
+
else:
|
296 |
+
path = pathlist[0]
|
297 |
+
return path
|
298 |
+
|
299 |
+
|
300 |
+
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
|
301 |
+
r"""Return full path to the user-specific cache dir for this application.
|
302 |
+
|
303 |
+
"appname" is the name of application.
|
304 |
+
If None, just the system directory is returned.
|
305 |
+
"appauthor" (only used on Windows) is the name of the
|
306 |
+
appauthor or distributing body for this application. Typically
|
307 |
+
it is the owning company name. This falls back to appname. You may
|
308 |
+
pass False to disable it.
|
309 |
+
"version" is an optional version path element to append to the
|
310 |
+
path. You might want to use this if you want multiple versions
|
311 |
+
of your app to be able to run independently. If used, this
|
312 |
+
would typically be "<major>.<minor>".
|
313 |
+
Only applied when appname is present.
|
314 |
+
"opinion" (boolean) can be False to disable the appending of
|
315 |
+
"Cache" to the base app data dir for Windows. See
|
316 |
+
discussion below.
|
317 |
+
|
318 |
+
Typical user cache directories are:
|
319 |
+
Mac OS X: ~/Library/Caches/<AppName>
|
320 |
+
Unix: ~/.cache/<AppName> (XDG default)
|
321 |
+
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
|
322 |
+
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
|
323 |
+
|
324 |
+
On Windows the only suggestion in the MSDN docs is that local settings go in
|
325 |
+
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
|
326 |
+
app data dir (the default returned by `user_data_dir` above). Apps typically
|
327 |
+
put cache data somewhere *under* the given dir here. Some examples:
|
328 |
+
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
|
329 |
+
...\Acme\SuperApp\Cache\1.0
|
330 |
+
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
|
331 |
+
This can be disabled with the `opinion=False` option.
|
332 |
+
"""
|
333 |
+
if system == "win32":
|
334 |
+
if appauthor is None:
|
335 |
+
appauthor = appname
|
336 |
+
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
|
337 |
+
if appname:
|
338 |
+
if appauthor is not False:
|
339 |
+
path = os.path.join(path, appauthor, appname)
|
340 |
+
else:
|
341 |
+
path = os.path.join(path, appname)
|
342 |
+
if opinion:
|
343 |
+
path = os.path.join(path, "Cache")
|
344 |
+
elif system == "darwin":
|
345 |
+
path = os.path.expanduser("~/Library/Caches")
|
346 |
+
if appname:
|
347 |
+
path = os.path.join(path, appname)
|
348 |
+
else:
|
349 |
+
path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
350 |
+
if appname:
|
351 |
+
path = os.path.join(path, appname)
|
352 |
+
if appname and version:
|
353 |
+
path = os.path.join(path, version)
|
354 |
+
return path
|
355 |
+
|
356 |
+
|
357 |
+
def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
|
358 |
+
r"""Return full path to the user-specific state dir for this application.
|
359 |
+
|
360 |
+
"appname" is the name of application.
|
361 |
+
If None, just the system directory is returned.
|
362 |
+
"appauthor" (only used on Windows) is the name of the
|
363 |
+
appauthor or distributing body for this application. Typically
|
364 |
+
it is the owning company name. This falls back to appname. You may
|
365 |
+
pass False to disable it.
|
366 |
+
"version" is an optional version path element to append to the
|
367 |
+
path. You might want to use this if you want multiple versions
|
368 |
+
of your app to be able to run independently. If used, this
|
369 |
+
would typically be "<major>.<minor>".
|
370 |
+
Only applied when appname is present.
|
371 |
+
"roaming" (boolean, default False) can be set True to use the Windows
|
372 |
+
roaming appdata directory. That means that for users on a Windows
|
373 |
+
network setup for roaming profiles, this user data will be
|
374 |
+
sync'd on login. See
|
375 |
+
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
376 |
+
for a discussion of issues.
|
377 |
+
|
378 |
+
Typical user state directories are:
|
379 |
+
Mac OS X: same as user_data_dir
|
380 |
+
Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
|
381 |
+
Win *: same as user_data_dir
|
382 |
+
|
383 |
+
For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
|
384 |
+
to extend the XDG spec and support $XDG_STATE_HOME.
|
385 |
+
|
386 |
+
That means, by default "~/.local/state/<AppName>".
|
387 |
+
"""
|
388 |
+
if system in ["win32", "darwin"]:
|
389 |
+
path = user_data_dir(appname, appauthor, None, roaming)
|
390 |
+
else:
|
391 |
+
path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state"))
|
392 |
+
if appname:
|
393 |
+
path = os.path.join(path, appname)
|
394 |
+
if appname and version:
|
395 |
+
path = os.path.join(path, version)
|
396 |
+
return path
|
397 |
+
|
398 |
+
|
399 |
+
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
|
400 |
+
r"""Return full path to the user-specific log dir for this application.
|
401 |
+
|
402 |
+
"appname" is the name of application.
|
403 |
+
If None, just the system directory is returned.
|
404 |
+
"appauthor" (only used on Windows) is the name of the
|
405 |
+
appauthor or distributing body for this application. Typically
|
406 |
+
it is the owning company name. This falls back to appname. You may
|
407 |
+
pass False to disable it.
|
408 |
+
"version" is an optional version path element to append to the
|
409 |
+
path. You might want to use this if you want multiple versions
|
410 |
+
of your app to be able to run independently. If used, this
|
411 |
+
would typically be "<major>.<minor>".
|
412 |
+
Only applied when appname is present.
|
413 |
+
"opinion" (boolean) can be False to disable the appending of
|
414 |
+
"Logs" to the base app data dir for Windows, and "log" to the
|
415 |
+
base cache dir for Unix. See discussion below.
|
416 |
+
|
417 |
+
Typical user log directories are:
|
418 |
+
Mac OS X: ~/Library/Logs/<AppName>
|
419 |
+
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
|
420 |
+
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
|
421 |
+
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
|
422 |
+
|
423 |
+
On Windows the only suggestion in the MSDN docs is that local settings
|
424 |
+
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
|
425 |
+
examples of what some windows apps use for a logs dir.)
|
426 |
+
|
427 |
+
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
|
428 |
+
value for Windows and appends "log" to the user cache dir for Unix.
|
429 |
+
This can be disabled with the `opinion=False` option.
|
430 |
+
"""
|
431 |
+
if system == "darwin":
|
432 |
+
path = os.path.join(os.path.expanduser("~/Library/Logs"), appname)
|
433 |
+
elif system == "win32":
|
434 |
+
path = user_data_dir(appname, appauthor, version)
|
435 |
+
version = False
|
436 |
+
if opinion:
|
437 |
+
path = os.path.join(path, "Logs")
|
438 |
+
else:
|
439 |
+
path = user_cache_dir(appname, appauthor, version)
|
440 |
+
version = False
|
441 |
+
if opinion:
|
442 |
+
path = os.path.join(path, "log")
|
443 |
+
if appname and version:
|
444 |
+
path = os.path.join(path, version)
|
445 |
+
return path
|
446 |
+
|
447 |
+
|
448 |
+
class AppDirs(object):
|
449 |
+
"""Convenience wrapper for getting application dirs."""
|
450 |
+
|
451 |
+
def __init__(
|
452 |
+
self, appname=None, appauthor=None, version=None, roaming=False, multipath=False
|
453 |
+
):
|
454 |
+
self.appname = appname
|
455 |
+
self.appauthor = appauthor
|
456 |
+
self.version = version
|
457 |
+
self.roaming = roaming
|
458 |
+
self.multipath = multipath
|
459 |
+
|
460 |
+
@property
|
461 |
+
def user_data_dir(self):
|
462 |
+
return user_data_dir(
|
463 |
+
self.appname, self.appauthor, version=self.version, roaming=self.roaming
|
464 |
+
)
|
465 |
+
|
466 |
+
@property
|
467 |
+
def site_data_dir(self):
|
468 |
+
return site_data_dir(
|
469 |
+
self.appname, self.appauthor, version=self.version, multipath=self.multipath
|
470 |
+
)
|
471 |
+
|
472 |
+
@property
|
473 |
+
def user_config_dir(self):
|
474 |
+
return user_config_dir(
|
475 |
+
self.appname, self.appauthor, version=self.version, roaming=self.roaming
|
476 |
+
)
|
477 |
+
|
478 |
+
@property
|
479 |
+
def site_config_dir(self):
|
480 |
+
return site_config_dir(
|
481 |
+
self.appname, self.appauthor, version=self.version, multipath=self.multipath
|
482 |
+
)
|
483 |
+
|
484 |
+
@property
|
485 |
+
def user_cache_dir(self):
|
486 |
+
return user_cache_dir(self.appname, self.appauthor, version=self.version)
|
487 |
+
|
488 |
+
@property
|
489 |
+
def user_state_dir(self):
|
490 |
+
return user_state_dir(self.appname, self.appauthor, version=self.version)
|
491 |
+
|
492 |
+
@property
|
493 |
+
def user_log_dir(self):
|
494 |
+
return user_log_dir(self.appname, self.appauthor, version=self.version)
|
495 |
+
|
496 |
+
|
497 |
+
# ---- internal support stuff
|
498 |
+
|
499 |
+
|
500 |
+
def _get_win_folder_from_registry(csidl_name):
|
501 |
+
"""This is a fallback technique at best. I'm not sure if using the
|
502 |
+
registry for this guarantees us the correct answer for all CSIDL_*
|
503 |
+
names.
|
504 |
+
"""
|
505 |
+
import winreg as _winreg
|
506 |
+
|
507 |
+
shell_folder_name = {
|
508 |
+
"CSIDL_APPDATA": "AppData",
|
509 |
+
"CSIDL_COMMON_APPDATA": "Common AppData",
|
510 |
+
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
511 |
+
}[csidl_name]
|
512 |
+
|
513 |
+
key = _winreg.OpenKey(
|
514 |
+
_winreg.HKEY_CURRENT_USER,
|
515 |
+
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
|
516 |
+
)
|
517 |
+
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
|
518 |
+
return dir
|
519 |
+
|
520 |
+
|
521 |
+
def _get_win_folder_with_pywin32(csidl_name):
|
522 |
+
from win32com.shell import shell, shellcon
|
523 |
+
|
524 |
+
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
|
525 |
+
# Try to make this a unicode path because SHGetFolderPath does
|
526 |
+
# not return unicode strings when there is unicode data in the
|
527 |
+
# path.
|
528 |
+
try:
|
529 |
+
dir = unicode(dir)
|
530 |
+
|
531 |
+
# Downgrade to short path name if have highbit chars. See
|
532 |
+
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
533 |
+
has_high_char = False
|
534 |
+
for c in dir:
|
535 |
+
if ord(c) > 255:
|
536 |
+
has_high_char = True
|
537 |
+
break
|
538 |
+
if has_high_char:
|
539 |
+
try:
|
540 |
+
import win32api
|
541 |
+
|
542 |
+
dir = win32api.GetShortPathName(dir)
|
543 |
+
except ImportError:
|
544 |
+
pass
|
545 |
+
except UnicodeError:
|
546 |
+
pass
|
547 |
+
return dir
|
548 |
+
|
549 |
+
|
550 |
+
def _get_win_folder_with_ctypes(csidl_name):
|
551 |
+
import ctypes
|
552 |
+
|
553 |
+
csidl_const = {
|
554 |
+
"CSIDL_APPDATA": 26,
|
555 |
+
"CSIDL_COMMON_APPDATA": 35,
|
556 |
+
"CSIDL_LOCAL_APPDATA": 28,
|
557 |
+
}[csidl_name]
|
558 |
+
|
559 |
+
buf = ctypes.create_unicode_buffer(1024)
|
560 |
+
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
|
561 |
+
|
562 |
+
# Downgrade to short path name if have highbit chars. See
|
563 |
+
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
564 |
+
has_high_char = False
|
565 |
+
for c in buf:
|
566 |
+
if ord(c) > 255:
|
567 |
+
has_high_char = True
|
568 |
+
break
|
569 |
+
if has_high_char:
|
570 |
+
buf2 = ctypes.create_unicode_buffer(1024)
|
571 |
+
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
|
572 |
+
buf = buf2
|
573 |
+
|
574 |
+
return buf.value
|
575 |
+
|
576 |
+
|
577 |
+
def _get_win_folder_with_jna(csidl_name):
|
578 |
+
import array
|
579 |
+
|
580 |
+
from com.sun import jna
|
581 |
+
from com.sun.jna.platform import win32
|
582 |
+
|
583 |
+
buf_size = win32.WinDef.MAX_PATH * 2
|
584 |
+
buf = array.zeros("c", buf_size)
|
585 |
+
shell = win32.Shell32.INSTANCE
|
586 |
+
shell.SHGetFolderPath(
|
587 |
+
None,
|
588 |
+
getattr(win32.ShlObj, csidl_name),
|
589 |
+
None,
|
590 |
+
win32.ShlObj.SHGFP_TYPE_CURRENT,
|
591 |
+
buf,
|
592 |
+
)
|
593 |
+
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
594 |
+
|
595 |
+
# Downgrade to short path name if have highbit chars. See
|
596 |
+
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
597 |
+
has_high_char = False
|
598 |
+
for c in dir:
|
599 |
+
if ord(c) > 255:
|
600 |
+
has_high_char = True
|
601 |
+
break
|
602 |
+
if has_high_char:
|
603 |
+
buf = array.zeros("c", buf_size)
|
604 |
+
kernel = win32.Kernel32.INSTANCE
|
605 |
+
if kernel.GetShortPathName(dir, buf, buf_size):
|
606 |
+
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
607 |
+
|
608 |
+
return dir
|
609 |
+
|
610 |
+
|
611 |
+
if system == "win32":
|
612 |
+
try:
|
613 |
+
import win32com.shell
|
614 |
+
|
615 |
+
_get_win_folder = _get_win_folder_with_pywin32
|
616 |
+
except ImportError:
|
617 |
+
try:
|
618 |
+
from ctypes import windll
|
619 |
+
|
620 |
+
_get_win_folder = _get_win_folder_with_ctypes
|
621 |
+
except ImportError:
|
622 |
+
try:
|
623 |
+
import com.sun.jna
|
624 |
+
|
625 |
+
_get_win_folder = _get_win_folder_with_jna
|
626 |
+
except ImportError:
|
627 |
+
_get_win_folder = _get_win_folder_from_registry
|
628 |
+
|
629 |
+
|
630 |
+
# ---- self test code
|
631 |
+
|
632 |
+
if __name__ == "__main__":
|
633 |
+
appname = "MyApp"
|
634 |
+
appauthor = "MyCompany"
|
635 |
+
|
636 |
+
props = (
|
637 |
+
"user_data_dir",
|
638 |
+
"user_config_dir",
|
639 |
+
"user_cache_dir",
|
640 |
+
"user_state_dir",
|
641 |
+
"user_log_dir",
|
642 |
+
"site_data_dir",
|
643 |
+
"site_config_dir",
|
644 |
+
)
|
645 |
+
|
646 |
+
print(f"-- app dirs {__version__} --")
|
647 |
+
|
648 |
+
print("-- app dirs (with optional 'version')")
|
649 |
+
dirs = AppDirs(appname, appauthor, version="1.0")
|
650 |
+
for prop in props:
|
651 |
+
print(f"{prop}: {getattr(dirs, prop)}")
|
652 |
+
|
653 |
+
print("\n-- app dirs (without optional 'version')")
|
654 |
+
dirs = AppDirs(appname, appauthor)
|
655 |
+
for prop in props:
|
656 |
+
print(f"{prop}: {getattr(dirs, prop)}")
|
657 |
+
|
658 |
+
print("\n-- app dirs (without optional 'appauthor')")
|
659 |
+
dirs = AppDirs(appname)
|
660 |
+
for prop in props:
|
661 |
+
print(f"{prop}: {getattr(dirs, prop)}")
|
662 |
+
|
663 |
+
print("\n-- app dirs (with disabled 'appauthor')")
|
664 |
+
dirs = AppDirs(appname, appauthor=False)
|
665 |
+
for prop in props:
|
666 |
+
print(f"{prop}: {getattr(dirs, prop)}")
|
torch/_awaits/__init__.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import cast, Callable, Generic, Type, TypeVar
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
__all__ = ['Await']
|
8 |
+
|
9 |
+
W = TypeVar("W")
|
10 |
+
|
11 |
+
class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
|
12 |
+
pass
|
13 |
+
|
14 |
+
class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
|
15 |
+
r"""
|
16 |
+
Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
|
17 |
+
of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
|
18 |
+
``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
|
19 |
+
|
20 |
+
Torch scriptable manipulations:
|
21 |
+
``torch.jit._awaitable(func, *args)``
|
22 |
+
Creates ``Await[W]`` object, where W is return type of func.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
``torch.jit._awaitable_wait(Await[W])``
|
26 |
+
Returns the result of the function, specified at ``_awaitable``, with specified arguments.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
The result of type ``W`` of the function call. The result is owned by ``Await[W]``
|
30 |
+
and returned on all following ``_awaitable_wait`` calls.
|
31 |
+
|
32 |
+
|
33 |
+
``torch.jit._awaitable_nowait(W)``
|
34 |
+
Returns:
|
35 |
+
Trivial ``Await[W]`` with specified result.
|
36 |
+
|
37 |
+
|
38 |
+
Only in eager mode:
|
39 |
+
``fn() -> Callable[Tuple[Any], W]``
|
40 |
+
Returns:
|
41 |
+
Specified at ``_awaitable`` python function ``func``.
|
42 |
+
|
43 |
+
``args() -> Tuple[Any]``
|
44 |
+
Returns:
|
45 |
+
Specified at ``_awaitable`` python args.
|
46 |
+
|
47 |
+
``is_nowait() -> _bool``
|
48 |
+
Returns:
|
49 |
+
``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
|
50 |
+
|
51 |
+
In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
|
52 |
+
``_awaitable_wait()`` call will be transparently added.
|
53 |
+
"""
|
54 |
+
pass
|
torch/_awaits/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.08 kB). View file
|
|
torch/_classes.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
|
3 |
+
import torch._C
|
4 |
+
|
5 |
+
|
6 |
+
class _ClassNamespace(types.ModuleType):
|
7 |
+
def __init__(self, name):
|
8 |
+
super().__init__("torch.classes" + name)
|
9 |
+
self.name = name
|
10 |
+
|
11 |
+
def __getattr__(self, attr):
|
12 |
+
proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
|
13 |
+
if proxy is None:
|
14 |
+
raise RuntimeError(f"Class {self.name}.{attr} not registered!")
|
15 |
+
return proxy
|
16 |
+
|
17 |
+
|
18 |
+
class _Classes(types.ModuleType):
|
19 |
+
__file__ = "_classes.py"
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
super().__init__("torch.classes")
|
23 |
+
|
24 |
+
def __getattr__(self, name):
|
25 |
+
namespace = _ClassNamespace(name)
|
26 |
+
setattr(self, name, namespace)
|
27 |
+
return namespace
|
28 |
+
|
29 |
+
@property
|
30 |
+
def loaded_libraries(self):
|
31 |
+
return torch.ops.loaded_libraries
|
32 |
+
|
33 |
+
def load_library(self, path):
|
34 |
+
"""
|
35 |
+
Loads a shared library from the given path into the current process.
|
36 |
+
|
37 |
+
The library being loaded may run global initialization code to register
|
38 |
+
custom classes with the PyTorch JIT runtime. This allows dynamically
|
39 |
+
loading custom classes. For this, you should compile your class
|
40 |
+
and the static registration code into a shared library object, and then
|
41 |
+
call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
|
42 |
+
shared object.
|
43 |
+
|
44 |
+
After the library is loaded, it is added to the
|
45 |
+
``torch.classes.loaded_libraries`` attribute, a set that may be inspected
|
46 |
+
for the paths of all libraries loaded using this function.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
path (str): A path to a shared library to load.
|
50 |
+
"""
|
51 |
+
torch.ops.load_library(path)
|
52 |
+
|
53 |
+
|
54 |
+
# The classes "namespace"
|
55 |
+
classes = _Classes()
|
torch/_compile.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
APIs related to torch.compile which lazily import torch._dynamo to avoid
|
3 |
+
circular dependencies.
|
4 |
+
"""
|
5 |
+
import functools
|
6 |
+
|
7 |
+
|
8 |
+
def _disable_dynamo(fn=None, recursive=True):
|
9 |
+
"""
|
10 |
+
This API should be only used inside torch, external users should still use
|
11 |
+
torch._dynamo.disable. The main goal of this API is to avoid circular
|
12 |
+
imports issues that is common while using _dynamo.disable inside torch
|
13 |
+
itself.
|
14 |
+
|
15 |
+
This API avoids it by lazily importing torch._dynamo from the import time to
|
16 |
+
the invocation of the decorated function.
|
17 |
+
"""
|
18 |
+
if fn is not None:
|
19 |
+
|
20 |
+
@functools.wraps(fn)
|
21 |
+
def inner(*args, **kwargs):
|
22 |
+
import torch._dynamo
|
23 |
+
|
24 |
+
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
|
25 |
+
|
26 |
+
return inner
|
27 |
+
else:
|
28 |
+
# decorator usage like @_disable_dynamo(recursive=False). The resulting
|
29 |
+
# object expects the original decorated function as the arg.
|
30 |
+
return functools.partial(_disable_dynamo, recursive=recursive)
|
torch/_custom_op/__init__.py
ADDED
File without changes
|
torch/_custom_op/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (156 Bytes). View file
|
|
torch/_custom_op/__pycache__/autograd.cpython-310.pyc
ADDED
Binary file (8.86 kB). View file
|
|
torch/_custom_op/__pycache__/functional.cpython-310.pyc
ADDED
Binary file (5.95 kB). View file
|
|
torch/_custom_op/__pycache__/impl.cpython-310.pyc
ADDED
Binary file (33.5 kB). View file
|
|
torch/_custom_op/autograd.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils._pytree as pytree
|
3 |
+
from collections import namedtuple
|
4 |
+
import functools
|
5 |
+
|
6 |
+
|
7 |
+
# NOTE [CustomOp autograd kernel indirection]
|
8 |
+
# We register `inner` as the autograd kernel for this custom_op.
|
9 |
+
# `inner` either calls the autograd formula registered by the user,
|
10 |
+
# or goes into an `autograd_not_implemented` kernel.
|
11 |
+
#
|
12 |
+
# The reason why this indirection exists is
|
13 |
+
# so that we can swap out the autograd kernel (the PyTorch dispatcher
|
14 |
+
# doesn't actually allow us to do this). By default, we want
|
15 |
+
# the `autograd_not_implemented` behavior, but then the user may come
|
16 |
+
# and register something that is actually a backward formula
|
17 |
+
def autograd_kernel_indirection(custom_op):
|
18 |
+
autograd_fallback = autograd_not_implemented(custom_op)
|
19 |
+
|
20 |
+
def inner(*args, **kwargs):
|
21 |
+
if custom_op._has_impl('autograd'):
|
22 |
+
kernel = custom_op._get_impl('autograd').func
|
23 |
+
return kernel(*args, **kwargs)
|
24 |
+
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
|
25 |
+
# after the user gives us "backward" and "save_for_backward", we generate
|
26 |
+
# the "autograd" impl. If the user only provided one, then we tell
|
27 |
+
# the user they've done something wrong.
|
28 |
+
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
|
29 |
+
missing = (
|
30 |
+
'save_for_backward' if custom_op._has_impl('backward')
|
31 |
+
else 'backward'
|
32 |
+
)
|
33 |
+
found = 'save_for_backward' if missing == 'backward' else 'backward'
|
34 |
+
loc = custom_op._get_impl(found).location
|
35 |
+
raise RuntimeError(
|
36 |
+
f"We found a '{found}' registration for {custom_op} at "
|
37 |
+
f"{loc} but were unable to find a '{missing}' registration. "
|
38 |
+
f"To use the CustomOp API to register a backward formula, "
|
39 |
+
f"please provide us both a backward function and a "
|
40 |
+
f"'save for backward' function via `impl_backward` and "
|
41 |
+
f"`impl_save_for_backward` respectively.")
|
42 |
+
return autograd_fallback(*args, **kwargs)
|
43 |
+
return inner
|
44 |
+
|
45 |
+
|
46 |
+
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
|
47 |
+
# or change the default autograd fallback to the autograd not implemented fallback.
|
48 |
+
def autograd_not_implemented(custom_op):
|
49 |
+
def kernel(*args, **kwargs):
|
50 |
+
if torch.is_grad_enabled() and pytree.tree_any(
|
51 |
+
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
|
52 |
+
):
|
53 |
+
raise RuntimeError("Autograd has not been implemented for operator")
|
54 |
+
with torch._C._AutoDispatchBelowAutograd():
|
55 |
+
return custom_op(*args, **kwargs)
|
56 |
+
return kernel
|
57 |
+
|
58 |
+
|
59 |
+
def mark_non_differentiable(ctx, output, output_differentiability):
|
60 |
+
# Output types are restricted to be:
|
61 |
+
# - Tensor
|
62 |
+
# - Tensor[]
|
63 |
+
# - int, bool, Scalar, float
|
64 |
+
# See _check_can_register_backward
|
65 |
+
if output_differentiability is not None:
|
66 |
+
if not isinstance(output, tuple):
|
67 |
+
tuple_output = (output,)
|
68 |
+
else:
|
69 |
+
tuple_output = output # type: ignore[assignment]
|
70 |
+
assert len(output_differentiability) == len(tuple_output)
|
71 |
+
non_differentiable_tensors = []
|
72 |
+
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
|
73 |
+
if isinstance(out, torch.Tensor):
|
74 |
+
if not differentiable:
|
75 |
+
non_differentiable_tensors.append(out)
|
76 |
+
continue
|
77 |
+
if isinstance(out, list):
|
78 |
+
if not differentiable:
|
79 |
+
non_differentiable_tensors.extend(out)
|
80 |
+
continue
|
81 |
+
if differentiable:
|
82 |
+
raise RuntimeError(
|
83 |
+
f"With output_differentiability={output_differentiability}. "
|
84 |
+
f"At idx {idx}, we received an object of type {type(out)} that "
|
85 |
+
f"is not a Tensor, so it cannot have be marked as differentiable in "
|
86 |
+
f"output_differentiability.")
|
87 |
+
if non_differentiable_tensors:
|
88 |
+
ctx.mark_non_differentiable(*non_differentiable_tensors)
|
89 |
+
|
90 |
+
|
91 |
+
def construct_autograd_kernel(
|
92 |
+
schema,
|
93 |
+
output_differentiability,
|
94 |
+
custom_op,
|
95 |
+
op_overload,
|
96 |
+
save_for_backward_fn,
|
97 |
+
backward_fn):
|
98 |
+
|
99 |
+
def apply(*args):
|
100 |
+
flat_args, spec = pytree.tree_flatten(args)
|
101 |
+
out_spec = None
|
102 |
+
|
103 |
+
def forward(ctx, *flat_args):
|
104 |
+
ctx.set_materialize_grads(True)
|
105 |
+
args = pytree.tree_unflatten(list(flat_args), spec)
|
106 |
+
with torch._C._AutoDispatchBelowAutograd():
|
107 |
+
output = op_overload(*args)
|
108 |
+
|
109 |
+
# We use the info about args to give better error messages in backward
|
110 |
+
args_info = namedtuple_args(
|
111 |
+
schema, pytree.tree_map(type, args))
|
112 |
+
|
113 |
+
save_for_backward_fn_inputs = namedtuple_args(schema, args)
|
114 |
+
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
|
115 |
+
|
116 |
+
save_pytree_for_backward(ctx, (to_save, args_info))
|
117 |
+
mark_non_differentiable(ctx, output, output_differentiability)
|
118 |
+
|
119 |
+
nonlocal out_spec
|
120 |
+
flat_output, out_spec = pytree.tree_flatten(output)
|
121 |
+
return tuple(flat_output)
|
122 |
+
|
123 |
+
def backward(ctx, *flat_grad_output):
|
124 |
+
assert out_spec is not None
|
125 |
+
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
|
126 |
+
saved, args_info = unpack_saved(ctx)
|
127 |
+
# There is nothing on the ctx object for now, it is just there so
|
128 |
+
# that we can add additional things in the future.
|
129 |
+
inner_ctx = object()
|
130 |
+
if not isinstance(grads, tuple):
|
131 |
+
grads = (grads,)
|
132 |
+
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
|
133 |
+
|
134 |
+
# Massage the grad_inputs_dict to a form acceptable by
|
135 |
+
# autograd.Function.
|
136 |
+
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
|
137 |
+
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
|
138 |
+
|
139 |
+
generated_cls = gen_autograd_function(
|
140 |
+
custom_op._opname + '_customop', forward, backward)
|
141 |
+
|
142 |
+
flat_output = generated_cls.apply(*flat_args)
|
143 |
+
assert out_spec is not None
|
144 |
+
return pytree.tree_unflatten(list(flat_output), out_spec)
|
145 |
+
return apply
|
146 |
+
|
147 |
+
|
148 |
+
def gen_autograd_function(name, forward, backward):
|
149 |
+
generated_cls = type(
|
150 |
+
name,
|
151 |
+
(torch.autograd.Function,),
|
152 |
+
{
|
153 |
+
'forward': staticmethod(forward),
|
154 |
+
'backward': staticmethod(backward),
|
155 |
+
}
|
156 |
+
)
|
157 |
+
return generated_cls
|
158 |
+
|
159 |
+
|
160 |
+
@functools.lru_cache
|
161 |
+
def namedtuple_args_cls(schema):
|
162 |
+
attribs = [arg.name for arg in schema.arguments.flat_all]
|
163 |
+
name = str(schema.name) + "_args"
|
164 |
+
# mypy doesn't support dynamic namedtuple name
|
165 |
+
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
|
166 |
+
return tuple_cls
|
167 |
+
|
168 |
+
|
169 |
+
def namedtuple_args(schema, args):
|
170 |
+
assert isinstance(args, tuple)
|
171 |
+
tuple_cls = namedtuple_args_cls(schema)
|
172 |
+
return tuple_cls(*args)
|
173 |
+
|
174 |
+
|
175 |
+
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
|
176 |
+
def error(what):
|
177 |
+
backward = forward_op._get_impl('backward')
|
178 |
+
raise RuntimeError(
|
179 |
+
f"In the backward function defined for {forward_op} at "
|
180 |
+
f"{backward.location} using the CustomOp API, {what}")
|
181 |
+
|
182 |
+
if not isinstance(grad_inputs_dict, dict):
|
183 |
+
error(f"expected the output of the backward function to be a dict but "
|
184 |
+
f"got {type(grad_inputs_dict)}")
|
185 |
+
|
186 |
+
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
|
187 |
+
if arg.type.is_tensor_like()}
|
188 |
+
actual_keys = grad_inputs_dict.keys()
|
189 |
+
if expected_keys != actual_keys:
|
190 |
+
error(f"expected the returned grad_input dict to have keys "
|
191 |
+
f"{expected_keys} but got {actual_keys}. The backward "
|
192 |
+
f"function must return a gradient (can be None) for each arg "
|
193 |
+
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
|
194 |
+
f"Args declared to be non-Tensor-like types should not appear "
|
195 |
+
f"in the grad_input dict")
|
196 |
+
|
197 |
+
for name, grad in grad_inputs_dict.items():
|
198 |
+
arg_info = getattr(args_info, name)
|
199 |
+
|
200 |
+
if isinstance(arg_info, list):
|
201 |
+
if not isinstance(grad, (tuple, list)):
|
202 |
+
error(f"for input '{name}' expected the grad_input dict to "
|
203 |
+
f"hold a list of gradients but got object of type "
|
204 |
+
f"{type(grad)}.")
|
205 |
+
if not len(grad) == len(arg_info):
|
206 |
+
error(f"for input '{name}' expected the grad_input dict to "
|
207 |
+
f"hold a list of {len(arg_info)} gradients but got "
|
208 |
+
f"{len(grad)}")
|
209 |
+
for idx, (g, info) in enumerate(zip(grad, arg_info)):
|
210 |
+
if g is None:
|
211 |
+
continue
|
212 |
+
if not isinstance(g, torch.Tensor):
|
213 |
+
error(f"for input '{name}' expected the grad_input dict to "
|
214 |
+
f"hold a list of None or Tensor gradients but got "
|
215 |
+
f"object of {type(g)} at index {idx}")
|
216 |
+
if not issubclass(info, torch.Tensor):
|
217 |
+
error(f"for input '{name}', got a Tensor as the gradient "
|
218 |
+
f"for the {idx}-th value but expected None because "
|
219 |
+
f"the {idx}-th value was not a Tensor (it was "
|
220 |
+
f"type {arg_info}")
|
221 |
+
continue
|
222 |
+
|
223 |
+
if grad is None:
|
224 |
+
continue
|
225 |
+
if not isinstance(grad, torch.Tensor):
|
226 |
+
error(f"got object of type {type(grad)} as the gradient for input "
|
227 |
+
f"'{name}', "
|
228 |
+
f"but expected the gradient to be either None or a Tensor")
|
229 |
+
if not issubclass(arg_info, torch.Tensor):
|
230 |
+
error(f"got a Tensor as the gradient for input '{name}' but "
|
231 |
+
f"expected None as the gradient because input '{name}' "
|
232 |
+
f"was not a Tensor (it was type {arg_info}).")
|
233 |
+
|
234 |
+
|
235 |
+
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
|
236 |
+
result = []
|
237 |
+
for name, arg_info in args_info._asdict().items():
|
238 |
+
if name not in grad_inputs_dict:
|
239 |
+
result.append(pytree.tree_map(lambda x: None, arg_info))
|
240 |
+
continue
|
241 |
+
result.append(grad_inputs_dict[name])
|
242 |
+
return tuple(pytree.tree_leaves(result))
|
243 |
+
|
244 |
+
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
|
245 |
+
# autograd.Function prefers that users use ctx.save_for_backward to
|
246 |
+
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
|
247 |
+
# ctx object.
|
248 |
+
def save_pytree_for_backward(ctx, stuff):
|
249 |
+
flat_stuff, spec = pytree.tree_flatten(stuff)
|
250 |
+
num_elts = len(flat_stuff)
|
251 |
+
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
252 |
+
if isinstance(thing, torch.Tensor)]
|
253 |
+
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
254 |
+
if not isinstance(thing, torch.Tensor)]
|
255 |
+
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
|
256 |
+
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
|
257 |
+
|
258 |
+
ctx.spec = spec
|
259 |
+
ctx.num_elts = num_elts
|
260 |
+
ctx.save_for_backward(*tensors)
|
261 |
+
ctx.tensor_idxs = tensor_idxs
|
262 |
+
ctx.saved_non_tensors = non_tensors
|
263 |
+
ctx.non_tensor_idxs = non_tensor_idxs
|
264 |
+
|
265 |
+
|
266 |
+
# Inverse operation to save_pytree_for_backward
|
267 |
+
def unpack_saved(ctx):
|
268 |
+
flat_stuff = [None] * ctx.num_elts
|
269 |
+
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
|
270 |
+
flat_stuff[idx] = tensor
|
271 |
+
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
|
272 |
+
flat_stuff[idx] = non_tensor
|
273 |
+
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
|
274 |
+
return stuff
|
torch/_custom_op/functional.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import weakref
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.utils._pytree as pytree
|
5 |
+
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
6 |
+
from torch._ops import OpOverload
|
7 |
+
from torch.library import Library
|
8 |
+
from torchgen.model import (
|
9 |
+
BaseTy,
|
10 |
+
BaseType,
|
11 |
+
FunctionSchema,
|
12 |
+
OperatorName,
|
13 |
+
OptionalType,
|
14 |
+
SchemaKind,
|
15 |
+
)
|
16 |
+
|
17 |
+
from .autograd import autograd_not_implemented
|
18 |
+
|
19 |
+
|
20 |
+
def register_functional_op(
|
21 |
+
lib: Library,
|
22 |
+
new_op_name: str,
|
23 |
+
mutable_op: OpOverload,
|
24 |
+
) -> None:
|
25 |
+
"""Given a mutable operator, registers the functional variant.
|
26 |
+
|
27 |
+
This API also correctly links the functional variant with the mutable
|
28 |
+
operator for the purposes of functionalization.
|
29 |
+
|
30 |
+
All of the new registrations are performed on the ``lib`` passed in.
|
31 |
+
|
32 |
+
Arguments:
|
33 |
+
lib (Library): Should be a torch.library.Library object that has
|
34 |
+
the same namespace as ``mutable_op``'s namespace.
|
35 |
+
lib will be used to register the new functional op as well
|
36 |
+
as a functionalization kernel for the ``mutable_op``
|
37 |
+
If you don't have a library handy, use
|
38 |
+
``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
|
39 |
+
new_op_name (str): The name of the functional operator (without the
|
40 |
+
namespace). If no namespace, the new functional variant will be
|
41 |
+
accessible under ``torch.ops.{lib.ns}.new_op_name``.
|
42 |
+
mutable_op (OpOverload): The mutable custom operator. Note
|
43 |
+
that you may need to add a `.default` to it, like
|
44 |
+
`torch.ops.aten.abs_.default`.
|
45 |
+
|
46 |
+
"""
|
47 |
+
validate(mutable_op)
|
48 |
+
schema = functional_schema(new_op_name, mutable_op)
|
49 |
+
lib.define(schema)
|
50 |
+
|
51 |
+
functional_impl = construct_functional_impl(mutable_op)
|
52 |
+
lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
|
53 |
+
|
54 |
+
functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
|
55 |
+
|
56 |
+
# There's no easy way for us to generate the autograd kernel, so we
|
57 |
+
# use autograd_not_implemented. Also, this makes it so that the user
|
58 |
+
# is unable to register an autograd formula themselves. This shouldn't
|
59 |
+
# be a problem if the user doesn't use the functional op direclty
|
60 |
+
# in their program, but we may need to revist this in the future.
|
61 |
+
lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
|
62 |
+
|
63 |
+
f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
|
64 |
+
|
65 |
+
lib.impl(mutable_op, f_kernel, 'Functionalize')
|
66 |
+
|
67 |
+
|
68 |
+
def construct_functional_impl(mutable_op):
|
69 |
+
def functional_impl(*args):
|
70 |
+
# Strategy:
|
71 |
+
# - clone args that would have been mutated
|
72 |
+
# - run mutable_op
|
73 |
+
# - return the cloned args as additional outputs
|
74 |
+
new_args = []
|
75 |
+
extra_rets = []
|
76 |
+
for is_write, arg in zip(mutable_args(mutable_op), args):
|
77 |
+
if is_write:
|
78 |
+
cloned = arg.clone() if arg is not None else None
|
79 |
+
new_args.append(cloned)
|
80 |
+
extra_rets.append(cloned)
|
81 |
+
else:
|
82 |
+
new_args.append(arg)
|
83 |
+
result = mutable_op(*new_args)
|
84 |
+
if result is None:
|
85 |
+
return tuple(extra_rets)
|
86 |
+
if isinstance(result, tuple):
|
87 |
+
return (*result, *extra_rets)
|
88 |
+
return (result, *extra_rets)
|
89 |
+
return functional_impl
|
90 |
+
|
91 |
+
|
92 |
+
def construct_functionalization_kernel(mutable_op, functional_op):
|
93 |
+
def kernel(*args):
|
94 |
+
# There's nothing to be functionalized!
|
95 |
+
# We can still end up here because DispatchKey::Functionalize is a mode key
|
96 |
+
if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
|
97 |
+
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
98 |
+
return mutable_op(*args)
|
99 |
+
|
100 |
+
# NB: This differs from the codegen -- codegen handles cases where there
|
101 |
+
# are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
|
102 |
+
# This only really matters for XLA (mixed CPU-XLA tensors) and
|
103 |
+
# running functionalization without the PT2 stack (which guarantees to us that
|
104 |
+
# all tensors are FunctionalTensorWrapper).
|
105 |
+
if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
|
106 |
+
raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
|
107 |
+
|
108 |
+
unwrapped_args = []
|
109 |
+
for arg in args:
|
110 |
+
if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
|
111 |
+
torch._sync(arg)
|
112 |
+
unwrapped = torch._from_functional_tensor(arg)
|
113 |
+
unwrapped_args.append(unwrapped)
|
114 |
+
else:
|
115 |
+
unwrapped_args.append(arg)
|
116 |
+
|
117 |
+
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
118 |
+
output = functional_op(*unwrapped_args)
|
119 |
+
|
120 |
+
num_actual_output = len(mutable_op._schema.returns)
|
121 |
+
actual_output = pytree.tree_map(
|
122 |
+
torch._to_functional_tensor, output[:num_actual_output])
|
123 |
+
|
124 |
+
new_values_to_propagate = output[num_actual_output:]
|
125 |
+
inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
|
126 |
+
if is_write]
|
127 |
+
assert len(new_values_to_propagate) == len(inputs_to_replace)
|
128 |
+
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
|
129 |
+
if (arg is None and new_value is None) or (arg is not None and new_value is not None):
|
130 |
+
continue
|
131 |
+
torch._C._propagate_xla_data(arg, new_value)
|
132 |
+
torch._C._replace_(arg, new_value)
|
133 |
+
torch._C._commit_update(arg)
|
134 |
+
torch._sync(arg)
|
135 |
+
|
136 |
+
if len(actual_output) == 1:
|
137 |
+
return actual_output[0]
|
138 |
+
elif len(actual_output) == 0:
|
139 |
+
return None
|
140 |
+
return actual_output
|
141 |
+
|
142 |
+
return kernel
|
143 |
+
|
144 |
+
|
145 |
+
def validate(mutable_op: OpOverload):
|
146 |
+
if not isinstance(mutable_op, OpOverload):
|
147 |
+
raise TypeError(
|
148 |
+
f"register_functional_op(mutable_op): expected mutable_op to be instance of "
|
149 |
+
f"OpOverload but got {type(mutable_op)}")
|
150 |
+
|
151 |
+
# There are generally three types of "in-place" or "mutable" ops.
|
152 |
+
# Each of them have their own conventions:
|
153 |
+
# - inplace (first input modified in-place and returned as only output)
|
154 |
+
# - out= (some args modified in-place and returned as outputs)
|
155 |
+
# - mutable (some args modified in-place but none of those returned as outputs)
|
156 |
+
# In theory we can support all three, but we'll just support the last
|
157 |
+
# option right now for simplicity.
|
158 |
+
schema = FunctionSchema.parse(str(mutable_op._schema))
|
159 |
+
if not schema.kind() == SchemaKind.mutable:
|
160 |
+
raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
|
161 |
+
for ret in schema.returns:
|
162 |
+
# construct_functionalization_kernel assumes this for simplicity
|
163 |
+
if ret.annotation is not None:
|
164 |
+
raise NotImplementedError(
|
165 |
+
"NYI: register_functional_op(op) where op returns a mutated or aliased value. "
|
166 |
+
"Please file an issue (and as a workaround, modify your operator to "
|
167 |
+
"not return the mutated value or aliases)")
|
168 |
+
for arg in schema.arguments.flat_all:
|
169 |
+
# construct_functionalization_kernel assumes this for simplicity
|
170 |
+
if arg.type.is_tensor_like() and (
|
171 |
+
arg.type != BaseType(BaseTy.Tensor)
|
172 |
+
and arg.type != OptionalType(BaseType(BaseTy.Tensor))
|
173 |
+
):
|
174 |
+
raise NotImplementedError(
|
175 |
+
"NYI: register_functional_op(op) where op has a List[Tensor] input."
|
176 |
+
"Please file an issue.")
|
177 |
+
|
178 |
+
|
179 |
+
def functional_schema(new_op_name, op: OpOverload):
|
180 |
+
schema = FunctionSchema.parse(str(op._schema))
|
181 |
+
schema = schema.signature().with_name(OperatorName.parse(new_op_name))
|
182 |
+
return str(schema)
|
183 |
+
|
184 |
+
|
185 |
+
def mutable_args(op: OpOverload):
|
186 |
+
return tuple(False if arg.alias_info is None else arg.alias_info.is_write
|
187 |
+
for arg in op._schema.arguments)
|
torch/_custom_op/impl.py
ADDED
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import functools
|
3 |
+
import inspect
|
4 |
+
import sys
|
5 |
+
import typing
|
6 |
+
import weakref
|
7 |
+
|
8 |
+
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch._C as _C
|
12 |
+
import torch.library as library
|
13 |
+
from torch._library.abstract_impl import AbstractImplCtx
|
14 |
+
from torch.library import get_ctx
|
15 |
+
|
16 |
+
from .autograd import autograd_kernel_indirection, construct_autograd_kernel
|
17 |
+
|
18 |
+
"""
|
19 |
+
For a detailed guide on custom ops, please see
|
20 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
21 |
+
|
22 |
+
This file includes pieces of the implementation of our custom operator API.
|
23 |
+
"""
|
24 |
+
|
25 |
+
__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
|
26 |
+
|
27 |
+
|
28 |
+
SUPPORTED_DEVICE_TYPE_TO_KEY = {
|
29 |
+
"cpu": "CPU",
|
30 |
+
"cuda": "CUDA",
|
31 |
+
}
|
32 |
+
|
33 |
+
# We will not let users register CustomOps with anything that could look like
|
34 |
+
# PyTorch internals to avoid confusion.
|
35 |
+
RESERVED_NS = {
|
36 |
+
"prim",
|
37 |
+
"prims",
|
38 |
+
"aten",
|
39 |
+
"at",
|
40 |
+
"torch",
|
41 |
+
"pytorch",
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
def custom_op(
|
46 |
+
qualname: str, manual_schema: typing.Optional[str] = None
|
47 |
+
) -> typing.Callable:
|
48 |
+
r"""Creates a new CustomOp object.
|
49 |
+
|
50 |
+
WARNING: if you're a user, please do not use this directly
|
51 |
+
(instead use the torch._custom_ops APIs).
|
52 |
+
Also please see the following for a detailed guide on custom ops.
|
53 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
54 |
+
|
55 |
+
In PyTorch, defining an op (short for "operator") is a two step-process:
|
56 |
+
- we need to define (create) the op
|
57 |
+
- we need to implement behavior for how the operator interacts with
|
58 |
+
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
|
59 |
+
|
60 |
+
This entrypoint defines the CustomOp object (the first step);
|
61 |
+
you must then perform the second step by calling various methods on
|
62 |
+
the CustomOp object.
|
63 |
+
|
64 |
+
This API is used as a decorator (see examples).
|
65 |
+
|
66 |
+
Arguments:
|
67 |
+
qualname (str): Should be a string that looks like
|
68 |
+
"namespace::operator_name". Operators in PyTorch need a namespace to
|
69 |
+
avoid name collisions; a given operator may only be created once.
|
70 |
+
If you are writing a Python library, we recommend the namespace to
|
71 |
+
be the name of your top-level module. The operator_name must be
|
72 |
+
the same as the name of the function you pass to custom_op
|
73 |
+
(see examples).
|
74 |
+
manual_schema (Optional[str]): Each PyTorch operator needs a schema that
|
75 |
+
tells PyTorch the types of the inputs/outputs. If None (default),
|
76 |
+
we will infer the schema from the type annotations on the function
|
77 |
+
(see examples). Otherwise, if you don't want to use type annotations,
|
78 |
+
you may provide us the schema string.
|
79 |
+
|
80 |
+
Example::
|
81 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
82 |
+
>>> import numpy as np
|
83 |
+
>>> from torch import Tensor
|
84 |
+
>>>
|
85 |
+
>>> # Step 1: define the CustomOp.
|
86 |
+
>>> # We need to provide the decorator a "prototype function"
|
87 |
+
>>> # (a function with Python ellipses as the body).
|
88 |
+
>>> @custom_op("my_library::numpy_sin")
|
89 |
+
>>> def numpy_sin(x: Tensor) -> Tensor:
|
90 |
+
>>> ...
|
91 |
+
>>>
|
92 |
+
>>> # numpy_sin is now an instance of class CustomOp
|
93 |
+
>>> print(type(numpy_sin))
|
94 |
+
>>>
|
95 |
+
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
96 |
+
>>>
|
97 |
+
>>> # Register an implementation for CPU tensors
|
98 |
+
>>> @numpy_sin.impl('cpu')
|
99 |
+
>>> def numpy_sin_impl_cpu(x):
|
100 |
+
>>> return torch.from_numpy(np.sin(x.numpy()))
|
101 |
+
>>>
|
102 |
+
>>> # Register an implementation for CUDA tensors
|
103 |
+
>>> @numpy_sin.impl('cuda')
|
104 |
+
>>> def numpy_sin_impl_cuda(x):
|
105 |
+
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
|
106 |
+
>>>
|
107 |
+
>>> x = torch.randn(3)
|
108 |
+
>>> numpy_sin(x) # calls numpy_sin_impl_cpu
|
109 |
+
>>>
|
110 |
+
>>> x_cuda = x.cuda()
|
111 |
+
>>> numpy_sin(x) # calls numpy_sin_impl_cuda
|
112 |
+
|
113 |
+
"""
|
114 |
+
|
115 |
+
def inner(func):
|
116 |
+
if not inspect.isfunction(func):
|
117 |
+
raise ValueError(
|
118 |
+
f"custom_op(...)(func): Expected `func` to be a Python "
|
119 |
+
f"function, got: {type(func)}"
|
120 |
+
)
|
121 |
+
|
122 |
+
ns, name = parse_qualname(qualname)
|
123 |
+
validate_namespace(ns)
|
124 |
+
if func.__name__ != name:
|
125 |
+
raise ValueError(
|
126 |
+
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
|
127 |
+
f"to have name '{name}' but got '{func.__name__}'. "
|
128 |
+
f"Please either change the name of `func` or the qualname that "
|
129 |
+
f"is passed to `custom_op`"
|
130 |
+
)
|
131 |
+
|
132 |
+
schema = infer_schema(func) if manual_schema is None else manual_schema
|
133 |
+
schema_str = f"{name}{schema}"
|
134 |
+
function_schema = FunctionSchema.parse(schema_str)
|
135 |
+
validate_schema(function_schema)
|
136 |
+
if manual_schema is not None:
|
137 |
+
validate_function_matches_schema(function_schema, func)
|
138 |
+
|
139 |
+
lib = library.Library(ns, "FRAGMENT")
|
140 |
+
lib.define(schema_str)
|
141 |
+
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
142 |
+
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
|
143 |
+
|
144 |
+
result.__name__ = func.__name__
|
145 |
+
result.__module__ = func.__module__
|
146 |
+
result.__doc__ = func.__doc__
|
147 |
+
|
148 |
+
library.impl(lib, result._opname, "Autograd")(
|
149 |
+
autograd_kernel_indirection(weakref.proxy(result))
|
150 |
+
)
|
151 |
+
|
152 |
+
torch._C._dispatch_set_report_error_callback(
|
153 |
+
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
154 |
+
)
|
155 |
+
|
156 |
+
return result
|
157 |
+
|
158 |
+
return inner
|
159 |
+
|
160 |
+
|
161 |
+
# Global dictionary holding references to all CustomOp objects
|
162 |
+
# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
|
163 |
+
# Used to query the CustomOp associated with a specific C++ dispatcher operator.
|
164 |
+
# An example usage is FakeTensor: FakeTensor checks if a specific operator
|
165 |
+
# has an implementation registered via the CustomOp API.
|
166 |
+
# Indexed by qualname (e.g. aten::foo)
|
167 |
+
global_registry: typing.Dict[str, "CustomOp"] = {}
|
168 |
+
|
169 |
+
|
170 |
+
class CustomOp:
|
171 |
+
r"""Class for custom operators in PyTorch.
|
172 |
+
|
173 |
+
Use the CustomOp API to create user-defined custom operators that behave
|
174 |
+
just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
|
175 |
+
comes to various PyTorch subsystems (like torch.compile).
|
176 |
+
|
177 |
+
To construct a `CustomOp`, use `custom_op`.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
|
181 |
+
super().__init__()
|
182 |
+
if not _private_access:
|
183 |
+
raise RuntimeError(
|
184 |
+
"The CustomOp constructor is private and we do not guarantee "
|
185 |
+
"BC for it. Please use custom_op(...) to create a CustomOp object"
|
186 |
+
)
|
187 |
+
name = f"{cpp_ns}::{operator_name}"
|
188 |
+
self._schema = schema
|
189 |
+
self._cpp_ns = cpp_ns
|
190 |
+
self._lib: library.Library = lib
|
191 |
+
self._ophandle: _C._DispatchOperatorHandle = ophandle
|
192 |
+
# Has the name of the op, e.g. "foo". We cache here for convenience.
|
193 |
+
self._opname: str = operator_name
|
194 |
+
# this is _opname but with namespace. e.g. "custom::foo"
|
195 |
+
self._qualname: str = name
|
196 |
+
self.__name__ = None # mypy requires this
|
197 |
+
# NB: Some of these impls are registered as kernels to DispatchKeys.
|
198 |
+
# Modifying the _impls dict directly won't do anything in that case.
|
199 |
+
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
|
200 |
+
# See NOTE [CustomOp autograd kernel indirection]
|
201 |
+
self._registered_autograd_kernel_indirection = False
|
202 |
+
|
203 |
+
global_registry[self._qualname] = self
|
204 |
+
|
205 |
+
def _register_autograd_kernel_indirection(self):
|
206 |
+
assert not self._registered_autograd_kernel_indirection
|
207 |
+
self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
|
208 |
+
self._registered_autograd_kernel_indirection = True
|
209 |
+
|
210 |
+
# Records the impl and the source location in self._impls
|
211 |
+
# Note that this doesn't cause torch.library to use the impl, that
|
212 |
+
# needs to be done in a separate self._lib.impl call.
|
213 |
+
def _register_impl(self, kind, func, stacklevel=2):
|
214 |
+
if self._has_impl(kind):
|
215 |
+
func_and_location = self._impls[kind]
|
216 |
+
assert func_and_location is not None # Pacify mypy
|
217 |
+
location = func_and_location.location
|
218 |
+
raise RuntimeError(
|
219 |
+
f"Attempting to register a {kind} impl for operator {self._qualname} "
|
220 |
+
f"that already has a {kind} impl registered from Python at "
|
221 |
+
f"{location}. This is not supported."
|
222 |
+
)
|
223 |
+
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
224 |
+
location = f"{frame.filename}:{frame.lineno}"
|
225 |
+
self._impls[kind] = FuncAndLocation(func, location)
|
226 |
+
|
227 |
+
def _get_impl(self, kind):
|
228 |
+
return self._impls[kind]
|
229 |
+
|
230 |
+
def _has_impl(self, kind):
|
231 |
+
return kind in self._impls
|
232 |
+
|
233 |
+
def _destroy(self):
|
234 |
+
# NOTE: [CustomOp lifetime]
|
235 |
+
# A CustomOp, once created, lives forever. The mechanism is that the
|
236 |
+
# global registry holds a reference to it. However, to make testing
|
237 |
+
# easier, we want to be able to destroy CustomOp objects.
|
238 |
+
# CustomOp._destroy does the job, though it leaves the CustomOp
|
239 |
+
# in a garbage state.
|
240 |
+
del self._lib
|
241 |
+
|
242 |
+
opnamespace = getattr(torch.ops, self._cpp_ns)
|
243 |
+
if hasattr(opnamespace, self._opname):
|
244 |
+
delattr(opnamespace, self._opname)
|
245 |
+
|
246 |
+
del global_registry[self._qualname]
|
247 |
+
|
248 |
+
def __repr__(self):
|
249 |
+
return f'<CustomOp(op="{self._qualname}")>'
|
250 |
+
|
251 |
+
def __call__(self, *args, **kwargs):
|
252 |
+
# Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
|
253 |
+
# Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
|
254 |
+
# issues from caching operators that make testing CustomOp difficult).
|
255 |
+
result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
|
256 |
+
return result
|
257 |
+
|
258 |
+
def impl(
|
259 |
+
self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
|
260 |
+
) -> typing.Callable:
|
261 |
+
r"""Register an implementation for a device type for this CustomOp object.
|
262 |
+
|
263 |
+
WARNING: if you're a user, please do not use this directly
|
264 |
+
(instead use the torch._custom_ops APIs).
|
265 |
+
Also please see the following for a detailed guide on custom ops.
|
266 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
267 |
+
|
268 |
+
If the CustomOp is passed multiple Tensor inputs with different device
|
269 |
+
types, it will dispatch to the registered implementation for the highest
|
270 |
+
priority device type among those present.
|
271 |
+
The supported device types, in order of priority, are {'cuda', 'cpu'}.
|
272 |
+
|
273 |
+
This API is used as a decorator (see examples).
|
274 |
+
|
275 |
+
Arguments:
|
276 |
+
device_types (str or Iterable[str]): the device type(s) to register the function for.
|
277 |
+
|
278 |
+
Examples::
|
279 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
280 |
+
>>> import numpy as np
|
281 |
+
>>> from torch import Tensor
|
282 |
+
>>>
|
283 |
+
>>> @custom_op("my_library::numpy_cos")
|
284 |
+
>>> def numpy_cos(x: Tensor) -> Tensor:
|
285 |
+
>>> ...
|
286 |
+
>>>
|
287 |
+
>>> # Register an implementation for CPU Tensors
|
288 |
+
>>> @numpy_cos.impl('cpu')
|
289 |
+
>>> def numpy_cos_impl_cpu(x):
|
290 |
+
>>> return torch.from_numpy(np.cos(x.numpy()))
|
291 |
+
>>>
|
292 |
+
>>> # Register an implementation for CUDA Tensors
|
293 |
+
>>> @numpy_cos.impl('cuda')
|
294 |
+
>>> def numpy_cos_impl_cuda(x):
|
295 |
+
>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
|
296 |
+
>>>
|
297 |
+
>>> x = torch.randn(3)
|
298 |
+
>>> numpy_cos(x) # calls numpy_cos_impl_cpu
|
299 |
+
>>>
|
300 |
+
>>> x_cuda = x.cuda()
|
301 |
+
>>> numpy_cos(x) # calls numpy_cos_impl_cuda
|
302 |
+
|
303 |
+
"""
|
304 |
+
if isinstance(device_types, str):
|
305 |
+
device_types = [device_types]
|
306 |
+
for device_type in device_types:
|
307 |
+
validate_device_type(device_type)
|
308 |
+
|
309 |
+
def inner(f):
|
310 |
+
for device_type in set(device_types):
|
311 |
+
self._check_doesnt_have_library_impl(device_type)
|
312 |
+
self._register_impl(device_type, f, stacklevel=_stacklevel)
|
313 |
+
dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
|
314 |
+
library.impl(self._lib, self._opname, dispatch_key)(f)
|
315 |
+
return f
|
316 |
+
|
317 |
+
return inner
|
318 |
+
|
319 |
+
def _check_doesnt_have_library_impl(self, device_type):
|
320 |
+
if self._has_impl(device_type):
|
321 |
+
return
|
322 |
+
key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
|
323 |
+
if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
|
324 |
+
raise RuntimeError(
|
325 |
+
f"impl(..., device_types={device_type}): the operator {self._qualname} "
|
326 |
+
f"already has an implementation for this device type via a "
|
327 |
+
f"pre-existing torch.library or TORCH_LIBRARY registration.")
|
328 |
+
|
329 |
+
def impl_factory(self) -> typing.Callable:
|
330 |
+
r"""Register an implementation for a factory function."""
|
331 |
+
|
332 |
+
def inner(f):
|
333 |
+
self._register_impl("factory", f)
|
334 |
+
library.impl(self._lib, self._opname, "BackendSelect")(f)
|
335 |
+
return f
|
336 |
+
|
337 |
+
return inner
|
338 |
+
|
339 |
+
def impl_abstract(self, _stacklevel=2) -> typing.Callable:
|
340 |
+
r"""Register an abstract implementation for this operator.
|
341 |
+
|
342 |
+
WARNING: please do not use this directly (and instead use the torch._custom_ops
|
343 |
+
APIs). Also please see the following for a detailed guide on custom ops.
|
344 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
345 |
+
|
346 |
+
An "abstract implementation" specifies the behavior of this operator on
|
347 |
+
Tensors that carry no data. Given some input Tensors with certain properties
|
348 |
+
(sizes/strides/storage_offset/device), it specifies what the properties of
|
349 |
+
the output Tensors are.
|
350 |
+
|
351 |
+
The abstract implementation has the same signature as the operator.
|
352 |
+
It is run for both FakeTensors and meta tensors. To write an abstract
|
353 |
+
implementation, assume that all Tensor inputs to the operator are
|
354 |
+
regular CPU/CUDA/Meta tensors, but they do not have storage, and
|
355 |
+
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
|
356 |
+
The abstract implementation must consist of only PyTorch operations
|
357 |
+
(and may not directly access the storage or data of any input or
|
358 |
+
intermediate Tensors).
|
359 |
+
|
360 |
+
This API is used as a decorator (see examples).
|
361 |
+
|
362 |
+
Examples::
|
363 |
+
>>> import numpy as np
|
364 |
+
>>> from torch import Tensor
|
365 |
+
>>>
|
366 |
+
>>> # Example 1: an operator without data-dependent output shape
|
367 |
+
>>> @custom_op('my_library::custom_linear')
|
368 |
+
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
369 |
+
>>> ...
|
370 |
+
>>>
|
371 |
+
>>> @custom_linear.impl_abstract()
|
372 |
+
>>> def custom_linear_abstract(x, weight):
|
373 |
+
>>> assert x.dim() == 2
|
374 |
+
>>> assert weight.dim() == 2
|
375 |
+
>>> assert bias.dim() == 1
|
376 |
+
>>> assert x.shape[1] == weight.shape[1]
|
377 |
+
>>> assert weight.shape[0] == bias.shape[0]
|
378 |
+
>>> assert x.device == weight.device
|
379 |
+
>>>
|
380 |
+
>>> return (x @ weight.t()) + bias
|
381 |
+
>>>
|
382 |
+
>>> # Example 2: an operator with data-dependent output shape
|
383 |
+
>>> @custom_op('my_library::custom_nonzero')
|
384 |
+
>>> def custom_nonzero(x: Tensor) -> Tensor:
|
385 |
+
>>> ...
|
386 |
+
>>>
|
387 |
+
>>> @custom_nonzero.impl_abstract()
|
388 |
+
>>> def custom_nonzero_abstract(x):
|
389 |
+
>>> # Number of nonzero-elements is data-dependent.
|
390 |
+
>>> # Since we cannot peek at the data in an abstract impl,
|
391 |
+
>>> # we use the ctx object to construct a new symint that
|
392 |
+
>>> # represents the data-dependent size.
|
393 |
+
>>> ctx = torch._custom_op.get_ctx()
|
394 |
+
>>> nnz = ctx.create_unbacked_symint()
|
395 |
+
>>> shape = [x.dim(), nnz]
|
396 |
+
>>> result = x.new_empty(shape, dtype=torch.long)
|
397 |
+
>>> return result
|
398 |
+
>>>
|
399 |
+
>>> @custom_nonzero.impl(['cpu', 'cuda'])
|
400 |
+
>>> def custom_nonzero_impl(x):
|
401 |
+
>>> x_np = to_numpy(x)
|
402 |
+
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
403 |
+
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
|
404 |
+
>>> # constrain the range to at least 2
|
405 |
+
>>> if res.shape[0] <= 1:
|
406 |
+
>>> raise RuntimeError("not supported")
|
407 |
+
>>> return torch.tensor(res, device=x.device)
|
408 |
+
|
409 |
+
"""
|
410 |
+
|
411 |
+
def inner(f):
|
412 |
+
self._check_doesnt_have_library_meta_impl()
|
413 |
+
self._register_impl("abstract", f, stacklevel=_stacklevel)
|
414 |
+
location = self._get_impl("abstract").location
|
415 |
+
|
416 |
+
qualname = self._qualname
|
417 |
+
|
418 |
+
# Handle DispatchKey.Meta registration
|
419 |
+
@functools.wraps(f)
|
420 |
+
def f_with_ctx(*args, **kwargs):
|
421 |
+
def error_on_ctx():
|
422 |
+
raise RuntimeError(
|
423 |
+
f"Attempted to call get_ctx() for the meta implementation "
|
424 |
+
f"for {qualname}."
|
425 |
+
f"You have presumably called get_ctx() because the operator "
|
426 |
+
f"has a data-dependent output shape; if so, there is no "
|
427 |
+
f"such meta implementation and this error is the correct "
|
428 |
+
f"behavior. Otherwise, please remove the call to get_ctx() "
|
429 |
+
f"in the implementation registered with impl_abstract "
|
430 |
+
f"at {location}"
|
431 |
+
)
|
432 |
+
|
433 |
+
with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
|
434 |
+
return f(*args, **kwargs)
|
435 |
+
|
436 |
+
self._lib.impl(self._opname, f_with_ctx, "Meta")
|
437 |
+
return f
|
438 |
+
|
439 |
+
return inner
|
440 |
+
|
441 |
+
def _check_can_register_backward(self):
|
442 |
+
def error(detail):
|
443 |
+
raise RuntimeError(
|
444 |
+
f"Cannot use torch._custom_ops APIs to register backward "
|
445 |
+
f"formula for {detail}. Got operator "
|
446 |
+
f"{self._qualname} with schema: {schema}"
|
447 |
+
)
|
448 |
+
|
449 |
+
schema = self._schema
|
450 |
+
if schema.kind() != SchemaKind.functional:
|
451 |
+
error("non-functional operator")
|
452 |
+
|
453 |
+
rets = schema.returns
|
454 |
+
if not schema.returns:
|
455 |
+
error("operator with no returns")
|
456 |
+
|
457 |
+
assert len(rets) > 0
|
458 |
+
is_non_mutating_view = any(
|
459 |
+
r.annotation is not None and not r.annotation.is_write for r in rets
|
460 |
+
)
|
461 |
+
if is_non_mutating_view:
|
462 |
+
error("operator that returns views")
|
463 |
+
|
464 |
+
# We make assumptions about the schema's return types.
|
465 |
+
allowed_return_types = {
|
466 |
+
BaseType(BaseTy.int): "int",
|
467 |
+
BaseType(BaseTy.SymInt): "SymInt",
|
468 |
+
BaseType(BaseTy.bool): "bool",
|
469 |
+
BaseType(BaseTy.float): "float",
|
470 |
+
BaseType(BaseTy.Tensor): "Tensor",
|
471 |
+
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
|
472 |
+
}
|
473 |
+
for ret in schema.returns:
|
474 |
+
if ret.type in allowed_return_types:
|
475 |
+
continue
|
476 |
+
error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
|
477 |
+
|
478 |
+
def _check_doesnt_have_library_autograd_impl(self):
|
479 |
+
if self._registered_autograd_kernel_indirection:
|
480 |
+
return
|
481 |
+
|
482 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
|
483 |
+
raise RuntimeError(
|
484 |
+
f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
|
485 |
+
f"already has an implementation for this device type via a "
|
486 |
+
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
487 |
+
f"CompositeImplicitAutograd operators do not need an autograd formula; "
|
488 |
+
f"instead, the operator will decompose into its constituents and those "
|
489 |
+
f"can have autograd formulas defined on them.")
|
490 |
+
|
491 |
+
# We can improve this by adding "all Autograd<BACKEND> keys", but
|
492 |
+
# realistically people will just be using this API for CPU/CUDA for now.
|
493 |
+
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
|
494 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
|
495 |
+
raise RuntimeError(
|
496 |
+
f"impl_backward/impl_save_for_backward: "
|
497 |
+
f"the operator {self._qualname} already has an Autograd kernel "
|
498 |
+
f"registered to DispatchKey::{key} vi a pre-existing "
|
499 |
+
f"torch.library or TORCH_LIBRARY registration. Please either "
|
500 |
+
f"remove those registrations or don't use the torch._custom_ops APIs")
|
501 |
+
|
502 |
+
def _check_doesnt_have_library_meta_impl(self):
|
503 |
+
if self._has_impl("abstract"):
|
504 |
+
return
|
505 |
+
|
506 |
+
# If the user's operator is CompositeExplicitAutograd,
|
507 |
+
# allow them to impl_abstract. This is being pragmatic
|
508 |
+
# (existing custom ops may have CompositeExplicitAutograd
|
509 |
+
# registration that don't work with Meta kernels, so this
|
510 |
+
# gives them an escape hatch).
|
511 |
+
if (
|
512 |
+
_C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
|
513 |
+
and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
|
514 |
+
):
|
515 |
+
return
|
516 |
+
|
517 |
+
# Otherwise, if the user's already has a Meta kernel or their
|
518 |
+
# op is CompositeImplicitAutograd or some other alias dispatch key,
|
519 |
+
# raise.
|
520 |
+
|
521 |
+
# Special case for CompositeImplicitAutograd
|
522 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
|
523 |
+
raise RuntimeError(
|
524 |
+
f"impl_abstract(...): the operator {self._qualname} "
|
525 |
+
f"already has an implementation for this device type via a "
|
526 |
+
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
527 |
+
f"CompositeImplicitAutograd operators do not need an abstract impl; "
|
528 |
+
f"instead, the operator will decompose into its constituents and those "
|
529 |
+
f"can have abstract impls defined on them.")
|
530 |
+
|
531 |
+
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
|
532 |
+
raise RuntimeError(
|
533 |
+
f"impl_abstract(...): the operator {self._qualname} "
|
534 |
+
f"already has an DispatchKey::Meta implementation via a "
|
535 |
+
f"pre-existing torch.library or TORCH_LIBRARY registration. "
|
536 |
+
f"Please either remove that registration or don't call impl_abstract.")
|
537 |
+
|
538 |
+
# NOTE ["backward", "save_for_backward", and "autograd"]
|
539 |
+
# As a part of the explicit autograd API, a user must provide us
|
540 |
+
# a "save_for_backward" function and a "backward" function.
|
541 |
+
# When both of these have been provided, then we automatically
|
542 |
+
# construct the "autograd" kernel.
|
543 |
+
def _register_autograd_kernel(self):
|
544 |
+
assert self._has_impl("backward")
|
545 |
+
assert self._has_impl("save_for_backward")
|
546 |
+
kernel = construct_autograd_kernel(
|
547 |
+
self._schema,
|
548 |
+
self._output_differentiability,
|
549 |
+
self,
|
550 |
+
get_op(self._qualname),
|
551 |
+
self._get_impl("save_for_backward").func,
|
552 |
+
self._get_impl("backward").func)
|
553 |
+
self._register_impl("autograd", kernel)
|
554 |
+
|
555 |
+
def impl_save_for_backward(self, _stacklevel=2):
|
556 |
+
r"""Register a function that tells us what to save for backward.
|
557 |
+
|
558 |
+
Please see impl_backward for more details.
|
559 |
+
"""
|
560 |
+
def inner(f):
|
561 |
+
self._check_can_register_backward()
|
562 |
+
self._check_doesnt_have_library_autograd_impl()
|
563 |
+
if not self._registered_autograd_kernel_indirection:
|
564 |
+
self._register_autograd_kernel_indirection()
|
565 |
+
self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
|
566 |
+
if self._has_impl("backward"):
|
567 |
+
self._register_autograd_kernel()
|
568 |
+
return inner
|
569 |
+
|
570 |
+
def impl_backward(self, output_differentiability=None, _stacklevel=2):
|
571 |
+
r"""Registers a backward formula.
|
572 |
+
|
573 |
+
WARNING: if you're a user, please do not use this directly
|
574 |
+
(instead use the torch._custom_ops APIs).
|
575 |
+
Also please see the following for a detailed guide on custom ops.
|
576 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
577 |
+
|
578 |
+
In order for the CustomOp to work with autograd, you need to register
|
579 |
+
a backward formula. There are two pieces to this:
|
580 |
+
1. You must give us a function to specify what to save for backward.
|
581 |
+
Call this the "save for backward" function.
|
582 |
+
2. You must give us a function that computes gradients. Call this the
|
583 |
+
"backward" function.
|
584 |
+
|
585 |
+
Use `impl_save_for_backward` to define a "save for backward" function
|
586 |
+
that specifies what gets saved for backward. The function should accept
|
587 |
+
two arguments ``(inputs, output)`` and return the quantities to be saved
|
588 |
+
for backward.
|
589 |
+
|
590 |
+
During runtime, when you call the CustomOp, PyTorch will invoke the
|
591 |
+
"save for backward" function with the inputs and output of the CustomOp.
|
592 |
+
|
593 |
+
Use `impl_backward` to define the "backward" function. The backward
|
594 |
+
function must accept ``(ctx, saved, *grads)``:
|
595 |
+
- ``ctx`` is a context object where we may provide information
|
596 |
+
- ``saved`` is exactly what gets returned from the "save for backward"
|
597 |
+
function
|
598 |
+
- ``grads`` is one or more gradients. The number of gradients matches
|
599 |
+
the number of outputs of the CustomOp.
|
600 |
+
|
601 |
+
The backward function must return a dict that maps the name of
|
602 |
+
an input to the CustomOp to its corresponding gradient. All inputs that
|
603 |
+
were declared to be Tensors in the CustomOp definition must be accounted
|
604 |
+
for in the dict. The gradient may be a Tensor or None.
|
605 |
+
|
606 |
+
"""
|
607 |
+
if output_differentiability is not None:
|
608 |
+
def yell():
|
609 |
+
raise RuntimeError(
|
610 |
+
f"impl_backward(output_differentiability): expected "
|
611 |
+
f"output_differentiability to be a list of bools with "
|
612 |
+
f"length equal to the number of outputs of this CustomOp "
|
613 |
+
f"got: {output_differentiability}")
|
614 |
+
|
615 |
+
if not isinstance(output_differentiability, list):
|
616 |
+
yell()
|
617 |
+
for diff in output_differentiability:
|
618 |
+
if not isinstance(diff, bool):
|
619 |
+
yell()
|
620 |
+
if len(self._schema.returns) != len(output_differentiability):
|
621 |
+
yell()
|
622 |
+
|
623 |
+
def inner(f):
|
624 |
+
self._check_can_register_backward()
|
625 |
+
self._check_doesnt_have_library_autograd_impl()
|
626 |
+
if not self._registered_autograd_kernel_indirection:
|
627 |
+
self._register_autograd_kernel_indirection()
|
628 |
+
self._register_impl("backward", f, stacklevel=_stacklevel)
|
629 |
+
self._output_differentiability = output_differentiability
|
630 |
+
if self._has_impl("save_for_backward"):
|
631 |
+
self._register_autograd_kernel()
|
632 |
+
return inner
|
633 |
+
|
634 |
+
|
635 |
+
@dataclasses.dataclass
|
636 |
+
class FuncAndLocation:
|
637 |
+
func: typing.Callable
|
638 |
+
location: str
|
639 |
+
|
640 |
+
|
641 |
+
def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
|
642 |
+
overload_name = (
|
643 |
+
"" if operator_name.overload_name is None else operator_name.overload_name
|
644 |
+
)
|
645 |
+
return _C._dispatch_find_schema_or_throw(
|
646 |
+
f"{cpp_ns}::{str(operator_name.name)}", overload_name
|
647 |
+
)
|
648 |
+
|
649 |
+
|
650 |
+
def validate_namespace(ns: str) -> None:
|
651 |
+
if "." in ns:
|
652 |
+
raise ValueError(
|
653 |
+
f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
|
654 |
+
f"valid variable name)"
|
655 |
+
)
|
656 |
+
if ns in RESERVED_NS:
|
657 |
+
raise ValueError(
|
658 |
+
f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
|
659 |
+
f"please choose something else. "
|
660 |
+
)
|
661 |
+
|
662 |
+
def validate_schema(schema: FunctionSchema) -> None:
|
663 |
+
if not torch._library.utils.is_functional_schema(schema):
|
664 |
+
raise ValueError(
|
665 |
+
f"custom_op only supports functional operators "
|
666 |
+
f"(ops that do not mutate any inputs, do not return "
|
667 |
+
f"views of the inputs, and has at least one return). "
|
668 |
+
f"Got the following non-functional schema: {schema}"
|
669 |
+
)
|
670 |
+
|
671 |
+
# For simplicity: don't allow self arguments
|
672 |
+
if schema.arguments.self_arg is not None:
|
673 |
+
raise ValueError(
|
674 |
+
f"custom_op does not support arguments named 'self'. Please "
|
675 |
+
f"rename your argument. Got: {schema}"
|
676 |
+
)
|
677 |
+
|
678 |
+
|
679 |
+
def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
|
680 |
+
names = qualname.split("::", 1)
|
681 |
+
if len(names) != 2:
|
682 |
+
raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
|
683 |
+
f"operator name should look something like ns::foo")
|
684 |
+
if '.' in names[1]:
|
685 |
+
raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
|
686 |
+
f"i.e. operator names with '.' in them. "
|
687 |
+
f"Please name your operator something like ns::foo. "
|
688 |
+
f"Got: {qualname}")
|
689 |
+
return names[0], names[1]
|
690 |
+
|
691 |
+
|
692 |
+
def validate_device_type(device_type: str) -> None:
|
693 |
+
if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
|
694 |
+
raise ValueError(
|
695 |
+
f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
|
696 |
+
f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
|
697 |
+
)
|
698 |
+
|
699 |
+
|
700 |
+
def supported_param(param: inspect.Parameter) -> bool:
|
701 |
+
return param.kind in (
|
702 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
703 |
+
inspect.Parameter.KEYWORD_ONLY,
|
704 |
+
)
|
705 |
+
|
706 |
+
|
707 |
+
def validate_function_matches_schema(
|
708 |
+
schema: FunctionSchema, func: typing.Callable
|
709 |
+
) -> None:
|
710 |
+
sig = inspect.signature(func)
|
711 |
+
|
712 |
+
if not all(supported_param(p) for _, p in sig.parameters.items()):
|
713 |
+
raise ValueError(
|
714 |
+
f"custom_op(..., manual_schema)(func): positional-only args, "
|
715 |
+
f"varargs, and kwargs are not supported. Please rewrite `func` "
|
716 |
+
f"to not have them. Got `func` with signature: {sig}"
|
717 |
+
)
|
718 |
+
|
719 |
+
if (
|
720 |
+
any(
|
721 |
+
p.annotation is not inspect.Parameter.empty
|
722 |
+
for _, p in sig.parameters.items()
|
723 |
+
)
|
724 |
+
or sig.return_annotation is not inspect.Signature.empty
|
725 |
+
):
|
726 |
+
raise ValueError(
|
727 |
+
f"custom_op(..., manual_schema)(func): When passing in a manual "
|
728 |
+
f"schema, we expect `func` to have no type annotations to avoid "
|
729 |
+
f"ambiguity. Got `func` with signature: {sig}"
|
730 |
+
)
|
731 |
+
|
732 |
+
positional = [
|
733 |
+
(name, param)
|
734 |
+
for name, param in sig.parameters.items()
|
735 |
+
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
736 |
+
]
|
737 |
+
kwargonly = [
|
738 |
+
(name, param)
|
739 |
+
for name, param in sig.parameters.items()
|
740 |
+
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
741 |
+
]
|
742 |
+
|
743 |
+
def error():
|
744 |
+
raise ValueError(
|
745 |
+
f"custom_op(..., manual_schema)(func): When passing in a manual "
|
746 |
+
f"schema, we expect `func`'s signature to match `manual_schema` "
|
747 |
+
f"(aside from type annotations). "
|
748 |
+
f"func's signature: {sig}, manual_schema: {schema}"
|
749 |
+
)
|
750 |
+
|
751 |
+
def error_default_args():
|
752 |
+
raise ValueError(
|
753 |
+
f"custom_op(..., manual_schema)(func): "
|
754 |
+
f"neither func nor manual_schema should have default "
|
755 |
+
f"arguments. Got "
|
756 |
+
f"func's signature: {sig}, manual_schema: {schema}"
|
757 |
+
)
|
758 |
+
|
759 |
+
def compare(sig_args, schema_args):
|
760 |
+
if len(sig_args) != len(schema_args):
|
761 |
+
error()
|
762 |
+
for (name, param), arg in zip(sig_args, schema_args):
|
763 |
+
if name != arg.name:
|
764 |
+
error()
|
765 |
+
if param.default is not inspect.Parameter.empty or arg.default is not None:
|
766 |
+
error_default_args()
|
767 |
+
|
768 |
+
compare(positional, schema.arguments.flat_positional)
|
769 |
+
compare(kwargonly, schema.arguments.flat_kwarg_only)
|
770 |
+
|
771 |
+
|
772 |
+
def infer_schema(prototype_function: typing.Callable) -> str:
|
773 |
+
sig = inspect.signature(prototype_function)
|
774 |
+
|
775 |
+
def error_fn(what):
|
776 |
+
raise ValueError(
|
777 |
+
f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
|
778 |
+
)
|
779 |
+
|
780 |
+
params = [
|
781 |
+
parse_param(name, param, error_fn) for name, param in sig.parameters.items()
|
782 |
+
]
|
783 |
+
ret = parse_return(sig.return_annotation, error_fn)
|
784 |
+
return f"({', '.join(params)}) -> {ret}"
|
785 |
+
|
786 |
+
|
787 |
+
def parse_param(name, param, error_fn):
|
788 |
+
if not supported_param(param):
|
789 |
+
error_fn("We do not support positional-only args, varargs, or varkwargs.")
|
790 |
+
|
791 |
+
if param.annotation is inspect.Parameter.empty:
|
792 |
+
error_fn(f"Parameter {name} must have a type annotation.")
|
793 |
+
|
794 |
+
if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
|
795 |
+
error_fn(
|
796 |
+
f"Parameter {name} has unsupported type {param.annotation}. "
|
797 |
+
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
798 |
+
)
|
799 |
+
|
800 |
+
if param.default is not inspect.Parameter.empty:
|
801 |
+
error_fn(
|
802 |
+
f"Parameter {name} has a default value; this is not supported. "
|
803 |
+
f"If you want to use default values then create a function with "
|
804 |
+
f"default values that calls the CustomOp"
|
805 |
+
)
|
806 |
+
|
807 |
+
return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
|
808 |
+
|
809 |
+
|
810 |
+
def derived_types(
|
811 |
+
base_type, cpp_type, list_base, optional_base_list, optional_list_base
|
812 |
+
):
|
813 |
+
result = [
|
814 |
+
(base_type, cpp_type),
|
815 |
+
(typing.Optional[base_type], f"{cpp_type}?"),
|
816 |
+
]
|
817 |
+
if list_base:
|
818 |
+
result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type]
|
819 |
+
if optional_base_list:
|
820 |
+
result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type]
|
821 |
+
if optional_list_base:
|
822 |
+
result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type]
|
823 |
+
return result
|
824 |
+
|
825 |
+
|
826 |
+
def get_supported_param_types():
|
827 |
+
data = [
|
828 |
+
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
|
829 |
+
(torch.Tensor, "Tensor", True, True, False),
|
830 |
+
(int, "SymInt", True, False, True),
|
831 |
+
(float, "float", True, False, True),
|
832 |
+
(bool, "bool", True, False, True),
|
833 |
+
(str, "str", False, False, False),
|
834 |
+
(torch.types.Number, "Scalar", True, False, False),
|
835 |
+
(torch.dtype, "ScalarType", False, False, False),
|
836 |
+
(torch.device, "Device", False, False, False),
|
837 |
+
]
|
838 |
+
result = []
|
839 |
+
for line in data:
|
840 |
+
result.extend(derived_types(*line))
|
841 |
+
return dict(result)
|
842 |
+
|
843 |
+
|
844 |
+
SUPPORTED_RETURN_TYPES = {
|
845 |
+
torch.Tensor: "Tensor",
|
846 |
+
typing.List[torch.Tensor]: "Tensor[]",
|
847 |
+
int: "SymInt",
|
848 |
+
float: "float",
|
849 |
+
bool: "bool",
|
850 |
+
torch.types.Number: "Scalar",
|
851 |
+
}
|
852 |
+
|
853 |
+
|
854 |
+
def parse_return(annotation, error_fn):
|
855 |
+
origin = typing.get_origin(annotation)
|
856 |
+
if origin is not tuple:
|
857 |
+
if annotation not in SUPPORTED_RETURN_TYPES.keys():
|
858 |
+
error_fn(
|
859 |
+
f"Return has unsupported type {annotation}. "
|
860 |
+
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
861 |
+
)
|
862 |
+
return SUPPORTED_RETURN_TYPES[annotation]
|
863 |
+
|
864 |
+
args = typing.get_args(annotation)
|
865 |
+
for arg in args:
|
866 |
+
if arg not in SUPPORTED_RETURN_TYPES:
|
867 |
+
error_fn(
|
868 |
+
f"Return has unsupported type {annotation}. "
|
869 |
+
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
870 |
+
)
|
871 |
+
|
872 |
+
return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
|
873 |
+
|
874 |
+
|
875 |
+
SUPPORTED_PARAM_TYPES = get_supported_param_types()
|
876 |
+
|
877 |
+
|
878 |
+
def report_error_callback(custom_op: typing.Any, key: str) -> None:
|
879 |
+
if key == "Undefined":
|
880 |
+
raise NotImplementedError(
|
881 |
+
f"{custom_op}: There were no Tensor inputs to this operator "
|
882 |
+
f"(e.g. you passed an empty list of Tensors). If your operator is a "
|
883 |
+
f"factory function (that is, it takes no Tensors and constructs "
|
884 |
+
f"a new one), then please use CustomOp.impl_factory to register "
|
885 |
+
f"an implementation for it"
|
886 |
+
)
|
887 |
+
if key == "Meta":
|
888 |
+
raise NotImplementedError(
|
889 |
+
f"{custom_op}: when running with device='Meta' tensors: there is no "
|
890 |
+
f"abstract impl registered for this CustomOp. Please register one via "
|
891 |
+
f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
|
892 |
+
)
|
893 |
+
if key in ("CPU", "CUDA"):
|
894 |
+
device = key.lower()
|
895 |
+
raise NotImplementedError(
|
896 |
+
f"{custom_op}: when running with device='{device}' tensors: there is no "
|
897 |
+
f"{device} impl registered for this CustomOp. Please register one via "
|
898 |
+
f"CustomOp.impl(device_type='{device}')"
|
899 |
+
)
|
900 |
+
raise NotImplementedError(
|
901 |
+
f"{custom_op}: No implementation for dispatch key {key}. It is likely "
|
902 |
+
f"that we have not added this functionality yet, please either open an "
|
903 |
+
f"issue or if you're feeling adventurous, use the low-level "
|
904 |
+
f"torch.library API"
|
905 |
+
)
|
906 |
+
|
907 |
+
|
908 |
+
def custom_op_from_existing(op):
|
909 |
+
ns = op.namespace
|
910 |
+
lib = torch.library.Library(ns, "FRAGMENT")
|
911 |
+
name = op.name().split("::")[-1]
|
912 |
+
schema_str = str(op._schema)
|
913 |
+
# CustomOp expects the schema string without the namespace
|
914 |
+
schema_str = schema_str.split("::")[-1]
|
915 |
+
schema = FunctionSchema.parse(schema_str)
|
916 |
+
return CustomOp(lib, ns, schema, name, op, _private_access=True)
|
917 |
+
|
918 |
+
|
919 |
+
def get_op(qualname):
|
920 |
+
def error_not_found():
|
921 |
+
raise ValueError(
|
922 |
+
f"Could not find the operator {qualname}. Please make sure you have "
|
923 |
+
f"already registered the operator and (if registered from C++) "
|
924 |
+
f"loaded it via torch.ops.load_library.")
|
925 |
+
|
926 |
+
ns, name = parse_qualname(qualname)
|
927 |
+
if not hasattr(torch.ops, ns):
|
928 |
+
error_not_found()
|
929 |
+
opnamespace = getattr(torch.ops, ns)
|
930 |
+
if not hasattr(opnamespace, name):
|
931 |
+
error_not_found()
|
932 |
+
packet = getattr(opnamespace, name)
|
933 |
+
if not hasattr(packet, 'default'):
|
934 |
+
error_not_found()
|
935 |
+
return packet.default
|
936 |
+
|
937 |
+
|
938 |
+
def _find_custom_op(qualname, also_check_torch_library=False):
|
939 |
+
if qualname in global_registry:
|
940 |
+
return global_registry[qualname]
|
941 |
+
if not also_check_torch_library:
|
942 |
+
raise RuntimeError(
|
943 |
+
f"Could not find custom op \"{qualname}\". Did you register it via "
|
944 |
+
f"the torch._custom_ops API?")
|
945 |
+
overload = get_op(qualname)
|
946 |
+
result = custom_op_from_existing(overload)
|
947 |
+
return result
|
948 |
+
|
949 |
+
|
950 |
+
def get_abstract_impl(qualname):
|
951 |
+
if qualname not in torch._custom_op.impl.global_registry:
|
952 |
+
return None
|
953 |
+
custom_op = torch._custom_op.impl.global_registry[qualname]
|
954 |
+
if custom_op is None:
|
955 |
+
return None
|
956 |
+
if not custom_op._has_impl("abstract"):
|
957 |
+
return None
|
958 |
+
return custom_op._get_impl("abstract").func
|
959 |
+
|
960 |
+
|
961 |
+
def _custom_op_with_schema(qualname, schema):
|
962 |
+
ns, name = qualname.split("::")
|
963 |
+
schema_str = f"{name}{schema}"
|
964 |
+
function_schema = FunctionSchema.parse(schema_str)
|
965 |
+
validate_schema(function_schema)
|
966 |
+
|
967 |
+
lib = library.Library(ns, "FRAGMENT")
|
968 |
+
lib.define(schema_str)
|
969 |
+
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
970 |
+
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
|
971 |
+
result._register_autograd_kernel_indirection()
|
972 |
+
|
973 |
+
torch._C._dispatch_set_report_error_callback(
|
974 |
+
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
975 |
+
)
|
976 |
+
return get_op(qualname)
|
torch/_custom_ops.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
|
3 |
+
from torch._custom_op.impl import (
|
4 |
+
_custom_op_with_schema,
|
5 |
+
_find_custom_op,
|
6 |
+
infer_schema,
|
7 |
+
parse_qualname,
|
8 |
+
validate_namespace,
|
9 |
+
)
|
10 |
+
from torch.library import get_ctx
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"custom_op",
|
14 |
+
"impl",
|
15 |
+
"impl_abstract",
|
16 |
+
"get_ctx",
|
17 |
+
"impl_save_for_backward",
|
18 |
+
"impl_backward",
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
def custom_op(qualname, func_or_schema=None):
|
23 |
+
r"""Register a new custom operator
|
24 |
+
|
25 |
+
In PyTorch, defining an op (short for "operator") is a two step-process:
|
26 |
+
- we need to define the op (by providing an operator name and schema)
|
27 |
+
- we need to implement behavior for how the operator interacts with
|
28 |
+
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
|
29 |
+
|
30 |
+
This entrypoint defines the custom operator (the first step)
|
31 |
+
you must then perform the second step by calling various
|
32 |
+
``impl_*`` APIs.
|
33 |
+
|
34 |
+
This API may be used as a decorator (see examples).
|
35 |
+
|
36 |
+
For a detailed guide on custom ops, please see
|
37 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
38 |
+
|
39 |
+
Arguments:
|
40 |
+
qualname (str): Should be a string that looks like
|
41 |
+
"namespace::operator_name". Operators in PyTorch need a namespace to
|
42 |
+
avoid name collisions; a given operator may only be created once.
|
43 |
+
If you are writing a Python library, we recommend the namespace to
|
44 |
+
be the name of your top-level module.
|
45 |
+
func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
|
46 |
+
schema that tells PyTorch the types of the inputs/outputs.
|
47 |
+
If this is a Callable, we will automatically infer the schema from
|
48 |
+
the type annotations on the function (see examples). Otherwise,
|
49 |
+
if you don't want to use type annotations, you may provide us the
|
50 |
+
schema string.
|
51 |
+
|
52 |
+
Example::
|
53 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
54 |
+
>>> import torch
|
55 |
+
>>> import numpy as np
|
56 |
+
>>> from torch import Tensor
|
57 |
+
>>>
|
58 |
+
>>> # Step 1: define the custom op.
|
59 |
+
>>> # We need to provide the API a "prototype function"
|
60 |
+
>>> # (a function that returns NotImplementedError), from which
|
61 |
+
>>> # we will infer the types of the inputs and outputs.
|
62 |
+
>>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
|
63 |
+
>>> def numpy_sin(x: Tensor) -> Tensor:
|
64 |
+
>>> raise NotImplementedError()
|
65 |
+
>>>
|
66 |
+
>>> # The custom op is now accessible via the torch.ops module:
|
67 |
+
>>> torch.ops.mylibrary.numpy_sin
|
68 |
+
>>>
|
69 |
+
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
70 |
+
>>>
|
71 |
+
>>> # Register an implementation for CPU tensors
|
72 |
+
>>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
|
73 |
+
>>> def numpy_sin_impl_cpu(x):
|
74 |
+
>>> return torch.from_numpy(np.sin(x.numpy()))
|
75 |
+
>>>
|
76 |
+
>>> # Register an implementation for CUDA tensors
|
77 |
+
>>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
|
78 |
+
>>> def numpy_sin_impl_cuda(x):
|
79 |
+
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
|
80 |
+
>>>
|
81 |
+
>>> x = torch.randn(3)
|
82 |
+
>>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
|
83 |
+
>>>
|
84 |
+
>>> x_cuda = x.cuda()
|
85 |
+
>>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
|
86 |
+
|
87 |
+
"""
|
88 |
+
ns, name = parse_qualname(qualname)
|
89 |
+
validate_namespace(ns)
|
90 |
+
|
91 |
+
def inner(func):
|
92 |
+
if not inspect.isfunction(func):
|
93 |
+
raise ValueError(
|
94 |
+
f"custom_op(...)(func): Expected `func` to be a Python "
|
95 |
+
f"function, got: {type(func)}"
|
96 |
+
)
|
97 |
+
|
98 |
+
if func.__name__ != name:
|
99 |
+
raise ValueError(
|
100 |
+
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
|
101 |
+
f"to have name '{name}' but got '{func.__name__}'. "
|
102 |
+
f"Please either change the name of `func` or the qualname that "
|
103 |
+
f"is passed to `custom_op`"
|
104 |
+
)
|
105 |
+
|
106 |
+
schema = infer_schema(func)
|
107 |
+
_custom_op_with_schema(qualname, schema)
|
108 |
+
return func
|
109 |
+
|
110 |
+
if func_or_schema is None:
|
111 |
+
return inner
|
112 |
+
if isinstance(func_or_schema, str):
|
113 |
+
_custom_op_with_schema(qualname, func_or_schema)
|
114 |
+
else:
|
115 |
+
return inner(func_or_schema)
|
116 |
+
|
117 |
+
|
118 |
+
def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
|
119 |
+
r"""Register an implementation for a device type for this custom op.
|
120 |
+
|
121 |
+
If the op is passed multiple Tensor inputs with different device
|
122 |
+
types, it will dispatch to the registered implementation for the highest
|
123 |
+
priority device type among those present.
|
124 |
+
The supported device types, in order of priority, are {'cuda', 'cpu'}.
|
125 |
+
|
126 |
+
This API may be used as a decorator (see examples).
|
127 |
+
|
128 |
+
For a detailed guide on custom ops, please see
|
129 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
130 |
+
|
131 |
+
Arguments:
|
132 |
+
device_types (str or Iterable[str]): the device type(s) to register the function for.
|
133 |
+
|
134 |
+
Example::
|
135 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
136 |
+
>>> import torch
|
137 |
+
>>> import numpy as np
|
138 |
+
>>> from torch import Tensor
|
139 |
+
>>>
|
140 |
+
>>> # Step 1: define the custom op.
|
141 |
+
>>> # We need to provide the API a "prototype function"
|
142 |
+
>>> # (a function that returns NotImplementedError), from which
|
143 |
+
>>> # we will infer the types of the inputs and outputs.
|
144 |
+
>>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
|
145 |
+
>>> def numpy_cos(x: Tensor) -> Tensor:
|
146 |
+
>>> raise NotImplementedError()
|
147 |
+
>>>
|
148 |
+
>>> # The custom op is now accessible via the torch.ops module:
|
149 |
+
>>> torch.ops.mylibrary.numpy_cos
|
150 |
+
>>>
|
151 |
+
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
152 |
+
>>>
|
153 |
+
>>> # Register an implementation for CPU tensors
|
154 |
+
>>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
|
155 |
+
>>> def numpy_cos_impl_cpu(x):
|
156 |
+
>>> return torch.from_numpy(np.cos(x.numpy()))
|
157 |
+
>>>
|
158 |
+
>>> # Register an implementation for CUDA tensors
|
159 |
+
>>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
|
160 |
+
>>> def numpy_cos_impl_cuda(x):
|
161 |
+
>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
|
162 |
+
>>>
|
163 |
+
>>> x = torch.randn(3)
|
164 |
+
>>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
|
165 |
+
>>>
|
166 |
+
>>> x_cuda = x.cuda()
|
167 |
+
>>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
|
168 |
+
|
169 |
+
"""
|
170 |
+
|
171 |
+
def inner(func):
|
172 |
+
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
173 |
+
custom_op.impl(device_types, _stacklevel=3)(func)
|
174 |
+
return func
|
175 |
+
|
176 |
+
if func is None:
|
177 |
+
return inner
|
178 |
+
return inner(func)
|
179 |
+
|
180 |
+
|
181 |
+
def impl_abstract(qualname, *, func=None):
|
182 |
+
r"""Register an abstract implementation for this operator.
|
183 |
+
|
184 |
+
An "abstract implementation" specifies the behavior of this operator on
|
185 |
+
Tensors that carry no data. Given some input Tensors with certain properties
|
186 |
+
(sizes/strides/storage_offset/device), it specifies what the properties of
|
187 |
+
the output Tensors are.
|
188 |
+
|
189 |
+
The abstract implementation has the same signature as the operator.
|
190 |
+
It is run for both FakeTensors and meta tensors. To write an abstract
|
191 |
+
implementation, assume that all Tensor inputs to the operator are
|
192 |
+
regular CPU/CUDA/Meta tensors, but they do not have storage, and
|
193 |
+
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
|
194 |
+
The abstract implementation must consist of only PyTorch operations
|
195 |
+
(and may not directly access the storage or data of any input or
|
196 |
+
intermediate Tensors).
|
197 |
+
|
198 |
+
This API may be used as a decorator (see examples).
|
199 |
+
|
200 |
+
For a detailed guide on custom ops, please see
|
201 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
202 |
+
|
203 |
+
Examples::
|
204 |
+
>>> import numpy as np
|
205 |
+
>>> from torch import Tensor
|
206 |
+
>>>
|
207 |
+
>>> # Example 1: an operator without data-dependent output shape
|
208 |
+
>>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
|
209 |
+
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
210 |
+
>>> raise NotImplementedError()
|
211 |
+
>>>
|
212 |
+
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
|
213 |
+
>>> def custom_linear_abstract(x, weight):
|
214 |
+
>>> assert x.dim() == 2
|
215 |
+
>>> assert weight.dim() == 2
|
216 |
+
>>> assert bias.dim() == 1
|
217 |
+
>>> assert x.shape[1] == weight.shape[1]
|
218 |
+
>>> assert weight.shape[0] == bias.shape[0]
|
219 |
+
>>> assert x.device == weight.device
|
220 |
+
>>>
|
221 |
+
>>> return (x @ weight.t()) + bias
|
222 |
+
>>>
|
223 |
+
>>> # Example 2: an operator with data-dependent output shape
|
224 |
+
>>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
|
225 |
+
>>> def custom_nonzero(x: Tensor) -> Tensor:
|
226 |
+
>>> ...
|
227 |
+
>>>
|
228 |
+
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
|
229 |
+
>>> def custom_nonzero_abstract(x):
|
230 |
+
>>> # Number of nonzero-elements is data-dependent.
|
231 |
+
>>> # Since we cannot peek at the data in an abstract impl,
|
232 |
+
>>> # we use the ctx object to construct a new symint that
|
233 |
+
>>> # represents the data-dependent size.
|
234 |
+
>>> ctx = torch._custom_ops.get_ctx()
|
235 |
+
>>> nnz = ctx.create_unbacked_symint()
|
236 |
+
>>> shape = [x.dim(), nnz]
|
237 |
+
>>> result = x.new_empty(shape, dtype=torch.long)
|
238 |
+
>>> return result
|
239 |
+
>>>
|
240 |
+
>>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
|
241 |
+
>>> def custom_nonzero_impl(x):
|
242 |
+
>>> x_np = to_numpy(x)
|
243 |
+
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
244 |
+
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
|
245 |
+
>>> # constrain the range to at least 2
|
246 |
+
>>> if res.shape[0] <= 1:
|
247 |
+
>>> raise RuntimeError("not supported")
|
248 |
+
>>> return torch.tensor(res, device=x.device)
|
249 |
+
|
250 |
+
"""
|
251 |
+
import torch.library
|
252 |
+
|
253 |
+
return torch.library.impl_abstract(qualname, func, _stacklevel=2)
|
254 |
+
|
255 |
+
|
256 |
+
def impl_save_for_backward(qualname, *, func=None):
|
257 |
+
r"""Register a function that tells us what to save for backward.
|
258 |
+
|
259 |
+
Please see :func:`impl_backward` for more details.
|
260 |
+
"""
|
261 |
+
|
262 |
+
def inner(func):
|
263 |
+
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
264 |
+
custom_op.impl_save_for_backward(_stacklevel=3)(func)
|
265 |
+
return func
|
266 |
+
|
267 |
+
if func is None:
|
268 |
+
return inner
|
269 |
+
return inner(func)
|
270 |
+
|
271 |
+
|
272 |
+
def impl_backward(qualname, output_differentiability=None, *, func=None):
|
273 |
+
r"""Registers a backward formula for an operator.
|
274 |
+
|
275 |
+
In order for an operator to work with autograd, you need to register
|
276 |
+
a backward formula. There are two pieces to this:
|
277 |
+
1. You must give us a function to specify what to save for backward.
|
278 |
+
Call this the "save for backward" function.
|
279 |
+
2. You must give us a function that computes gradients. Call this the
|
280 |
+
"backward" function.
|
281 |
+
|
282 |
+
Use `impl_save_for_backward` to define a "save for backward" function
|
283 |
+
that specifies what gets saved for backward. The function should accept
|
284 |
+
two arguments ``(inputs, output)`` and return the quantities to be saved
|
285 |
+
for backward.
|
286 |
+
|
287 |
+
During runtime, when you call the operator in a forwards pass, PyTorch
|
288 |
+
will invoke the "save for backward" function with the inputs and output
|
289 |
+
of the operator.
|
290 |
+
|
291 |
+
Use `impl_backward` to define the "backward" function. The backward
|
292 |
+
function must accept ``(ctx, saved, *grads)``:
|
293 |
+
- ``ctx`` is a context object where we may provide information
|
294 |
+
- ``saved`` is exactly what gets returned from the "save for backward"
|
295 |
+
function
|
296 |
+
- ``grads`` is one or more gradients. The number of gradients matches
|
297 |
+
the number of outputs of the operator.
|
298 |
+
|
299 |
+
The backward function must return a dict that maps the name of
|
300 |
+
an input to the operator to its corresponding gradient. All inputs that
|
301 |
+
were declared to be Tensors in the operator definition must be accounted
|
302 |
+
for in the dict. The gradient may be a Tensor or None.
|
303 |
+
|
304 |
+
For a detailed guide on custom ops, please see
|
305 |
+
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
306 |
+
|
307 |
+
"""
|
308 |
+
|
309 |
+
def inner(func):
|
310 |
+
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
311 |
+
custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
|
312 |
+
return func
|
313 |
+
|
314 |
+
if func is None:
|
315 |
+
return inner
|
316 |
+
return inner(func)
|
317 |
+
|
318 |
+
|
319 |
+
def _destroy(qualname):
|
320 |
+
"""De-registers a custom op. For testing purposes only"""
|
321 |
+
custom_op = _find_custom_op(qualname)
|
322 |
+
custom_op._destroy()
|
torch/_decomp/__init__.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from collections import defaultdict
|
3 |
+
from functools import wraps
|
4 |
+
from itertools import chain
|
5 |
+
from typing import Callable, Dict, List, Sequence, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.library
|
9 |
+
from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
|
10 |
+
from torch._prims_common import CustomOutParamAnnotation
|
11 |
+
from torch.utils import _pytree as pytree
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"decomposition_table",
|
15 |
+
"pre_autograd_decomposition_table",
|
16 |
+
"meta_table",
|
17 |
+
"register_decomposition",
|
18 |
+
"get_decompositions",
|
19 |
+
"core_aten_decompositions",
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
# TODO: relax key type here; torch registrations should be possible to; but
|
24 |
+
# right now this type is accurate
|
25 |
+
global_decomposition_table: Dict[
|
26 |
+
str, Dict[torch._ops.OperatorBase, Callable]
|
27 |
+
] = defaultdict(dict)
|
28 |
+
|
29 |
+
decomposition_table = global_decomposition_table["post_autograd"]
|
30 |
+
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
|
31 |
+
meta_table = global_decomposition_table["meta"]
|
32 |
+
|
33 |
+
|
34 |
+
def _add_op_to_registry(registry, op, fn):
|
35 |
+
"""
|
36 |
+
This is an internal API for adding an op to the decomposition table.
|
37 |
+
|
38 |
+
If op is OpOverload, it will be added to the registry directly.
|
39 |
+
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
|
40 |
+
"""
|
41 |
+
overloads: List[Union[torch._ops.OperatorBase]] = []
|
42 |
+
if isinstance(op, HigherOrderOperator):
|
43 |
+
# There's no concept of overloads for HigherOrderOperator
|
44 |
+
registry[op] = fn
|
45 |
+
return
|
46 |
+
elif isinstance(op, OpOverload):
|
47 |
+
overloads.append(op)
|
48 |
+
else:
|
49 |
+
assert isinstance(op, OpOverloadPacket)
|
50 |
+
for ol in op.overloads():
|
51 |
+
overloads.append(getattr(op, ol))
|
52 |
+
|
53 |
+
for op_overload in overloads:
|
54 |
+
if op_overload in registry:
|
55 |
+
raise RuntimeError(f"duplicate registrations for {op_overload}")
|
56 |
+
# TorchScript dumps a bunch of extra nonsense overloads
|
57 |
+
# which don't have corresponding dispatcher entries, we need
|
58 |
+
# to filter those out, e.g aten.add.float_int
|
59 |
+
if torch._C._dispatch_has_kernel(op_overload.name()):
|
60 |
+
registry[op_overload] = fn
|
61 |
+
|
62 |
+
|
63 |
+
def _convert_out_params(f):
|
64 |
+
out_annotation = f.__annotations__.get("out")
|
65 |
+
|
66 |
+
# If there are no out params, do not wrap the function.
|
67 |
+
if not out_annotation:
|
68 |
+
return f
|
69 |
+
|
70 |
+
# Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
|
71 |
+
if getattr(out_annotation, "__origin__", None) is tuple:
|
72 |
+
sig = inspect.signature(f)
|
73 |
+
out_names = sig.return_annotation._fields
|
74 |
+
# If out is a tuple, we need to register a function that unpacks all the out
|
75 |
+
# elements as this is what native_functions.yaml expects
|
76 |
+
|
77 |
+
@wraps(f)
|
78 |
+
def _fn(*args, **kwargs):
|
79 |
+
out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
|
80 |
+
# Either all of the out kwargs are set or none of them
|
81 |
+
is_none = out_kwargs[0] is None
|
82 |
+
assert all((o is None) == is_none for o in out_kwargs)
|
83 |
+
return f(*args, **kwargs, out=None if is_none else out_kwargs)
|
84 |
+
|
85 |
+
out_params = [
|
86 |
+
inspect.Parameter(
|
87 |
+
o,
|
88 |
+
kind=inspect.Parameter.KEYWORD_ONLY,
|
89 |
+
default=None,
|
90 |
+
annotation=t,
|
91 |
+
)
|
92 |
+
for o, t in zip(out_names, out_annotation.__args__)
|
93 |
+
]
|
94 |
+
# Drop the out parameter and concatenate the new kwargs in the signature
|
95 |
+
params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
|
96 |
+
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
97 |
+
parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
|
98 |
+
)
|
99 |
+
# Drop the out parameter and concatenate the new kwargs in the annotations
|
100 |
+
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
|
101 |
+
for o in out_params:
|
102 |
+
_fn.__annotations__[o.name] = o.annotation
|
103 |
+
|
104 |
+
# Propagate that this function is wrapped by `out_wrapper`
|
105 |
+
_fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
|
106 |
+
|
107 |
+
return _fn
|
108 |
+
|
109 |
+
# Alternatively, there may be a single tensor out parameter with a name
|
110 |
+
# other than "out". This will need special treatment and is indicated by an
|
111 |
+
# annotation, which we will remove here so it is not exposed after wrapping.
|
112 |
+
custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
|
113 |
+
if custom_out_param_name:
|
114 |
+
|
115 |
+
@wraps(f)
|
116 |
+
def _fn(*args, **kwargs):
|
117 |
+
out_kwarg = kwargs.pop(custom_out_param_name, None)
|
118 |
+
return f(*args, **kwargs, out=out_kwarg)
|
119 |
+
|
120 |
+
out_param = inspect.Parameter(
|
121 |
+
custom_out_param_name,
|
122 |
+
kind=inspect.Parameter.KEYWORD_ONLY,
|
123 |
+
default=None,
|
124 |
+
annotation=out_annotation,
|
125 |
+
)
|
126 |
+
|
127 |
+
# Drop the out parameter and concatenate the new kwarg in the signature
|
128 |
+
sig = inspect.signature(f)
|
129 |
+
params = chain(
|
130 |
+
(v for k, v in sig.parameters.items() if k != "out"), (out_param,)
|
131 |
+
)
|
132 |
+
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
133 |
+
parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
|
134 |
+
)
|
135 |
+
|
136 |
+
# Drop the out parameter and concatenate the new kwargs in the annotations
|
137 |
+
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
|
138 |
+
_fn.__annotations__[out_param.name] = out_param.annotation
|
139 |
+
|
140 |
+
return _fn
|
141 |
+
|
142 |
+
return f
|
143 |
+
|
144 |
+
|
145 |
+
def register_decomposition(
|
146 |
+
aten_op, registry=None, *, type="post_autograd", unsafe=False
|
147 |
+
):
|
148 |
+
"""
|
149 |
+
A decorator to register a function as a decomposition to the Python
|
150 |
+
decomposition table. Use it like this::
|
151 |
+
|
152 |
+
@register_decomposition(torch.ops.aten.clamp_min)
|
153 |
+
def clamp_min(x):
|
154 |
+
return torch.clamp(self, min=min)
|
155 |
+
|
156 |
+
If you are writing a new decomposition, consider contributing it
|
157 |
+
directly to PyTorch in torch._decomp.decompositions.
|
158 |
+
|
159 |
+
This API is experimental; we are almost certainly going to extend
|
160 |
+
the API when we make decompositions eligible for use in transforms (e.g.,
|
161 |
+
autograd) and not just backend tracing, where we then need to know if a
|
162 |
+
decomposition can be used to simulate a transform.
|
163 |
+
|
164 |
+
By default, we also will register it to the Meta key of dispatcher,
|
165 |
+
and replace the c++ Meta implementation if there is already one.
|
166 |
+
|
167 |
+
unsafe kwarg is for reuse of this function for registering non-function
|
168 |
+
things
|
169 |
+
"""
|
170 |
+
|
171 |
+
assert type in {"post_autograd", "pre_autograd", "meta"}
|
172 |
+
|
173 |
+
def decomposition_decorator(fn: Callable) -> Callable:
|
174 |
+
if not unsafe:
|
175 |
+
fn = _convert_out_params(fn)
|
176 |
+
|
177 |
+
nonlocal registry
|
178 |
+
if registry is None:
|
179 |
+
registry = global_decomposition_table[type]
|
180 |
+
|
181 |
+
def register(op):
|
182 |
+
_add_op_to_registry(registry, op, fn)
|
183 |
+
|
184 |
+
# To handle allowing multiple aten_ops at once
|
185 |
+
pytree.tree_map_(register, aten_op)
|
186 |
+
return fn
|
187 |
+
|
188 |
+
return decomposition_decorator
|
189 |
+
|
190 |
+
|
191 |
+
def get_decompositions(
|
192 |
+
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
|
193 |
+
type: str = "post_autograd",
|
194 |
+
) -> Dict[torch._ops.OperatorBase, Callable]:
|
195 |
+
"""
|
196 |
+
Retrieve a dictionary of decompositions corresponding to the list of
|
197 |
+
operator overloads and overload packets passed as input. Overload
|
198 |
+
packets will include all decomposed overloads in the packet. If there is
|
199 |
+
no decomposition for a requested operator, it is silently ignored.
|
200 |
+
|
201 |
+
This API is experimental; we are almost certainly going to give an alternate,
|
202 |
+
more recommended formulation, where a user provides the set of operators
|
203 |
+
they know how to implement, and we provide decompositions for everything
|
204 |
+
not in this set.
|
205 |
+
"""
|
206 |
+
assert type in {"post_autograd", "pre_autograd", "meta"}
|
207 |
+
|
208 |
+
registry = global_decomposition_table[type]
|
209 |
+
packets_to_overloads = defaultdict(list)
|
210 |
+
for opo in registry:
|
211 |
+
if isinstance(opo, (OpOverload, OpOverloadPacket)):
|
212 |
+
packets_to_overloads[opo.overloadpacket].append(opo)
|
213 |
+
decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
|
214 |
+
for op in aten_ops:
|
215 |
+
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
|
216 |
+
for op_overload in packets_to_overloads[op]:
|
217 |
+
decompositions[op_overload] = registry[op_overload]
|
218 |
+
elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
|
219 |
+
decompositions[op] = registry[op]
|
220 |
+
return decompositions
|
221 |
+
|
222 |
+
|
223 |
+
def remove_decompositions(
|
224 |
+
decompositions: Dict[torch._ops.OperatorBase, Callable],
|
225 |
+
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
|
226 |
+
) -> None:
|
227 |
+
"""
|
228 |
+
Given a dictionary of decompositions obtained from get_decompositions(), removes
|
229 |
+
operators associated with a list of operator overloads and overload packets passed
|
230 |
+
as input. If the decomposition dictionary does not contain a decomposition that is
|
231 |
+
specified to be removed, it is silently ignored.
|
232 |
+
"""
|
233 |
+
for op in aten_ops:
|
234 |
+
if isinstance(op, OpOverloadPacket):
|
235 |
+
for overload_name in op.overloads():
|
236 |
+
opo = getattr(op, overload_name)
|
237 |
+
decompositions.pop(opo, None)
|
238 |
+
elif isinstance(op, OpOverload):
|
239 |
+
decompositions.pop(op, None)
|
240 |
+
|
241 |
+
|
242 |
+
# populate the table
|
243 |
+
import torch._decomp.decompositions
|
244 |
+
import torch._refs
|
245 |
+
|
246 |
+
|
247 |
+
# See NOTE [Core ATen Ops]
|
248 |
+
#
|
249 |
+
# list was copied from torch/_inductor/decomposition.py
|
250 |
+
# excluding decompositions that results in prim ops
|
251 |
+
# Resulting opset of decomposition is core aten ops
|
252 |
+
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
253 |
+
aten = torch.ops.aten
|
254 |
+
return get_decompositions(
|
255 |
+
[
|
256 |
+
aten.addcdiv,
|
257 |
+
aten.addcdiv_,
|
258 |
+
aten.addcmul,
|
259 |
+
aten.addcmul_,
|
260 |
+
aten.addr,
|
261 |
+
aten.affine_grid_generator,
|
262 |
+
aten.all,
|
263 |
+
aten.aminmax,
|
264 |
+
aten.arange.default,
|
265 |
+
aten.arange.start,
|
266 |
+
aten.avg_pool2d_backward,
|
267 |
+
aten.baddbmm,
|
268 |
+
aten.binary_cross_entropy,
|
269 |
+
aten.binary_cross_entropy_backward,
|
270 |
+
aten.binary_cross_entropy_with_logits,
|
271 |
+
aten.celu,
|
272 |
+
aten.celu_,
|
273 |
+
aten.clamp_max,
|
274 |
+
aten.clamp_min,
|
275 |
+
aten.col2im,
|
276 |
+
aten.count_nonzero,
|
277 |
+
aten.cudnn_batch_norm,
|
278 |
+
aten.cudnn_batch_norm_backward,
|
279 |
+
aten.deg2rad,
|
280 |
+
aten.deg2rad_,
|
281 |
+
aten.detach,
|
282 |
+
aten.diag_embed,
|
283 |
+
aten.diagonal_backward,
|
284 |
+
aten.dot,
|
285 |
+
aten.vdot,
|
286 |
+
aten.elu,
|
287 |
+
aten.elu_,
|
288 |
+
aten.elu_backward,
|
289 |
+
aten._embedding_bag,
|
290 |
+
aten.embedding_dense_backward,
|
291 |
+
aten.empty_like,
|
292 |
+
aten._euclidean_dist.default,
|
293 |
+
aten.expand_as,
|
294 |
+
aten.eye,
|
295 |
+
aten.fill,
|
296 |
+
aten.fill_,
|
297 |
+
aten.floor_divide,
|
298 |
+
aten.frac,
|
299 |
+
aten.frac_,
|
300 |
+
aten._fused_moving_avg_obs_fq_helper,
|
301 |
+
aten.gelu_,
|
302 |
+
aten.gelu_backward,
|
303 |
+
aten.glu,
|
304 |
+
aten.glu_backward,
|
305 |
+
aten.hardshrink,
|
306 |
+
aten.hardsigmoid,
|
307 |
+
aten.hardsigmoid_,
|
308 |
+
aten.hardsigmoid_backward,
|
309 |
+
aten.hardswish,
|
310 |
+
aten.hardswish_,
|
311 |
+
aten.hardswish_backward,
|
312 |
+
aten.hardtanh_,
|
313 |
+
aten.hardtanh_backward,
|
314 |
+
aten.heaviside,
|
315 |
+
aten.heaviside_,
|
316 |
+
aten.huber_loss,
|
317 |
+
aten.huber_loss_backward,
|
318 |
+
aten.im2col,
|
319 |
+
aten.index_add,
|
320 |
+
aten.index_add_,
|
321 |
+
aten.index_copy,
|
322 |
+
aten.index_copy_,
|
323 |
+
aten.index_fill,
|
324 |
+
aten.index_fill_,
|
325 |
+
aten.isneginf,
|
326 |
+
aten.isposinf,
|
327 |
+
aten.l1_loss,
|
328 |
+
aten.leaky_relu_,
|
329 |
+
aten.leaky_relu_backward,
|
330 |
+
aten.lerp,
|
331 |
+
aten.lerp_,
|
332 |
+
aten.linspace,
|
333 |
+
aten.logaddexp,
|
334 |
+
aten.logaddexp2,
|
335 |
+
aten.logit,
|
336 |
+
aten.logit_,
|
337 |
+
aten.logit_backward,
|
338 |
+
aten.log_sigmoid_backward,
|
339 |
+
aten.log_sigmoid_forward,
|
340 |
+
aten._log_softmax_backward_data,
|
341 |
+
aten.logspace,
|
342 |
+
aten.logsumexp.default,
|
343 |
+
aten.masked_fill,
|
344 |
+
aten.masked_fill_,
|
345 |
+
aten.mish,
|
346 |
+
aten.mish_,
|
347 |
+
aten.mse_loss,
|
348 |
+
aten.mse_loss_backward,
|
349 |
+
aten.multi_margin_loss,
|
350 |
+
aten.multilabel_margin_loss_forward,
|
351 |
+
aten.mv,
|
352 |
+
aten.mvlgamma,
|
353 |
+
aten.mvlgamma_,
|
354 |
+
aten.nansum,
|
355 |
+
aten.nan_to_num,
|
356 |
+
aten.nan_to_num_,
|
357 |
+
aten.narrow,
|
358 |
+
aten.native_batch_norm_backward,
|
359 |
+
aten.native_dropout_backward,
|
360 |
+
aten.native_group_norm_backward,
|
361 |
+
aten.native_layer_norm_backward,
|
362 |
+
aten.new_empty,
|
363 |
+
aten.new_full,
|
364 |
+
aten.new_ones,
|
365 |
+
aten.new_zeros,
|
366 |
+
aten.nll_loss_backward,
|
367 |
+
aten.nll_loss_forward,
|
368 |
+
aten.norm,
|
369 |
+
aten.ones,
|
370 |
+
aten.ones_like,
|
371 |
+
aten._prelu_kernel,
|
372 |
+
aten._prelu_kernel_backward,
|
373 |
+
aten._reshape_alias,
|
374 |
+
aten.rad2deg,
|
375 |
+
aten.rad2deg_,
|
376 |
+
aten.renorm,
|
377 |
+
aten.renorm_,
|
378 |
+
aten.replication_pad2d,
|
379 |
+
aten.rot90,
|
380 |
+
aten.rrelu_with_noise,
|
381 |
+
aten.rrelu_with_noise_,
|
382 |
+
aten.rsub.Scalar,
|
383 |
+
aten.rsub.Tensor,
|
384 |
+
aten._scaled_dot_product_flash_attention.default,
|
385 |
+
aten.select_backward,
|
386 |
+
aten.select_scatter,
|
387 |
+
aten.sgn,
|
388 |
+
aten.sgn_,
|
389 |
+
aten.sigmoid_backward,
|
390 |
+
aten.silu,
|
391 |
+
aten.silu_,
|
392 |
+
aten.silu_backward,
|
393 |
+
aten.sinc,
|
394 |
+
aten.sinc_,
|
395 |
+
aten.slice_backward,
|
396 |
+
aten.smooth_l1_loss,
|
397 |
+
aten.smooth_l1_loss_backward,
|
398 |
+
aten.soft_margin_loss,
|
399 |
+
aten.soft_margin_loss_backward,
|
400 |
+
aten._softmax_backward_data,
|
401 |
+
aten.softplus,
|
402 |
+
aten.softplus_backward,
|
403 |
+
aten.softshrink,
|
404 |
+
aten.special_entr,
|
405 |
+
aten.special_log_ndtr,
|
406 |
+
aten.special_xlog1py,
|
407 |
+
aten.split.Tensor,
|
408 |
+
aten.squeeze.default,
|
409 |
+
aten.squeeze.dim,
|
410 |
+
aten.std,
|
411 |
+
aten.std_mean,
|
412 |
+
aten.stack,
|
413 |
+
aten.sum.default,
|
414 |
+
aten.sum.out,
|
415 |
+
aten.t,
|
416 |
+
aten.tanh_backward,
|
417 |
+
aten.threshold,
|
418 |
+
aten.threshold_,
|
419 |
+
aten.threshold_backward,
|
420 |
+
aten.trace,
|
421 |
+
aten.transpose.int,
|
422 |
+
aten.tril,
|
423 |
+
aten.tril_,
|
424 |
+
aten.triu,
|
425 |
+
aten.triu_,
|
426 |
+
aten.unbind,
|
427 |
+
aten.unfold_backward,
|
428 |
+
aten.unfold_copy,
|
429 |
+
aten._unsafe_index,
|
430 |
+
aten.unsafe_split.Tensor,
|
431 |
+
aten.unsafe_split_with_sizes,
|
432 |
+
aten._unsafe_view,
|
433 |
+
aten.upsample_bilinear2d,
|
434 |
+
aten.upsample_nearest2d_backward,
|
435 |
+
aten.view_as_complex,
|
436 |
+
aten.xlogy,
|
437 |
+
aten.xlogy_,
|
438 |
+
aten.zero,
|
439 |
+
aten.zero_,
|
440 |
+
aten.zeros,
|
441 |
+
aten.zeros_like,
|
442 |
+
aten._weight_norm_interface,
|
443 |
+
]
|
444 |
+
)
|
torch/_decomp/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
torch/_decomp/__pycache__/decompositions.cpython-310.pyc
ADDED
Binary file (102 kB). View file
|
|
torch/_decomp/__pycache__/decompositions_for_jvp.cpython-310.pyc
ADDED
Binary file (6.27 kB). View file
|
|
torch/_decomp/__pycache__/decompositions_for_rng.cpython-310.pyc
ADDED
Binary file (7.99 kB). View file
|
|
torch/_decomp/decompositions.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
torch/_decomp/decompositions_for_jvp.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch._decomp
|
6 |
+
from torch import Tensor
|
7 |
+
from torch._prims_common.wrappers import _maybe_remove_out_wrapper
|
8 |
+
|
9 |
+
decomposition_table = torch._decomp.decomposition_table
|
10 |
+
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
|
11 |
+
register_decomposition = torch._decomp.register_decomposition
|
12 |
+
aten = torch.ops.aten
|
13 |
+
|
14 |
+
# NOTE: [forward-mode AD decompositions mechanism]
|
15 |
+
#
|
16 |
+
# The mechanism is in VariableType,
|
17 |
+
# IF any inputs have forward grad
|
18 |
+
# AND there is no forward AD formula implemented
|
19 |
+
# AND the functions is actually differentiable
|
20 |
+
# run the decomposition
|
21 |
+
# See run_jit_decomposition_with_args_for_jvp
|
22 |
+
# We currently use python decompositions that we torchscript.
|
23 |
+
#
|
24 |
+
# Note that we would be building the backward graph at the decomposed level
|
25 |
+
# too, but that is OK, because we would've errored out otherwise anyway.
|
26 |
+
#
|
27 |
+
# TODO: The mechanism we are using to register decompositions doesn't
|
28 |
+
# seem to be exclusively used for jvp. So open question here is whether
|
29 |
+
# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
|
30 |
+
# If that is the case, we may go down the decomposition path unexpectedly
|
31 |
+
# (and possibly produce an unintelligible error) vs erroring out earlier and
|
32 |
+
# printing that the forward AD formula is not implemented.
|
33 |
+
#
|
34 |
+
# The solution to this may be to have a explicitly white list control when
|
35 |
+
# to enable the decomposition.
|
36 |
+
|
37 |
+
|
38 |
+
def maybe_register_decomposition(op):
|
39 |
+
def decorator(f):
|
40 |
+
try:
|
41 |
+
return register_decomposition(op)(f)
|
42 |
+
except Exception:
|
43 |
+
return f
|
44 |
+
|
45 |
+
return decorator
|
46 |
+
|
47 |
+
|
48 |
+
# Functions where we need a special decomposition for jvp but there's another version that
|
49 |
+
# should be used more generally (ex. for jvp we need to recompute the mean and variance for
|
50 |
+
# the backwards of a normalization function. Without jvp, it should use the saved value)
|
51 |
+
decomposition_table_for_jvp = {}
|
52 |
+
|
53 |
+
|
54 |
+
def register_decomposition_for_jvp(fn):
|
55 |
+
return register_decomposition(fn, registry=decomposition_table_for_jvp)
|
56 |
+
|
57 |
+
|
58 |
+
def _register_jit_decomposition_for_jvp(decomp, use_python=False):
|
59 |
+
if decomp in decomposition_table_for_jvp:
|
60 |
+
decomposition_table_used = decomposition_table_for_jvp
|
61 |
+
elif decomp in decomposition_table:
|
62 |
+
decomposition_table_used = decomposition_table
|
63 |
+
else:
|
64 |
+
raise RuntimeError(f"could not find decomposition for {decomp}")
|
65 |
+
decomp_fn = decomposition_table_used[decomp]
|
66 |
+
|
67 |
+
# `out_wrapper` extends a decompositions signature with
|
68 |
+
# an `out` parameter. However jit will use the unwrapped function's
|
69 |
+
# signature instead so we need to unwrap here to prevent an error
|
70 |
+
decomp_fn = _maybe_remove_out_wrapper(decomp_fn)
|
71 |
+
|
72 |
+
if use_python:
|
73 |
+
decomp_fn = torch.jit.ignore(decomp_fn)
|
74 |
+
sig = inspect.signature(decomp_fn)
|
75 |
+
|
76 |
+
# Create a string wrapping the function from the signature
|
77 |
+
# example output:
|
78 |
+
# def wrapped_decomp(x: torch.Tensor, y: int, z: int):
|
79 |
+
# return decomp_fn(x, y, z)
|
80 |
+
# Thanks copilot!
|
81 |
+
def get_function_def(sig):
|
82 |
+
param_def = [f"{param_str}" for param_str in sig.parameters.values()]
|
83 |
+
param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
|
84 |
+
|
85 |
+
return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
|
86 |
+
|
87 |
+
f_str = get_function_def(sig)
|
88 |
+
graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
|
89 |
+
else:
|
90 |
+
graph = torch.jit.script(decomp_fn).graph
|
91 |
+
torch.jit._register_decomposition(decomp, graph)
|
92 |
+
|
93 |
+
|
94 |
+
# The only decompositions here are temporary or hacks for the purposes of jvp
|
95 |
+
|
96 |
+
|
97 |
+
# TODO: do these also belong here?
|
98 |
+
@maybe_register_decomposition(aten.trace.default)
|
99 |
+
def trace(self: Tensor) -> Tensor:
|
100 |
+
return torch.sum(torch.diag(self))
|
101 |
+
|
102 |
+
|
103 |
+
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
|
104 |
+
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
|
105 |
+
min = torch.minimum(self.new_zeros(()), self)
|
106 |
+
z = torch.exp(-torch.abs(self))
|
107 |
+
if self.is_cuda:
|
108 |
+
buffer = self.new_zeros((0,))
|
109 |
+
else:
|
110 |
+
buffer = z
|
111 |
+
return min - torch.log1p(z), buffer
|
112 |
+
|
113 |
+
|
114 |
+
def recompute_mean_var(
|
115 |
+
input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
|
116 |
+
):
|
117 |
+
# for most norm decompositions, it will be the same as the core version except for here.
|
118 |
+
# We recompute the mean and variance so that they track gradients through input
|
119 |
+
|
120 |
+
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
|
121 |
+
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
|
122 |
+
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
|
123 |
+
eps = eps.detach()
|
124 |
+
rstd = 1 / torch.sqrt(var + eps)
|
125 |
+
return mean, rstd
|
126 |
+
|
127 |
+
|
128 |
+
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
|
129 |
+
def native_layer_norm_backward(
|
130 |
+
grad_out: Tensor,
|
131 |
+
input: Tensor,
|
132 |
+
normalized_shape: List[int],
|
133 |
+
mean: Tensor,
|
134 |
+
rstd: Tensor,
|
135 |
+
weight: Optional[Tensor],
|
136 |
+
bias: Optional[Tensor],
|
137 |
+
output_mask: List[bool],
|
138 |
+
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
139 |
+
input_shape = input.shape
|
140 |
+
input_ndim = input.dim()
|
141 |
+
|
142 |
+
axis = input_ndim - len(normalized_shape)
|
143 |
+
inner_dims = input_shape[axis:]
|
144 |
+
outer_dims = input_shape[:axis]
|
145 |
+
inner_dim_indices = list(range(axis, input_ndim))
|
146 |
+
outer_dim_indices = list(range(0, axis))
|
147 |
+
|
148 |
+
N = 1
|
149 |
+
for i in inner_dims:
|
150 |
+
N *= i
|
151 |
+
M = 1
|
152 |
+
for i in outer_dims:
|
153 |
+
M *= i
|
154 |
+
if M <= 0 or N <= 0:
|
155 |
+
return (
|
156 |
+
input.new_zeros(input_shape),
|
157 |
+
input.new_zeros(input_shape[axis:]),
|
158 |
+
input.new_zeros(input_shape[axis:]),
|
159 |
+
)
|
160 |
+
|
161 |
+
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
|
162 |
+
|
163 |
+
x_hat = (input - mean_) * rstd_
|
164 |
+
if weight is not None:
|
165 |
+
grad_x_hat = grad_out * weight
|
166 |
+
else:
|
167 |
+
grad_x_hat = grad_out
|
168 |
+
a = grad_x_hat * N
|
169 |
+
b = torch.sum(grad_x_hat, inner_dim_indices, True)
|
170 |
+
c1 = torch.mul(grad_x_hat, x_hat)
|
171 |
+
c2 = torch.sum(c1, inner_dim_indices, True)
|
172 |
+
c3 = torch.mul(x_hat, c2)
|
173 |
+
inner = a - b - c3
|
174 |
+
|
175 |
+
if output_mask[0]:
|
176 |
+
d_input: Optional[Tensor] = (rstd_ / N) * inner
|
177 |
+
else:
|
178 |
+
d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
|
179 |
+
|
180 |
+
if output_mask[1] and weight is not None:
|
181 |
+
if len(outer_dim_indices) > 0:
|
182 |
+
d_weight: Optional[Tensor] = torch.sum(
|
183 |
+
grad_out * x_hat, outer_dim_indices, False
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
d_weight = grad_out * x_hat
|
187 |
+
elif weight is not None:
|
188 |
+
d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
|
189 |
+
else:
|
190 |
+
d_weight = torch.zeros(()) # should be None but doesn't work with vjp
|
191 |
+
|
192 |
+
if output_mask[2] and bias is not None:
|
193 |
+
if len(outer_dim_indices) > 0:
|
194 |
+
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
|
195 |
+
else:
|
196 |
+
d_bias = grad_out.clone()
|
197 |
+
elif bias is not None:
|
198 |
+
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
|
199 |
+
else:
|
200 |
+
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
|
201 |
+
|
202 |
+
return (d_input, d_weight, d_bias)
|
203 |
+
|
204 |
+
|
205 |
+
def prod(x: List[int]):
|
206 |
+
r = 1
|
207 |
+
for i in x:
|
208 |
+
r *= i
|
209 |
+
return r
|
210 |
+
|
211 |
+
|
212 |
+
@register_decomposition_for_jvp(aten.native_batch_norm_backward)
|
213 |
+
def native_batch_norm_backward(
|
214 |
+
grad_out: Tensor,
|
215 |
+
input: Tensor,
|
216 |
+
weight: Optional[Tensor],
|
217 |
+
running_mean: Optional[Tensor],
|
218 |
+
running_var: Optional[Tensor],
|
219 |
+
save_mean: Optional[Tensor],
|
220 |
+
save_invstd: Optional[Tensor],
|
221 |
+
train: bool,
|
222 |
+
eps: float,
|
223 |
+
output_mask: List[bool],
|
224 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
225 |
+
input_shape = input.shape
|
226 |
+
input_rank = input.dim()
|
227 |
+
assert input_rank >= 2, "rank of the input must be at least 2"
|
228 |
+
|
229 |
+
axis = 1
|
230 |
+
num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
|
231 |
+
mean = save_mean
|
232 |
+
invstd = save_invstd
|
233 |
+
if train:
|
234 |
+
assert (
|
235 |
+
save_mean is not None and save_invstd is not None
|
236 |
+
), "when train=True, save_mean and save_invstd are required"
|
237 |
+
|
238 |
+
reduciton_dims = [0] + list(range(2, input.dim()))
|
239 |
+
assert invstd is not None # for typing
|
240 |
+
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
|
241 |
+
else:
|
242 |
+
assert running_mean is not None and running_var is not None
|
243 |
+
mean = running_mean
|
244 |
+
invstd = torch.rsqrt(running_var + eps)
|
245 |
+
|
246 |
+
assert invstd is not None and mean is not None
|
247 |
+
|
248 |
+
broadcast_mask = [1] * input_rank
|
249 |
+
broadcast_mask[axis] = input_shape[axis]
|
250 |
+
|
251 |
+
reduction_axes: List[int] = []
|
252 |
+
for i in range(input_rank):
|
253 |
+
if i != axis:
|
254 |
+
reduction_axes.append(i)
|
255 |
+
|
256 |
+
mean = torch.reshape(mean, broadcast_mask)
|
257 |
+
norm = 1.0 / num_features
|
258 |
+
grad_output_sum = torch.sum(grad_out, reduction_axes)
|
259 |
+
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
|
260 |
+
|
261 |
+
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
|
262 |
+
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
|
263 |
+
|
264 |
+
if weight is None:
|
265 |
+
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
|
266 |
+
else:
|
267 |
+
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
|
268 |
+
|
269 |
+
if train:
|
270 |
+
proj = (input - mean) * proj_scale
|
271 |
+
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
|
272 |
+
else:
|
273 |
+
grad_input = grad_out * grad_scale
|
274 |
+
|
275 |
+
if output_mask[1]:
|
276 |
+
grad_weight = dot_p * invstd
|
277 |
+
elif weight is not None:
|
278 |
+
grad_weight = torch.zeros_like(
|
279 |
+
weight
|
280 |
+
) # should be None but doesn't work with vjp
|
281 |
+
else:
|
282 |
+
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
|
283 |
+
|
284 |
+
if output_mask[2]:
|
285 |
+
grad_bias = grad_output_sum
|
286 |
+
else:
|
287 |
+
grad_bias = torch.zeros_like(
|
288 |
+
grad_output_sum
|
289 |
+
) # should be None but doesn't work with vjp
|
290 |
+
|
291 |
+
return (grad_input, grad_weight, grad_bias)
|
292 |
+
|
293 |
+
|
294 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
|
295 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
|
296 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
|
297 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
|
298 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
|
299 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
|
300 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
|
301 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
|
302 |
+
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
|
torch/_decomp/decompositions_for_rng.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import Callable, Dict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch._decomp as decomp
|
7 |
+
from torch._decomp import get_decompositions
|
8 |
+
from torch._ops import OpOverload
|
9 |
+
|
10 |
+
aten = torch.ops.aten
|
11 |
+
|
12 |
+
rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
|
13 |
+
|
14 |
+
|
15 |
+
def register_rng_decomposition(aten_op):
|
16 |
+
return decomp.register_decomposition(aten_op, rng_decompositions)
|
17 |
+
|
18 |
+
|
19 |
+
def throw_on_non_cuda(device):
|
20 |
+
raise RuntimeError(
|
21 |
+
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
|
22 |
+
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
|
23 |
+
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
# TODO - We have to register many more distributions here, and also higher level
|
28 |
+
# ops like dropout which have fused implementation and can hide the rand inside.
|
29 |
+
@register_rng_decomposition(aten.rand)
|
30 |
+
def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
|
31 |
+
if device and device.type != "cuda":
|
32 |
+
throw_on_non_cuda(device)
|
33 |
+
seed, offset = PhiloxStateTracker.get_state_as_tuple()
|
34 |
+
dtype = dtype or torch.float32
|
35 |
+
out, offset_jump = torch.ops.rngprims.philox_rand(
|
36 |
+
shape, seed, offset, None, device, dtype
|
37 |
+
)
|
38 |
+
PhiloxStateTracker.advance_offset(offset_jump)
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
@register_rng_decomposition(aten.rand_like)
|
43 |
+
def rand_like(
|
44 |
+
x: torch.Tensor,
|
45 |
+
dtype=None,
|
46 |
+
layout=None,
|
47 |
+
device=None,
|
48 |
+
pin_memory=False,
|
49 |
+
memory_format=torch.preserve_format,
|
50 |
+
):
|
51 |
+
device = device or x.device
|
52 |
+
if device.type != "cuda":
|
53 |
+
throw_on_non_cuda(device)
|
54 |
+
dtype = dtype or x.dtype
|
55 |
+
seed, offset = PhiloxStateTracker.get_state_as_tuple()
|
56 |
+
out, offset_jump = torch.ops.rngprims.philox_rand(
|
57 |
+
x.shape, seed, offset, None, device, dtype
|
58 |
+
)
|
59 |
+
PhiloxStateTracker.advance_offset(offset_jump)
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class PhiloxState:
|
64 |
+
"""
|
65 |
+
Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
|
66 |
+
relative_offset. seed and base_offset basically point to the rng state just
|
67 |
+
before tracing starts. relative offset tracks the totally consumed offset at
|
68 |
+
trace time.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self):
|
72 |
+
self.reset()
|
73 |
+
|
74 |
+
def reset(self):
|
75 |
+
self.seed = torch.tensor(())
|
76 |
+
self.base_offset = torch.tensor(())
|
77 |
+
self.relative_offset = 0
|
78 |
+
self.offset_advanced_alteast_once = False
|
79 |
+
|
80 |
+
def validate_state(self):
|
81 |
+
assert self.seed.numel() != 0 and self.base_offset.numel() != 0
|
82 |
+
|
83 |
+
def advance_offset(self, consumed_offset):
|
84 |
+
self.offset_advanced_alteast_once = True
|
85 |
+
self.relative_offset = self.relative_offset + consumed_offset
|
86 |
+
|
87 |
+
def set_state(self, seed, base_offset, relative_offset=0):
|
88 |
+
self.seed = seed
|
89 |
+
self.base_offset = base_offset
|
90 |
+
self.relative_offset = relative_offset
|
91 |
+
|
92 |
+
def get_state_as_tuple(self):
|
93 |
+
self.validate_state()
|
94 |
+
return (self.seed, self.base_offset + self.relative_offset)
|
95 |
+
|
96 |
+
def get_state_as_tensor(self):
|
97 |
+
# Only needed because we override get_rng_state.
|
98 |
+
self.validate_state()
|
99 |
+
return torch.stack([self.seed, self.base_offset + self.relative_offset])
|
100 |
+
|
101 |
+
def set_state_from_tensor(self, state):
|
102 |
+
# Only needed because we override set_rng_state.
|
103 |
+
self.seed, self.base_offset = torch.unbind(state)
|
104 |
+
self.relative_offset = 0
|
105 |
+
|
106 |
+
|
107 |
+
class PhiloxStateTracker:
|
108 |
+
"""
|
109 |
+
Singleton class to track the philox rng state during AOT Autograd tracing.
|
110 |
+
For each aot tracing instance, AOT Autograd resets this tracker and keeps
|
111 |
+
track of both forward and backward offsets. At runtime, we only care about
|
112 |
+
the total consumed forward and backward offsets. For dynamic shapes, these
|
113 |
+
offsets are a function of input shapes. Therefore, the AOT generated graphs
|
114 |
+
have additional outputs that compute total consumed forward and backward
|
115 |
+
offsets.
|
116 |
+
"""
|
117 |
+
|
118 |
+
running_state: PhiloxState
|
119 |
+
fwd_state: PhiloxState
|
120 |
+
bwd_state: PhiloxState
|
121 |
+
|
122 |
+
def __enter__(self):
|
123 |
+
PhiloxStateTracker.reset()
|
124 |
+
return self
|
125 |
+
|
126 |
+
def __exit__(self, exc_type, exc_cal, exc_tb):
|
127 |
+
PhiloxStateTracker.reset()
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def reset(cls):
|
131 |
+
cls.running_state = PhiloxState()
|
132 |
+
cls.fwd_state = PhiloxState()
|
133 |
+
cls.bwd_state = PhiloxState()
|
134 |
+
|
135 |
+
@classmethod
|
136 |
+
def mark_beginning_of_forward(cls):
|
137 |
+
# Tells the tracker to use fwd_state as the running state
|
138 |
+
cls.running_state = cls.fwd_state
|
139 |
+
|
140 |
+
@classmethod
|
141 |
+
def mark_beginning_of_backward(cls):
|
142 |
+
# Tells the tracker to use bwd_state as the running state
|
143 |
+
cls.running_state = cls.bwd_state
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def record_state(cls, seed, offset, mode):
|
147 |
+
# Records the seed and offset tensors. These tensors are used to invoke
|
148 |
+
# the philox_rand functional primitives.
|
149 |
+
if mode == "forward":
|
150 |
+
cls.fwd_state.set_state(seed, offset)
|
151 |
+
cls.mark_beginning_of_forward()
|
152 |
+
else:
|
153 |
+
assert mode == "backward"
|
154 |
+
cls.bwd_state.set_state(seed, offset)
|
155 |
+
|
156 |
+
@classmethod
|
157 |
+
def get_state_as_tensor(cls):
|
158 |
+
# The only reason this exists is because we override get_rng_state and
|
159 |
+
# set_rng_state during tracing. get_rng_state expects a tensor output,
|
160 |
+
# so return (seed, offset) tuple upset other parts of the program like
|
161 |
+
# ctx.saved_tensors.
|
162 |
+
|
163 |
+
# A bad consequence is that if user saves and restores rng state, we
|
164 |
+
# have little bit of ugliness in the generated code, where we first
|
165 |
+
# concat the (seed, offset) to create a tensor for get_rng_state, and
|
166 |
+
# then split it back to get (seed, offset) tuple in set_rng_state.
|
167 |
+
|
168 |
+
# TODO: Investigate if there is be a better way to wrap the tuple in a
|
169 |
+
# false Tensor object, and then desugar it later on.
|
170 |
+
return cls.running_state.get_state_as_tensor()
|
171 |
+
|
172 |
+
@classmethod
|
173 |
+
def get_state_as_tuple(cls):
|
174 |
+
return cls.running_state.get_state_as_tuple()
|
175 |
+
|
176 |
+
@classmethod
|
177 |
+
def set_state_from_tensor(cls, x):
|
178 |
+
# This is only needed because we override set_rng_state. Look at the
|
179 |
+
# comment in get_state_from_tensor method.
|
180 |
+
cls.running_state.set_state_from_tensor(x)
|
181 |
+
|
182 |
+
@classmethod
|
183 |
+
def advance_offset(cls, consumed_offset):
|
184 |
+
cls.running_state.advance_offset(consumed_offset)
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def get_current_relative_offset(cls):
|
188 |
+
return cls.running_state.relative_offset
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
def multiple_of_4(offset):
|
192 |
+
# torch cuda rng state offset must be a multiple of 4. For inductor, as
|
193 |
+
# we sum up all the numel, the result might not be a multiple of 4. This
|
194 |
+
# method achieves that.
|
195 |
+
return (offset + 3) // 4 * 4
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def get_updated_fwd_offset(cls):
|
199 |
+
# Short circuit if no rand ops were observed
|
200 |
+
if not cls.fwd_state.offset_advanced_alteast_once:
|
201 |
+
return cls.fwd_state.base_offset
|
202 |
+
return cls.multiple_of_4(
|
203 |
+
cls.fwd_state.base_offset + cls.fwd_state.relative_offset
|
204 |
+
)
|
205 |
+
|
206 |
+
@classmethod
|
207 |
+
def get_updated_bwd_offset(cls):
|
208 |
+
# Short circuit if no rand ops were observed
|
209 |
+
if not cls.bwd_state.offset_advanced_alteast_once:
|
210 |
+
return cls.bwd_state.base_offset
|
211 |
+
return cls.multiple_of_4(
|
212 |
+
cls.bwd_state.base_offset + cls.bwd_state.relative_offset
|
213 |
+
)
|
214 |
+
|
215 |
+
|
216 |
+
# Adding more decompositions which eventually use rand_like inside decomps.
|
217 |
+
# Adding these in rng_decompositions ensures the functionalization of rand_like
|
218 |
+
# ops used in these decomps. The list is copied from inductor codebase, which
|
219 |
+
# uses it for similar purpose.
|
220 |
+
#
|
221 |
+
# Caution - These decomps do not have same accuracy as that of eager. However,
|
222 |
+
# we can't just disable them with a config flag like fallback_random, because
|
223 |
+
# for functionalization of rng ops, we have to decompose these ops.
|
224 |
+
extra_random_decomps = get_decompositions(
|
225 |
+
[
|
226 |
+
aten.cauchy,
|
227 |
+
aten.cauchy_,
|
228 |
+
aten.exponential,
|
229 |
+
aten.exponential_,
|
230 |
+
aten.geometric,
|
231 |
+
aten.geometric_,
|
232 |
+
aten.native_dropout,
|
233 |
+
aten.normal,
|
234 |
+
aten.normal_,
|
235 |
+
aten.normal_functional,
|
236 |
+
aten.log_normal,
|
237 |
+
aten.log_normal_,
|
238 |
+
aten.rrelu_with_noise,
|
239 |
+
aten.rrelu_with_noise_,
|
240 |
+
aten.uniform_,
|
241 |
+
]
|
242 |
+
)
|
243 |
+
register_extra_random_decomp = functools.partial(
|
244 |
+
decomp.register_decomposition, registry=extra_random_decomps
|
245 |
+
)
|
246 |
+
|
247 |
+
|
248 |
+
@register_extra_random_decomp([aten.bernoulli_])
|
249 |
+
def bernoulli_(self, p=0.5):
|
250 |
+
if self.device == torch.device("cpu"):
|
251 |
+
return NotImplemented
|
252 |
+
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
|
253 |
+
|
254 |
+
|
255 |
+
@register_extra_random_decomp([aten.bernoulli.p])
|
256 |
+
def bernoulli_p(self, p=0.5, *, generator=None):
|
257 |
+
if self.device == torch.device("cpu"):
|
258 |
+
return NotImplemented
|
259 |
+
assert generator is None
|
260 |
+
return torch.rand_like(self, dtype=torch.float32) < p
|
261 |
+
|
262 |
+
|
263 |
+
rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]
|
torch/_deploy.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
|
5 |
+
from torch.package._package_pickler import create_pickler
|
6 |
+
from torch.package._package_unpickler import PackageUnpickler
|
7 |
+
from torch.serialization import _maybe_decode_ascii
|
8 |
+
|
9 |
+
|
10 |
+
def _save_storages(importer, obj):
|
11 |
+
serialized_storages = []
|
12 |
+
serialized_dtypes = []
|
13 |
+
|
14 |
+
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
|
15 |
+
importers: Importer
|
16 |
+
if importer is not None:
|
17 |
+
importers = OrderedImporter(importer, sys_importer)
|
18 |
+
else:
|
19 |
+
importers = sys_importer
|
20 |
+
|
21 |
+
def persistent_id(obj):
|
22 |
+
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
|
23 |
+
if isinstance(obj, torch.storage.TypedStorage):
|
24 |
+
# TODO: Once we decide to break serialization FC, we can
|
25 |
+
# remove this case
|
26 |
+
storage = obj._untyped_storage
|
27 |
+
dtype = obj.dtype
|
28 |
+
else:
|
29 |
+
storage = obj
|
30 |
+
dtype = torch.uint8
|
31 |
+
|
32 |
+
serialized_storages.append(obj)
|
33 |
+
serialized_dtypes.append(dtype)
|
34 |
+
return ("storage", len(serialized_storages) - 1)
|
35 |
+
|
36 |
+
if hasattr(obj, "__reduce_deploy__"):
|
37 |
+
if _serialized_reduces.get(id(obj)) is None:
|
38 |
+
_serialized_reduces[id(obj)] = (
|
39 |
+
"reduce_deploy",
|
40 |
+
id(obj),
|
41 |
+
*obj.__reduce_deploy__(importers),
|
42 |
+
)
|
43 |
+
return _serialized_reduces[id(obj)]
|
44 |
+
|
45 |
+
return None
|
46 |
+
|
47 |
+
# Write the pickle data for `obj`
|
48 |
+
data_buf = io.BytesIO()
|
49 |
+
pickler = create_pickler(data_buf, importers)
|
50 |
+
pickler.persistent_id = persistent_id
|
51 |
+
pickler.dump(obj)
|
52 |
+
data_value = data_buf.getvalue()
|
53 |
+
return (
|
54 |
+
data_value,
|
55 |
+
serialized_storages,
|
56 |
+
serialized_dtypes,
|
57 |
+
importer.zip_reader if importer else None,
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
|
62 |
+
def persistent_load(saved_id):
|
63 |
+
assert isinstance(saved_id, tuple)
|
64 |
+
typename = _maybe_decode_ascii(saved_id[0])
|
65 |
+
data = saved_id[1:]
|
66 |
+
|
67 |
+
if typename == "storage":
|
68 |
+
# TODO: Once we decide to break serialization FC, we can
|
69 |
+
# stop wrapping with TypedStorage
|
70 |
+
storage = serialized_storages[data[0]]
|
71 |
+
dtype = serialized_dtypes[data[0]]
|
72 |
+
return torch.storage.TypedStorage(
|
73 |
+
wrap_storage=storage.untyped(), dtype=dtype
|
74 |
+
)
|
75 |
+
|
76 |
+
if typename == "reduce_deploy":
|
77 |
+
reduce_id, func, args = data
|
78 |
+
if reduce_id not in _loaded_reduces:
|
79 |
+
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
|
80 |
+
return _loaded_reduces[reduce_id]
|
81 |
+
|
82 |
+
return None
|
83 |
+
|
84 |
+
importer: Importer
|
85 |
+
if zip_reader is not None:
|
86 |
+
importer = OrderedImporter(_get_package(zip_reader), sys_importer)
|
87 |
+
else:
|
88 |
+
importer = sys_importer
|
89 |
+
|
90 |
+
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
|
91 |
+
unpickler.persistent_load = persistent_load # type: ignore[assignment]
|
92 |
+
result = _deploy_objects[id] = unpickler.load()
|
93 |
+
return result
|
94 |
+
|
95 |
+
|
96 |
+
def _get_package(zip_reader):
|
97 |
+
if zip_reader not in _raw_packages:
|
98 |
+
_raw_packages[zip_reader] = PackageImporter(zip_reader)
|
99 |
+
return _raw_packages[zip_reader]
|
100 |
+
|
101 |
+
|
102 |
+
_raw_packages: dict = {}
|
103 |
+
_deploy_objects: dict = {}
|
104 |
+
_serialized_reduces: dict = {}
|
105 |
+
_loaded_reduces: dict = {}
|
torch/_dispatch/__init__.py
ADDED
File without changes
|