Adi-69s commited on
Commit
b2659ad
1 Parent(s): 0000c2e

Upload 5061 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +15 -0
  2. torch/_C.cp310-win_amd64.pyd +0 -0
  3. torch/_C/_VariableFunctions.pyi +0 -0
  4. torch/_C/__init__.pyi +0 -0
  5. torch/_C/_autograd.pyi +123 -0
  6. torch/_C/_cpu.pyi +5 -0
  7. torch/_C/_cudnn.pyi +17 -0
  8. torch/_C/_distributed_autograd.pyi +26 -0
  9. torch/_C/_distributed_c10d.pyi +478 -0
  10. torch/_C/_distributed_rpc.pyi +188 -0
  11. torch/_C/_distributed_rpc_testing.pyi +35 -0
  12. torch/_C/_functions.pyi +11 -0
  13. torch/_C/_functorch.pyi +71 -0
  14. torch/_C/_itt.pyi +5 -0
  15. torch/_C/_lazy.pyi +28 -0
  16. torch/_C/_lazy_ts_backend.pyi +11 -0
  17. torch/_C/_monitor.pyi +44 -0
  18. torch/_C/_nn.pyi +86 -0
  19. torch/_C/_nvtx.pyi +6 -0
  20. torch/_C/_onnx.pyi +38 -0
  21. torch/_C/_profiler.pyi +238 -0
  22. torch/_C/_verbose.pyi +3 -0
  23. torch/_VF.py +30 -0
  24. torch/_VF.pyi +0 -0
  25. torch/__config__.py +22 -0
  26. torch/__future__.py +21 -0
  27. torch/_appdirs.py +666 -0
  28. torch/_awaits/__init__.py +54 -0
  29. torch/_awaits/__pycache__/__init__.cpython-310.pyc +0 -0
  30. torch/_classes.py +55 -0
  31. torch/_compile.py +30 -0
  32. torch/_custom_op/__init__.py +0 -0
  33. torch/_custom_op/__pycache__/__init__.cpython-310.pyc +0 -0
  34. torch/_custom_op/__pycache__/autograd.cpython-310.pyc +0 -0
  35. torch/_custom_op/__pycache__/functional.cpython-310.pyc +0 -0
  36. torch/_custom_op/__pycache__/impl.cpython-310.pyc +0 -0
  37. torch/_custom_op/autograd.py +274 -0
  38. torch/_custom_op/functional.py +187 -0
  39. torch/_custom_op/impl.py +976 -0
  40. torch/_custom_ops.py +322 -0
  41. torch/_decomp/__init__.py +444 -0
  42. torch/_decomp/__pycache__/__init__.cpython-310.pyc +0 -0
  43. torch/_decomp/__pycache__/decompositions.cpython-310.pyc +0 -0
  44. torch/_decomp/__pycache__/decompositions_for_jvp.cpython-310.pyc +0 -0
  45. torch/_decomp/__pycache__/decompositions_for_rng.cpython-310.pyc +0 -0
  46. torch/_decomp/decompositions.py +0 -0
  47. torch/_decomp/decompositions_for_jvp.py +302 -0
  48. torch/_decomp/decompositions_for_rng.py +263 -0
  49. torch/_deploy.py +105 -0
  50. torch/_dispatch/__init__.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ torch/bin/fbgemm.dll filter=lfs diff=lfs merge=lfs -text
37
+ torch/bin/protoc.exe filter=lfs diff=lfs merge=lfs -text
38
+ torch/lib/dnnl.lib filter=lfs diff=lfs merge=lfs -text
39
+ torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text
40
+ torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text
41
+ torch/lib/fmt.lib filter=lfs diff=lfs merge=lfs -text
42
+ torch/lib/kineto.lib filter=lfs diff=lfs merge=lfs -text
43
+ torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text
44
+ torch/lib/libprotobuf-lite.lib filter=lfs diff=lfs merge=lfs -text
45
+ torch/lib/libprotobuf.lib filter=lfs diff=lfs merge=lfs -text
46
+ torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text
47
+ torch/lib/torch_cpu.dll filter=lfs diff=lfs merge=lfs -text
48
+ torch/lib/torch_cpu.lib filter=lfs diff=lfs merge=lfs -text
49
+ torch/lib/torch_python.dll filter=lfs diff=lfs merge=lfs -text
50
+ torch/lib/XNNPACK.lib filter=lfs diff=lfs merge=lfs -text
torch/_C.cp310-win_amd64.pyd ADDED
Binary file (10.2 kB). View file
 
torch/_C/_VariableFunctions.pyi ADDED
The diff for this file is too large to render. See raw diff
 
torch/_C/__init__.pyi ADDED
The diff for this file is too large to render. See raw diff
 
torch/_C/_autograd.pyi ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Any, Callable, List, Optional, Set
3
+
4
+ import torch
5
+
6
+ from ._profiler import (
7
+ _ProfilerEvent,
8
+ ActiveProfilerType,
9
+ ProfilerActivity,
10
+ ProfilerConfig,
11
+ )
12
+
13
+ # Defined in tools/autograd/init.cpp
14
+
15
+ class DeviceType(Enum):
16
+ CPU = ...
17
+ CUDA = ...
18
+ MKLDNN = ...
19
+ OPENGL = ...
20
+ OPENCL = ...
21
+ IDEEP = ...
22
+ HIP = ...
23
+ FPGA = ...
24
+ ORT = ...
25
+ XLA = ...
26
+ MPS = ...
27
+ HPU = ...
28
+ Meta = ...
29
+ Vulkan = ...
30
+ Metal = ...
31
+ PrivateUse1 = ...
32
+
33
+ class ProfilerEvent:
34
+ def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
35
+ def cpu_memory_usage(self) -> int: ...
36
+ def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
37
+ def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
38
+ def cuda_memory_usage(self) -> int: ...
39
+ def device(self) -> int: ...
40
+ def handle(self) -> int: ...
41
+ def has_cuda(self) -> bool: ...
42
+ def is_remote(self) -> bool: ...
43
+ def kind(self) -> int: ...
44
+ def name(self) -> str: ...
45
+ def node_id(self) -> int: ...
46
+ def sequence_nr(self) -> int: ...
47
+ def shapes(self) -> List[List[int]]: ...
48
+ def thread_id(self) -> int: ...
49
+ def flops(self) -> float: ...
50
+ def is_async(self) -> bool: ...
51
+
52
+ class _KinetoEvent:
53
+ def name(self) -> str: ...
54
+ def device_index(self) -> int: ...
55
+ def start_us(self) -> int: ...
56
+ def duration_us(self) -> int: ...
57
+ def is_async(self) -> bool: ...
58
+ def linked_correlation_id(self) -> int: ...
59
+ def shapes(self) -> List[List[int]]: ...
60
+ def dtypes(self) -> List[str]: ...
61
+ def concrete_inputs(self) -> List[Any]: ...
62
+ def device_type(self) -> DeviceType: ...
63
+ def start_thread_id(self) -> int: ...
64
+ def end_thread_id(self) -> int: ...
65
+ def correlation_id(self) -> int: ...
66
+ def fwd_thread_id(self) -> int: ...
67
+ def stack(self) -> List[str]: ...
68
+ def scope(self) -> int: ...
69
+ def sequence_nr(self) -> int: ...
70
+ def flops(self) -> int: ...
71
+ def cuda_elapsed_us(self) -> int: ...
72
+ def privateuse1_elapsed_us(self) -> int: ...
73
+
74
+ class _ProfilerResult:
75
+ def events(self) -> List[_KinetoEvent]: ...
76
+ def legacy_events(self) -> List[List[ProfilerEvent]]: ...
77
+ def save(self, path: str) -> None: ...
78
+ def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
79
+ def trace_start_us(self) -> int: ...
80
+
81
+ class SavedTensor: ...
82
+
83
+ def _enable_profiler(
84
+ config: ProfilerConfig,
85
+ activities: Set[ProfilerActivity],
86
+ ) -> None: ...
87
+ def _prepare_profiler(
88
+ config: ProfilerConfig,
89
+ activities: Set[ProfilerActivity],
90
+ ) -> None: ...
91
+ def _disable_profiler() -> _ProfilerResult: ...
92
+ def _profiler_enabled() -> bool: ...
93
+ def _add_metadata_json(key: str, value: str) -> None: ...
94
+ def _kineto_step() -> None: ...
95
+ def _get_sequence_nr() -> int: ...
96
+ def kineto_available() -> bool: ...
97
+ def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
98
+ def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
99
+ def _supported_activities() -> Set[ProfilerActivity]: ...
100
+ def _enable_record_function(enable: bool) -> None: ...
101
+ def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
102
+ def _push_saved_tensors_default_hooks(
103
+ pack_hook: Callable[[torch.Tensor], Any],
104
+ unpack_hook: Callable[[Any], torch.Tensor],
105
+ ) -> None: ...
106
+ def _pop_saved_tensors_default_hooks() -> None: ...
107
+ def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
108
+ def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
109
+ def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
110
+ def _profiler_type() -> ActiveProfilerType: ...
111
+ def _saved_tensors_hooks_enable() -> None: ...
112
+ def _saved_tensors_hooks_disable(message: str) -> None: ...
113
+ def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
114
+
115
+ class CreationMeta(Enum):
116
+ DEFAULT = ...
117
+ IN_CUSTOM_FUNCTION = ...
118
+ MULTI_OUTPUT_NODE = ...
119
+ NO_GRAD_MODE = ...
120
+ INFERENCE_MODE = ...
121
+
122
+ def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
123
+ def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...
torch/_C/_cpu.pyi ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from torch.types import _bool
2
+
3
+ # Defined in torch/csrc/cpu/Module.cpp
4
+
5
+ def _is_cpu_support_vnni() -> _bool: ...
torch/_C/_cudnn.pyi ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from torch.types import _bool, Tuple
4
+
5
+ # Defined in torch/csrc/cuda/shared/cudnn.cpp
6
+ is_cuda: _bool
7
+
8
+ def getRuntimeVersion() -> Tuple[int, int, int]: ...
9
+ def getCompileVersion() -> Tuple[int, int, int]: ...
10
+ def getVersionInt() -> int: ...
11
+
12
+ class RNNMode(int, Enum):
13
+ value: int
14
+ rnn_relu = ...
15
+ rnn_tanh = ...
16
+ lstm = ...
17
+ gru = ...
torch/_C/_distributed_autograd.pyi ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Set
2
+
3
+ import torch
4
+
5
+ # This module is defined in torch/csrc/distributed/autograd/init.cpp
6
+
7
+ class DistAutogradContext:
8
+ def _context_id(self) -> int: ...
9
+ def _recv_functions(self) -> Dict[int, Any]: ...
10
+ def _send_functions(self) -> Dict[int, Any]: ...
11
+ def _known_worker_ids(self) -> Set[int]: ...
12
+
13
+ def _new_context() -> DistAutogradContext: ...
14
+ def _release_context(context_id: int) -> None: ...
15
+ def _get_max_id() -> int: ...
16
+ def _is_valid_context(worker_id: int) -> bool: ...
17
+ def _retrieve_context(context_id: int) -> DistAutogradContext: ...
18
+ def _current_context() -> DistAutogradContext: ...
19
+ def _init(worker_id: int) -> None: ...
20
+ def _get_debug_info() -> Dict[str, str]: ...
21
+ def backward(
22
+ context_id: int,
23
+ roots: List[torch.Tensor],
24
+ retain_graph=False,
25
+ ) -> None: ...
26
+ def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...
torch/_C/_distributed_c10d.pyi ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: disable-error-code="type-arg"
2
+ from datetime import timedelta
3
+ from enum import Enum
4
+ from typing import Any, Dict, List, Optional, overload, Tuple, Union
5
+
6
+ from torch import Tensor
7
+ from torch._C import ScriptObject
8
+ from torch.futures import Future
9
+
10
+ # This module is defined in torch/csrc/distributed/c10d/init.cpp
11
+
12
+ _DEFAULT_FIRST_BUCKET_BYTES: int
13
+ _DEFAULT_NO_TIMEOUT: timedelta
14
+ _DEFAULT_PG_TIMEOUT: timedelta
15
+ _DEFAULT_PG_NCCL_TIMEOUT: timedelta
16
+
17
+ class BuiltinCommHookType(Enum):
18
+ ALLREDUCE = ...
19
+ FP16_COMPRESS = ...
20
+
21
+ def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
22
+ def _register_builtin_comm_hook(
23
+ reducer: Reducer,
24
+ comm_hook_type: BuiltinCommHookType,
25
+ ): ...
26
+
27
+ class GradBucket:
28
+ def index(self) -> int: ...
29
+ def buffer(self) -> Tensor: ...
30
+ def gradients(self) -> List[Tensor]: ...
31
+ def is_last(self) -> bool: ...
32
+ def set_buffer(self, tensor: Tensor) -> None: ...
33
+ def parameters(self) -> List[Tensor]: ...
34
+
35
+ class Reducer:
36
+ def __init__(
37
+ self,
38
+ params: List[Tensor],
39
+ bucket_indices: List[List[int]],
40
+ per_bucket_size_limits: List[int],
41
+ process_group: ProcessGroup,
42
+ expect_sparse_gradients: List[bool] = ...,
43
+ bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
44
+ find_unused_parameters: bool = ...,
45
+ gradient_as_bucket_view: bool = ...,
46
+ param_to_name_mapping: Dict[int, str] = ...,
47
+ first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
48
+ ): ...
49
+ def prepare_for_forward(self) -> None: ...
50
+ def prepare_for_backward(self, output: List[Tensor]) -> None: ...
51
+ def get_backward_stats(self) -> List[int]: ...
52
+ def _install_post_backward_futures(self, futures: List[Future]) -> None: ...
53
+ def _rebuild_buckets(self) -> bool: ...
54
+ def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ...
55
+ def _push_all_rebuilt_params(self) -> None: ...
56
+ def _set_forward_pass_work_handle(
57
+ self,
58
+ work: Work,
59
+ use_static_world_size: bool,
60
+ ): ...
61
+ def _get_local_used_map(self) -> Tensor: ...
62
+ def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
63
+ def _set_static_graph(self) -> None: ...
64
+ def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
65
+ def set_logger(self, logger: Logger) -> None: ...
66
+ def _remove_autograd_hooks(self) -> None: ...
67
+ def _check_reducer_finalized(self) -> None: ...
68
+ def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
69
+ def _reset_state(self) -> None: ...
70
+ def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
71
+
72
+ class DDPLoggingData:
73
+ strs_map: Dict[str, str]
74
+ ints_map: Dict[str, int]
75
+
76
+ class Logger:
77
+ def __init__(self, reducer: Reducer): ...
78
+ def set_construction_data_and_log(
79
+ self,
80
+ module_name: str,
81
+ device_ids: List[int],
82
+ output_device: int,
83
+ broadcast_buffers: bool,
84
+ has_sync_bn: bool,
85
+ static_graph: bool,
86
+ ): ...
87
+ def set_runtime_stats_and_log(self) -> None: ...
88
+ def set_error_and_log(self, error: str) -> None: ...
89
+ def _get_ddp_logging_data(self) -> DDPLoggingData: ...
90
+ def _set_comm_hook_name(self, comm_hook: str) -> None: ...
91
+ def _set_uneven_input_join(self) -> None: ...
92
+ def _set_static_graph(self) -> None: ...
93
+
94
+ def get_debug_level(): ...
95
+ def set_debug_level(): ...
96
+ def set_debug_level_from_env(): ...
97
+
98
+ class DebugLevel(Enum):
99
+ OFF = ...
100
+ INFO = ...
101
+ DETAIL = ...
102
+
103
+ class ReduceOp:
104
+ def __init__(self, op: RedOpType): ...
105
+
106
+ SUM: RedOpType = ...
107
+ AVG: RedOpType = ...
108
+ PRODUCT: RedOpType = ...
109
+ MIN: RedOpType = ...
110
+ MAX: RedOpType = ...
111
+ BAND: RedOpType = ...
112
+ BOR: RedOpType = ...
113
+ BXOR: RedOpType = ...
114
+ PREMUL_SUM: RedOpType = ...
115
+ UNUSED: RedOpType = ...
116
+
117
+ class RedOpType(Enum): ...
118
+
119
+ class BroadcastOptions:
120
+ rootRank: int
121
+ rootTensor: int
122
+ timeout: timedelta
123
+ asyncOp: bool
124
+
125
+ class AllreduceOptions:
126
+ reduceOp: ReduceOp
127
+ timeout: timedelta
128
+
129
+ class AllreduceCoalescedOptions(AllreduceOptions): ...
130
+
131
+ class ReduceOptions:
132
+ reduceOp: ReduceOp
133
+ rootRank: int
134
+ rootTensor: int
135
+ timeout: timedelta
136
+
137
+ class AllgatherOptions:
138
+ timeout: timedelta
139
+ asyncOp: bool
140
+
141
+ class GatherOptions:
142
+ rootRank: int
143
+ timeout: timedelta
144
+
145
+ class ScatterOptions:
146
+ rootRank: int
147
+ timeout: timedelta
148
+ asyncOp: bool
149
+
150
+ class ReduceScatterOptions:
151
+ reduceOp: ReduceOp
152
+ timeout: timedelta
153
+ asyncOp: bool
154
+
155
+ class BarrierOptions:
156
+ device_ids: List[int]
157
+ timeout: timedelta
158
+
159
+ class AllToAllOptions:
160
+ timeout: timedelta
161
+
162
+ class Store:
163
+ def set(self, key: str, value: str): ...
164
+ def get(self, key: str) -> bytes: ...
165
+ def add(self, key: str, value: int) -> int: ...
166
+ def compare_set(
167
+ self,
168
+ key: str,
169
+ expected_value: str,
170
+ desired_value: str,
171
+ ) -> bytes: ...
172
+ def delete_key(self, key: str) -> bool: ...
173
+ def num_keys(self) -> int: ...
174
+ def set_timeout(self, timeout: timedelta): ...
175
+ @overload
176
+ def wait(self, keys: List[str]): ...
177
+ @overload
178
+ def wait(self, keys: List[str], timeout: timedelta): ...
179
+
180
+ class FileStore(Store):
181
+ def __init__(self, path: str, numWorkers: int = ...): ...
182
+
183
+ class HashStore(Store):
184
+ def __init__(self): ...
185
+
186
+ class TCPStore(Store):
187
+ def __init__(
188
+ self,
189
+ host_name: str,
190
+ port: int,
191
+ world_size: Optional[int] = ...,
192
+ is_master: bool = ...,
193
+ timeout: timedelta = ...,
194
+ wait_for_workers: bool = ...,
195
+ multi_tenant: bool = ...,
196
+ master_listen_fd: Optional[int] = ...,
197
+ use_libuv: Optional[bool] = ...,
198
+ ): ...
199
+ @property
200
+ def host(self) -> str: ...
201
+ @property
202
+ def port(self) -> int: ...
203
+
204
+ class PrefixStore(Store):
205
+ def __init__(self, prefix: str, store: Store): ...
206
+ @property
207
+ def underlying_store(self) -> Store: ...
208
+
209
+ class Work:
210
+ def is_completed(self) -> bool: ...
211
+ def is_success(self) -> bool: ...
212
+ def exception(self) -> Any: ...
213
+ def wait(self, timeout: timedelta = ...) -> bool: ...
214
+ def source_rank(self) -> int: ...
215
+ def _source_rank(self) -> int: ...
216
+ def result(self) -> List[Tensor]: ...
217
+ def synchronize(self): ...
218
+ def boxed(self) -> ScriptObject: ...
219
+ @staticmethod
220
+ def unbox(obj: ScriptObject) -> Work: ...
221
+
222
+ class ProcessGroup:
223
+ class Options: ...
224
+
225
+ def __init__(self): ...
226
+ def rank(self) -> int: ...
227
+ def size(self) -> int: ...
228
+ @overload
229
+ def broadcast(
230
+ self,
231
+ tensors: List[Tensor],
232
+ opts=...,
233
+ ) -> Work: ...
234
+ @overload
235
+ def broadcast(
236
+ self,
237
+ tensor: Tensor,
238
+ root: int,
239
+ ) -> Work: ...
240
+ @overload
241
+ def allreduce(
242
+ self,
243
+ tensors: List[Tensor],
244
+ opts: AllreduceOptions = ...,
245
+ ) -> Work: ...
246
+ @overload
247
+ def allreduce(
248
+ self,
249
+ tensors: List[Tensor],
250
+ op=...,
251
+ ) -> Work: ...
252
+ @overload
253
+ def allreduce(
254
+ self,
255
+ tensor: Tensor,
256
+ op=...,
257
+ ) -> Work: ...
258
+ def allreduce_coalesced(
259
+ self,
260
+ tensors: List[Tensor],
261
+ opts=...,
262
+ ) -> Work: ...
263
+ @overload
264
+ def reduce(
265
+ self,
266
+ tensors: List[Tensor],
267
+ opts=...,
268
+ ) -> Work: ...
269
+ @overload
270
+ def reduce(
271
+ self,
272
+ tensor: Tensor,
273
+ root: int,
274
+ op=...,
275
+ ) -> Work: ...
276
+ @overload
277
+ def allgather(
278
+ self,
279
+ output_tensors: List[List[Tensor]],
280
+ input_tensors: List[Tensor],
281
+ opts=...,
282
+ ) -> Work: ...
283
+ @overload
284
+ def allgather(
285
+ self,
286
+ output_tensors: List[Tensor],
287
+ input_tensor: Tensor,
288
+ ) -> Work: ...
289
+ def _allgather_base(
290
+ self,
291
+ output: Tensor,
292
+ input: Tensor,
293
+ opts=...,
294
+ ) -> Work: ...
295
+ def allgather_coalesced(
296
+ self,
297
+ output_lists: List[List[Tensor]],
298
+ input_list: List[Tensor],
299
+ opts=...,
300
+ ) -> Work: ...
301
+ @overload
302
+ def gather(
303
+ self,
304
+ output_tensors: List[List[Tensor]],
305
+ input_tensors: List[Tensor],
306
+ opts=...,
307
+ ) -> Work: ...
308
+ @overload
309
+ def gather(
310
+ self,
311
+ output_tensors: List[Tensor],
312
+ input_tensor: Tensor,
313
+ root: int,
314
+ ) -> Work: ...
315
+ @overload
316
+ def scatter(
317
+ self,
318
+ output_tensors: List[Tensor],
319
+ input_tensors: List[List[Tensor]],
320
+ opts=...,
321
+ ) -> Work: ...
322
+ @overload
323
+ def scatter(
324
+ self,
325
+ output_tensor: Tensor,
326
+ input_tensors: List[Tensor],
327
+ root: int,
328
+ ) -> Work: ...
329
+ @overload
330
+ def reduce_scatter(
331
+ self,
332
+ output_tensors: List[Tensor],
333
+ input_tensors: List[List[Tensor]],
334
+ opts=...,
335
+ ) -> Work: ...
336
+ @overload
337
+ def reduce_scatter(
338
+ self,
339
+ output_tensors: Tensor,
340
+ input_tensor: List[Tensor],
341
+ ) -> Work: ...
342
+ def _reduce_scatter_base(
343
+ self,
344
+ outputTensor: Tensor,
345
+ inputTensor: Tensor,
346
+ ) -> Work: ...
347
+ @overload
348
+ def alltoall_base(
349
+ self,
350
+ output_tensor: Tensor,
351
+ input_tensor: Tensor,
352
+ output_split_sizes: List[int],
353
+ input_split_sizes: List[int],
354
+ opts=...,
355
+ ) -> Work: ...
356
+ @overload
357
+ def alltoall_base(
358
+ self,
359
+ output: Tensor,
360
+ input: Tensor,
361
+ output_split_sizes: List[int],
362
+ input_split_sizes: List[int],
363
+ ) -> Work: ...
364
+ @overload
365
+ def alltoall(
366
+ self,
367
+ output_tensor: List[Tensor],
368
+ input_tensor: List[Tensor],
369
+ opts=...,
370
+ ) -> Work: ...
371
+ @overload
372
+ def alltoall(
373
+ self,
374
+ output: List[Tensor],
375
+ input: List[Tensor],
376
+ ) -> Work: ...
377
+ def send(
378
+ self,
379
+ tensors: List[Tensor],
380
+ dstRank: int,
381
+ tag: int,
382
+ ) -> Work: ...
383
+ def recv(
384
+ self,
385
+ tensors: List[Tensor],
386
+ srcRank: int,
387
+ tag: int,
388
+ ) -> Work: ...
389
+ def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
390
+ def barrier(self, opts=...) -> Work: ...
391
+ def boxed(self) -> ScriptObject: ...
392
+ @staticmethod
393
+ def unbox(obj: ScriptObject) -> ProcessGroup: ...
394
+
395
+ class ProcessGroupRoundRobin(ProcessGroup): ...
396
+
397
+ def _round_robin_process_groups(
398
+ process_groups: List[ProcessGroup],
399
+ ) -> ProcessGroupRoundRobin: ...
400
+
401
+ class ProcessGroupGloo(ProcessGroup):
402
+ class Device: ...
403
+ class Options: ...
404
+
405
+ def __init__(
406
+ self,
407
+ store: Store,
408
+ rank: int,
409
+ size: int,
410
+ timeout: timedelta,
411
+ ): ...
412
+ @staticmethod
413
+ def create_device(hostname="", interface="") -> Device: ...
414
+ @staticmethod
415
+ def create_default_device() -> Device: ...
416
+
417
+ class _ProcessGroupWrapper(ProcessGroup):
418
+ def __init__(self, pg: ProcessGroup, gloo_pg: ProcessGroupGloo): ...
419
+ wrapped_pg: ProcessGroup
420
+
421
+ class ProcessGroupNCCL(ProcessGroup):
422
+ class Options: ...
423
+
424
+ def __init__(
425
+ self,
426
+ store: Store,
427
+ rank: int,
428
+ size: int,
429
+ timeout: timedelta,
430
+ ): ...
431
+ def _group_start(self) -> None: ...
432
+ def _group_end(self) -> None: ...
433
+
434
+ class ProcessGroupUCC(ProcessGroup):
435
+ def __init__(
436
+ self,
437
+ store: Store,
438
+ rank: int,
439
+ size: int,
440
+ timeout: timedelta,
441
+ ): ...
442
+
443
+ class ProcessGroupMPI(ProcessGroup):
444
+ def __init__(
445
+ self,
446
+ rank: int,
447
+ size: int,
448
+ pgComm: int,
449
+ ): ...
450
+ @staticmethod
451
+ def create(ranks: List[int]) -> ProcessGroupMPI: ...
452
+
453
+ def _compute_bucket_assignment_by_size(
454
+ tensors: List[Tensor],
455
+ bucket_size_limits: List[int],
456
+ expect_sparse_gradient: List[bool] = ...,
457
+ tensor_indices: List[int] = ...,
458
+ ) -> Tuple[List[List[int]], List[int]]: ...
459
+ def _broadcast_coalesced(
460
+ process_group: ProcessGroup,
461
+ tensors: List[Tensor],
462
+ buffer_size: int,
463
+ src: int,
464
+ ): ...
465
+ def _test_python_store(store: Store): ...
466
+ def _verify_params_across_processes(
467
+ process_group: ProcessGroup,
468
+ params: List[Tensor],
469
+ logger: Optional[Logger],
470
+ ): ...
471
+ def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
472
+
473
+ class Backend:
474
+ def __init__(
475
+ self,
476
+ rank: int,
477
+ size: int,
478
+ ): ...
torch/_C/_distributed_rpc.pyi ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: disable-error-code="type-arg"
2
+ from datetime import timedelta
3
+ from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar
4
+
5
+ import torch
6
+
7
+ from . import Future
8
+ from ._autograd import ProfilerEvent
9
+ from ._distributed_c10d import Store
10
+ from ._profiler import ProfilerConfig
11
+
12
+ # This module is defined in torch/csrc/distributed/rpc/init.cpp
13
+
14
+ _DEFAULT_INIT_METHOD: str
15
+ _DEFAULT_NUM_WORKER_THREADS: int
16
+ _UNSET_RPC_TIMEOUT: float
17
+ _DEFAULT_RPC_TIMEOUT_SEC: float
18
+
19
+ _T = TypeVar("_T")
20
+
21
+ class RpcBackendOptions:
22
+ rpc_timeout: float
23
+ init_method: str
24
+ def __init__(
25
+ self,
26
+ rpc_timeout: float = ...,
27
+ init_method: str = ...,
28
+ ): ...
29
+
30
+ class WorkerInfo:
31
+ def __init__(self, name: str, worker_id: int): ...
32
+ @property
33
+ def name(self) -> str: ...
34
+ @property
35
+ def id(self) -> int: ...
36
+ def __eq__(self, other: object) -> bool: ...
37
+
38
+ class RpcAgent:
39
+ def join(self, shutdown: bool = False, timeout: float = 0): ...
40
+ def sync(self): ...
41
+ def shutdown(self): ...
42
+ @overload
43
+ def get_worker_info(self) -> WorkerInfo: ...
44
+ @overload
45
+ def get_worker_info(self, workerName: str) -> WorkerInfo: ...
46
+ def get_worker_infos(self) -> List[WorkerInfo]: ...
47
+ def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
48
+ def get_debug_info(self) -> Dict[str, str]: ...
49
+ def get_metrics(self) -> Dict[str, str]: ...
50
+
51
+ class PyRRef(Generic[_T]):
52
+ def __init__(self, value: _T, type_hint: Any = None) -> None: ...
53
+ def is_owner(self) -> bool: ...
54
+ def confirmed_by_owner(self) -> bool: ...
55
+ def owner(self) -> WorkerInfo: ...
56
+ def owner_name(self) -> str: ...
57
+ def to_here(self, timeout: float = ...) -> _T: ...
58
+ def local_value(self) -> Any: ...
59
+ def rpc_sync(self, timeout: float = ...) -> Any: ...
60
+ def rpc_async(self, timeout: float = ...) -> Any: ...
61
+ def remote(self, timeout: float = ...) -> Any: ...
62
+ def _serialize(self) -> Tuple: ...
63
+ @staticmethod
64
+ def _deserialize(tp: Tuple) -> PyRRef: ...
65
+ def _get_type(self) -> Type[_T]: ...
66
+ def _get_future(self) -> Future[_T]: ...
67
+ def _get_profiling_future(self) -> Future[_T]: ...
68
+ def _set_profiling_future(self, profilingFuture: Future[_T]): ...
69
+
70
+ class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
71
+ num_worker_threads: int
72
+ device_maps: Dict[str, Dict[torch.device, torch.device]]
73
+ devices: List[torch.device]
74
+ def __init__(
75
+ self,
76
+ num_worker_threads: int,
77
+ _transports: Optional[List],
78
+ _channels: Optional[List],
79
+ rpc_timeout: float = ...,
80
+ init_method: str = ...,
81
+ device_maps: Dict[str, Dict[torch.device, torch.device]] = {}, # noqa: B006
82
+ devices: List[torch.device] = [], # noqa: B006
83
+ ): ...
84
+ def _set_device_map(
85
+ self,
86
+ to: str,
87
+ device_map: Dict[torch.device, torch.device],
88
+ ): ...
89
+
90
+ class TensorPipeAgent(RpcAgent):
91
+ def __init__(
92
+ self,
93
+ store: Store,
94
+ name: str,
95
+ worker_id: int,
96
+ world_size: Optional[int],
97
+ opts: _TensorPipeRpcBackendOptionsBase,
98
+ reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
99
+ devices: List[torch.device],
100
+ ): ...
101
+ def join(self, shutdown: bool = False, timeout: float = 0): ...
102
+ def shutdown(self): ...
103
+ @overload
104
+ def get_worker_info(self) -> WorkerInfo: ...
105
+ @overload
106
+ def get_worker_info(self, workerName: str) -> WorkerInfo: ...
107
+ @overload
108
+ def get_worker_info(self, id: int) -> WorkerInfo: ...
109
+ def get_worker_infos(self) -> List[WorkerInfo]: ...
110
+ def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
111
+ def _update_group_membership(
112
+ self,
113
+ worker_info: WorkerInfo,
114
+ my_devices: List[torch.device],
115
+ reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
116
+ is_join: bool,
117
+ ): ...
118
+ def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
119
+ @property
120
+ def is_static_group(self) -> bool: ...
121
+ @property
122
+ def store(self) -> Store: ...
123
+
124
+ def _is_current_rpc_agent_set() -> bool: ...
125
+ def _get_current_rpc_agent() -> RpcAgent: ...
126
+ def _set_and_start_rpc_agent(agent: RpcAgent): ...
127
+ def _reset_current_rpc_agent(): ...
128
+ def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
129
+ def _destroy_rref_context(ignoreRRefLeak: bool): ...
130
+ def _rref_context_get_debug_info() -> Dict[str, str]: ...
131
+ def _cleanup_python_rpc_handler(): ...
132
+ def _invoke_rpc_builtin(
133
+ dst: WorkerInfo,
134
+ opName: str,
135
+ rpcTimeoutSeconds: float,
136
+ *args: Any,
137
+ **kwargs: Any,
138
+ ): ...
139
+ def _invoke_rpc_python_udf(
140
+ dst: WorkerInfo,
141
+ pickledPythonUDF: str,
142
+ tensors: List[torch.Tensor],
143
+ rpcTimeoutSeconds: float,
144
+ isAsyncExecution: bool,
145
+ ): ...
146
+ def _invoke_rpc_torchscript(
147
+ dstWorkerName: str,
148
+ qualifiedNameStr: str,
149
+ argsTuple: Tuple,
150
+ kwargsDict: Dict,
151
+ rpcTimeoutSeconds: float,
152
+ isAsyncExecution: bool,
153
+ ): ...
154
+ def _invoke_remote_builtin(
155
+ dst: WorkerInfo,
156
+ opName: str,
157
+ rpcTimeoutSeconds: float,
158
+ *args: Any,
159
+ **kwargs: Any,
160
+ ): ...
161
+ def _invoke_remote_python_udf(
162
+ dst: WorkerInfo,
163
+ pickledPythonUDF: str,
164
+ tensors: List[torch.Tensor],
165
+ rpcTimeoutSeconds: float,
166
+ isAsyncExecution: bool,
167
+ ): ...
168
+ def _invoke_remote_torchscript(
169
+ dstWorkerName: WorkerInfo,
170
+ qualifiedNameStr: str,
171
+ rpcTimeoutSeconds: float,
172
+ isAsyncExecution: bool,
173
+ *args: Any,
174
+ **kwargs: Any,
175
+ ): ...
176
+ def get_rpc_timeout() -> float: ...
177
+ def enable_gil_profiling(flag: bool): ...
178
+ def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
179
+
180
+ class RemoteProfilerManager:
181
+ @staticmethod
182
+ def set_current_profiling_key(key: str): ...
183
+
184
+ def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
185
+ def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
186
+ def _set_profiler_node_id(default_node_id: int): ...
187
+ def _enable_jit_rref_pickle(): ...
188
+ def _disable_jit_rref_pickle(): ...
torch/_C/_distributed_rpc_testing.pyi ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import torch
4
+
5
+ from ._distributed_c10d import Store
6
+ from ._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
7
+
8
+ # This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
9
+
10
+ class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
11
+ def __init__(
12
+ self,
13
+ num_worker_threads: int,
14
+ rpc_timeout: float,
15
+ init_method: str,
16
+ messages_to_fail: List[str],
17
+ messages_to_delay: Dict[str, float],
18
+ num_fail_sends: int,
19
+ ): ...
20
+ num_send_recv_threads: int
21
+ messages_to_fail: List[str]
22
+ messages_to_delay: Dict[str, float]
23
+ num_fail_sends: int
24
+
25
+ class FaultyTensorPipeAgent(TensorPipeAgent):
26
+ def __init__(
27
+ self,
28
+ store: Store,
29
+ name: str,
30
+ rank: int,
31
+ world_size: int,
32
+ options: FaultyTensorPipeRpcBackendOptions,
33
+ reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
34
+ devices: List[torch.device],
35
+ ): ...
torch/_C/_functions.pyi ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import AnyStr, List
2
+
3
+ from torch import Tensor
4
+
5
+ class UndefinedGrad:
6
+ def __init__(self) -> None: ...
7
+ def __call__(self, *inputs: Tensor) -> List[Tensor]: ...
8
+
9
+ class DelayedError:
10
+ def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
11
+ def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
torch/_C/_functorch.pyi ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Optional, Tuple
3
+
4
+ from torch import Tensor
5
+
6
+ # Defined in torch/csrc/functorch/init.cpp
7
+
8
+ def _set_dynamic_layer_keys_included(included: bool) -> None: ...
9
+ def get_unwrapped(tensor: Tensor) -> Tensor: ...
10
+ def is_batchedtensor(tensor: Tensor) -> bool: ...
11
+ def is_functionaltensor(tensor: Tensor) -> bool: ...
12
+ def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
13
+ def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
14
+ def maybe_get_bdim(tensor: Tensor) -> int: ...
15
+ def maybe_get_level(tensor: Tensor) -> int: ...
16
+ def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
17
+ def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
18
+ def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
19
+ def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
20
+ def current_level() -> int: ...
21
+ def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
22
+ def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
23
+ def get_single_level_autograd_function_allowed() -> bool: ...
24
+ def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
25
+ def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
26
+
27
+ # Defined in aten/src/ATen/functorch/Interpreter.h
28
+ class TransformType(Enum):
29
+ Torch: TransformType = ...
30
+ Vmap: TransformType = ...
31
+ Grad: TransformType = ...
32
+ Jvp: TransformType = ...
33
+ Functionalize: TransformType = ...
34
+
35
+ class RandomnessType(Enum):
36
+ Error: TransformType = ...
37
+ Same: TransformType = ...
38
+ Different: TransformType = ...
39
+
40
+ class CInterpreter:
41
+ def key(self) -> TransformType: ...
42
+ def level(self) -> int: ...
43
+
44
+ class CGradInterpreterPtr:
45
+ def __init__(self, interpreter: CInterpreter): ...
46
+ def lift(self, Tensor) -> Tensor: ...
47
+ def prevGradMode(self) -> bool: ...
48
+
49
+ class CJvpInterpreterPtr:
50
+ def __init__(self, interpreter: CInterpreter): ...
51
+ def lift(self, Tensor) -> Tensor: ...
52
+ def prevFwdGradMode(self) -> bool: ...
53
+
54
+ class CFunctionalizeInterpreterPtr:
55
+ def __init__(self, interpreter: CInterpreter): ...
56
+ def key(self) -> TransformType: ...
57
+ def level(self) -> int: ...
58
+ def functionalizeAddBackViews(self) -> bool: ...
59
+
60
+ class CVmapInterpreterPtr:
61
+ def __init__(self, interpreter: CInterpreter): ...
62
+ def key(self) -> TransformType: ...
63
+ def level(self) -> int: ...
64
+ def batchSize(self) -> int: ...
65
+ def randomness(self) -> RandomnessType: ...
66
+
67
+ class DynamicLayer: ...
68
+
69
+ def peek_interpreter_stack() -> CInterpreter: ...
70
+ def pop_dynamic_layer_stack() -> DynamicLayer: ...
71
+ def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
torch/_C/_itt.pyi ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Defined in torch/csrc/itt.cpp
2
+ def is_available() -> None: ...
3
+ def rangePush(message: str) -> None: ...
4
+ def rangePop() -> None: ...
5
+ def mark(message: str) -> None: ...
torch/_C/_lazy.pyi ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from torch import Tensor
4
+
5
+ # defined in torch/csrc/lazy/python/init.cpp
6
+ def _mark_step(device: str, devices: List[str], wait: bool): ...
7
+ def _wait_device_ops(devices: List[str]): ...
8
+ def _reset_metrics(): ...
9
+ def _counter_names() -> List[str]: ...
10
+ def _counter_value(name: str) -> int: ...
11
+ def _metrics_report() -> str: ...
12
+ def _get_graph_hash(tensors: List[Tensor]) -> str: ...
13
+ def _sync_multi(
14
+ tensors: List[Tensor],
15
+ devices: List[str],
16
+ wait: bool = True,
17
+ sync_ltc_data: bool = True,
18
+ ): ...
19
+ def _get_tensor_id(tensor: Tensor) -> int: ...
20
+ def _get_tensors_text(tensors: List[Tensor]) -> str: ...
21
+ def _get_tensors_dot(tensors: List[Tensor]) -> str: ...
22
+ def _get_tensors_backend(tensors: List[Tensor]) -> str: ...
23
+ def _get_force_fallback() -> str: ...
24
+ def _set_force_fallback(newval: str): ...
25
+ def _clear_ir_cache(): ...
26
+ def _dump_ir_cache(filename: str): ...
27
+ def _set_reuse_ir(val: bool): ...
28
+ def _get_default_device_type(): ...
torch/_C/_lazy_ts_backend.pyi ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # defined in torch/csrc/lazy/python/init.cpp
2
+
3
+ from typing import Any, List, Tuple
4
+
5
+ from torch import Tensor
6
+
7
+ def _init(): ...
8
+ def _get_tensors_ts_device_data_node(
9
+ tensors: List[Tensor],
10
+ ) -> Tuple[List[int], List[Any]]: ...
11
+ def _run_cached_graph(hash_str: str, graph_inputs: List[Any]) -> List[Tensor]: ...
torch/_C/_monitor.pyi ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/monitor/python_init.cpp
2
+
3
+ import datetime
4
+ from enum import Enum
5
+ from typing import Callable, Dict, List, Union
6
+
7
+ class Aggregation(Enum):
8
+ VALUE = ...
9
+ MEAN = ...
10
+ COUNT = ...
11
+ SUM = ...
12
+ MAX = ...
13
+ MIN = ...
14
+
15
+ class Stat:
16
+ name: str
17
+ count: int
18
+ def __init__(
19
+ self,
20
+ name: str,
21
+ aggregations: List[Aggregation],
22
+ window_size: int,
23
+ max_samples: int = -1,
24
+ ) -> None: ...
25
+ def add(self, v: float) -> None: ...
26
+ def get(self) -> Dict[Aggregation, float]: ...
27
+
28
+ class Event:
29
+ name: str
30
+ timestamp: datetime.datetime
31
+ data: Dict[str, Union[int, float, bool, str]]
32
+ def __init__(
33
+ self,
34
+ name: str,
35
+ timestamp: datetime.datetime,
36
+ data: Dict[str, Union[int, float, bool, str]],
37
+ ) -> None: ...
38
+
39
+ def log_event(e: Event) -> None: ...
40
+
41
+ class EventHandlerHandle: ...
42
+
43
+ def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ...
44
+ def unregister_event_handler(handle: EventHandlerHandle) -> None: ...
torch/_C/_nn.pyi ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: disable-error-code="type-arg"
2
+ from typing import List, Optional, overload, Sequence, Tuple, Union
3
+
4
+ from torch import memory_format, Tensor
5
+ from torch.types import _bool, _device, _dtype, _int, _size
6
+
7
+ # Defined in tools/autograd/templates/python_nn_functions.cpp
8
+
9
+ def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
10
+ def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
11
+ def avg_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
12
+ def avg_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
13
+ def elu_(input: Tensor, alpha: float = ...) -> Tensor: ...
14
+ def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
15
+ def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
16
+ def gelu(input: Tensor, approximate: str = ...) -> Tensor: ...
17
+ def hardsigmoid(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ...
18
+ def hardtanh(input: Tensor, min_val: float = ..., max_val: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
19
+ def hardtanh_(input: Tensor, min_val: float = ..., max_val: float = ...) -> Tensor: ...
20
+ def leaky_relu(input: Tensor, negative_slope: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
21
+ def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ...
22
+ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: ...
23
+ def log_sigmoid(input: Tensor) -> Tensor: ...
24
+ def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ...
25
+ def pad(input: Tensor, pad: Sequence[int], mode: str = ..., value: Optional[float] = None) -> Tensor: ...
26
+ def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None) -> Tensor: ...
27
+ def softplus(input: Tensor, beta: int = ..., threshold: int = ...) -> Tensor: ...
28
+ def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
29
+
30
+ # Defined in aten/src/ATen/native/mkldnn/Linear.cpp
31
+ def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ...
32
+
33
+ # Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
34
+ def mkldnn_reorder_conv2d_weight(
35
+ self: Tensor,
36
+ padding: List,
37
+ stride: List,
38
+ dilatation: List,
39
+ groups: int,
40
+ ) -> Tensor: ...
41
+ def mkldnn_reorder_conv3d_weight(
42
+ self: Tensor,
43
+ padding: List,
44
+ stride: List,
45
+ dilatation: List,
46
+ groups: int,
47
+ ) -> Tensor: ...
48
+
49
+ # Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
50
+ def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
51
+
52
+ # Defined at tools/autograd/templates/python_nn_functions.cpp
53
+ @overload
54
+ def _parse_to(
55
+ device: _device,
56
+ dtype: _dtype,
57
+ non_blocking: _bool,
58
+ copy: _bool,
59
+ *,
60
+ memory_format: memory_format,
61
+ ) -> Tuple[_device, _dtype, _bool, memory_format]: ...
62
+ @overload
63
+ def _parse_to(
64
+ dtype: _dtype,
65
+ non_blocking: _bool,
66
+ copy: _bool,
67
+ *,
68
+ memory_format: memory_format,
69
+ ) -> Tuple[_device, _dtype, _bool, memory_format]: ...
70
+ @overload
71
+ def _parse_to(
72
+ tensor: Tensor,
73
+ non_blocking: _bool,
74
+ copy: _bool,
75
+ *,
76
+ memory_format: memory_format,
77
+ ) -> Tuple[_device, _dtype, _bool, memory_format]: ...
78
+
79
+ # Defined in aten/src/ATen/native/PadSequence.cpp
80
+ def pad_sequence(
81
+ sequences: List[Tensor],
82
+ batch_first: bool = False,
83
+ padding_value: float = ...,
84
+ ) -> Tensor: ...
85
+ def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ...
86
+ def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ...
torch/_C/_nvtx.pyi ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/cuda/shared/nvtx.cpp
2
+ def rangePushA(message: str) -> int: ...
3
+ def rangePop() -> int: ...
4
+ def rangeStartA(message: str) -> int: ...
5
+ def rangeEnd(int) -> None: ...
6
+ def markA(message: str) -> None: ...
torch/_C/_onnx.pyi ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/onnx/init.cpp
2
+
3
+ from enum import Enum
4
+
5
+ _CAFFE2_ATEN_FALLBACK: bool
6
+ PRODUCER_VERSION: str
7
+
8
+ class TensorProtoDataType(Enum):
9
+ UNDEFINED = ...
10
+ FLOAT = ...
11
+ UINT8 = ...
12
+ INT8 = ...
13
+ UINT16 = ...
14
+ INT16 = ...
15
+ INT32 = ...
16
+ INT64 = ...
17
+ STRING = ...
18
+ BOOL = ...
19
+ FLOAT16 = ...
20
+ DOUBLE = ...
21
+ UINT32 = ...
22
+ UINT64 = ...
23
+ COMPLEX64 = ...
24
+ COMPLEX128 = ...
25
+ BFLOAT16 = ...
26
+ FLOAT8E5M2 = ...
27
+ FLOAT8E4M3FN = ...
28
+
29
+ class OperatorExportTypes(Enum):
30
+ ONNX = ...
31
+ ONNX_ATEN = ...
32
+ ONNX_ATEN_FALLBACK = ...
33
+ ONNX_FALLTHROUGH = ...
34
+
35
+ class TrainingMode(Enum):
36
+ EVAL = ...
37
+ PRESERVE = ...
38
+ TRAINING = ...
torch/_C/_profiler.pyi ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
3
+
4
+ from torch._C import device, dtype, layout
5
+ from typing_extensions import TypeAlias
6
+
7
+ # defined in torch/csrc/profiler/python/init.cpp
8
+
9
+ class RecordScope(Enum):
10
+ FUNCTION = ...
11
+ BACKWARD_FUNCTION = ...
12
+ TORCHSCRIPT_FUNCTION = ...
13
+ KERNEL_FUNCTION_DTYPE = ...
14
+ CUSTOM_CLASS = ...
15
+ BUILD_FEATURE = ...
16
+ LITE_INTERPRETER = ...
17
+ USER_SCOPE = ...
18
+ STATIC_RUNTIME_OP = ...
19
+ STATIC_RUNTIME_MODEL = ...
20
+
21
+ class ProfilerState(Enum):
22
+ Disable = ...
23
+ CPU = ...
24
+ CUDA = ...
25
+ NVTX = ...
26
+ ITT = ...
27
+ KINETO = ...
28
+ KINETO_GPU_FALLBACK = ...
29
+ KINETO_PRIVATEUSE1_FALLBACK = ...
30
+ KINETO_PRIVATEUSE1 = ...
31
+
32
+ class ActiveProfilerType(Enum):
33
+ NONE = ...
34
+ LEGACY = ...
35
+ KINETO = ...
36
+ NVTX = ...
37
+ ITT = ...
38
+
39
+ class ProfilerActivity(Enum):
40
+ CPU = ...
41
+ CUDA = ...
42
+ MTIA = ...
43
+ PrivateUse1 = ...
44
+
45
+ class _EventType(Enum):
46
+ TorchOp = ...
47
+ Backend = ...
48
+ Allocation = ...
49
+ OutOfMemory = ...
50
+ PyCall = ...
51
+ PyCCall = ...
52
+ Kineto = ...
53
+
54
+ class _ExperimentalConfig:
55
+ def __init__(
56
+ self,
57
+ profiler_metrics: List[str] = ...,
58
+ profiler_measure_per_kernel: bool = ...,
59
+ verbose: bool = ...,
60
+ performance_events: List[str] = ...,
61
+ enable_cuda_sync_events: bool = ...,
62
+ ) -> None: ...
63
+
64
+ class ProfilerConfig:
65
+ def __init__(
66
+ self,
67
+ state: ProfilerState,
68
+ report_input_shapes: bool,
69
+ profile_memory: bool,
70
+ with_stack: bool,
71
+ with_flops: bool,
72
+ with_modules: bool,
73
+ experimental_config: _ExperimentalConfig,
74
+ ) -> None: ...
75
+
76
+ class _ProfilerEvent:
77
+ start_tid: int
78
+ start_time_ns: int
79
+ children: List[_ProfilerEvent]
80
+
81
+ # TODO(robieta): remove in favor of `self.typed`
82
+ extra_fields: Union[
83
+ _ExtraFields_TorchOp,
84
+ _ExtraFields_Backend,
85
+ _ExtraFields_Allocation,
86
+ _ExtraFields_OutOfMemory,
87
+ _ExtraFields_PyCall,
88
+ _ExtraFields_PyCCall,
89
+ _ExtraFields_Kineto,
90
+ ]
91
+
92
+ @property
93
+ def typed(
94
+ self,
95
+ ) -> Union[
96
+ Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
97
+ Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
98
+ Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
99
+ Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
100
+ Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
101
+ Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
102
+ Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
103
+ ]: ...
104
+ @property
105
+ def name(self) -> str: ...
106
+ @property
107
+ def tag(self) -> _EventType: ...
108
+ @property
109
+ def id(self) -> int: ...
110
+ @property
111
+ def parent(self) -> Optional[_ProfilerEvent]: ...
112
+ @property
113
+ def correlation_id(self) -> int: ...
114
+ @property
115
+ def end_time_ns(self) -> int: ...
116
+ @property
117
+ def duration_time_ns(self) -> int: ...
118
+
119
+ class _TensorMetadata:
120
+ impl_ptr: Optional[int]
121
+ storage_data_ptr: Optional[int]
122
+ id: Optional[int]
123
+
124
+ @property
125
+ def allocation_id(self) -> Optional[int]: ...
126
+ @property
127
+ def layout(self) -> layout: ...
128
+ @property
129
+ def device(self) -> device: ...
130
+ @property
131
+ def dtype(self) -> dtype: ...
132
+ @property
133
+ def sizes(self) -> List[int]: ...
134
+ @property
135
+ def strides(self) -> List[int]: ...
136
+
137
+ Scalar: TypeAlias = Union[int, float, bool, complex]
138
+ Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
139
+
140
+ class _ExtraFields_TorchOp:
141
+ name: str
142
+ sequence_number: int
143
+ allow_tf32_cublas: bool
144
+
145
+ @property
146
+ def inputs(self) -> List[Input]: ...
147
+ @property
148
+ def scope(self) -> RecordScope: ...
149
+
150
+ class _ExtraFields_Backend: ...
151
+
152
+ class _ExtraFields_Allocation:
153
+ ptr: int
154
+ id: Optional[int]
155
+ alloc_size: int
156
+ total_allocated: int
157
+ total_reserved: int
158
+
159
+ @property
160
+ def allocation_id(self) -> Optional[int]: ...
161
+ @property
162
+ def device(self) -> device: ...
163
+
164
+ class _ExtraFields_OutOfMemory: ...
165
+
166
+ class _PyFrameState:
167
+ line_number: int
168
+ function_name: str
169
+
170
+ @property
171
+ def file_name(self) -> str: ...
172
+
173
+ class _NNModuleInfo:
174
+ @property
175
+ def self_ptr(self) -> int: ...
176
+ @property
177
+ def cls_ptr(self) -> int: ...
178
+ @property
179
+ def cls_name(self) -> str: ...
180
+ @property
181
+ def parameters(
182
+ self,
183
+ ) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
184
+
185
+ class _OptimizerInfo:
186
+ @property
187
+ def parameters(
188
+ self,
189
+ ) -> List[
190
+ Tuple[
191
+ # Parameter
192
+ _TensorMetadata,
193
+ #
194
+ # Gradient (if present during optimizer.step())
195
+ Optional[_TensorMetadata],
196
+ #
197
+ # Optimizer state for Parameter as (name, tensor) pairs
198
+ List[Tuple[str, _TensorMetadata]],
199
+ ]
200
+ ]: ...
201
+
202
+ class _ExtraFields_PyCCall:
203
+ @property
204
+ def caller(self) -> _PyFrameState: ...
205
+
206
+ class _ExtraFields_PyCall:
207
+ @property
208
+ def callsite(self) -> _PyFrameState: ...
209
+ @property
210
+ def caller(self) -> _PyFrameState: ...
211
+ @property
212
+ def module(self) -> Optional[_NNModuleInfo]: ...
213
+ @property
214
+ def optimizer(self) -> Optional[_OptimizerInfo]: ...
215
+
216
+ class _ExtraFields_Kineto: ...
217
+
218
+ def _add_execution_trace_observer(output_file_path: str) -> bool: ...
219
+ def _remove_execution_trace_observer() -> None: ...
220
+ def _enable_execution_trace_observer() -> None: ...
221
+ def _disable_execution_trace_observer() -> None: ...
222
+ def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
223
+ def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
224
+ def _set_cuda_sync_enabled_val(val: bool) -> None: ...
225
+
226
+ class CapturedTraceback: ...
227
+
228
+ def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ...
229
+
230
+ # The Dict has name, filename, line
231
+ def symbolize_tracebacks(
232
+ to_symbolize: List[CapturedTraceback],
233
+ ) -> List[List[Dict[str, str]]]: ...
234
+
235
+ class _RecordFunctionFast:
236
+ def __init__(self, name: str) -> None: ...
237
+ def __enter__(self) -> None: ...
238
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ...
torch/_C/_verbose.pyi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Defined in torch/csrc/utils/verbose.cpp
2
+ def mkl_set_verbose(enable: int) -> int: ...
3
+ def mkldnn_set_verbose(level: int) -> int: ...
torch/_VF.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This makes the functions in torch._C._VariableFunctions available as
3
+ torch._VF.<funcname>
4
+ without mypy being able to find them.
5
+
6
+ A subset of those functions are mapped to ATen functions in
7
+ torch/jit/_builtins.py
8
+
9
+ See https://github.com/pytorch/pytorch/issues/21478 for the reason for
10
+ introducing torch._VF
11
+
12
+ """
13
+ import sys
14
+ import types
15
+
16
+ import torch
17
+
18
+
19
+ class VFModule(types.ModuleType):
20
+ vf: types.ModuleType
21
+
22
+ def __init__(self, name):
23
+ super().__init__(name)
24
+ self.vf = torch._C._VariableFunctions
25
+
26
+ def __getattr__(self, attr):
27
+ return getattr(self.vf, attr)
28
+
29
+
30
+ sys.modules[__name__] = VFModule(__name__)
torch/_VF.pyi ADDED
The diff for this file is too large to render. See raw diff
 
torch/__config__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def show():
5
+ """
6
+ Return a human-readable string with descriptions of the
7
+ configuration of PyTorch.
8
+ """
9
+ return torch._C._show_config()
10
+
11
+
12
+ # TODO: In principle, we could provide more structured version/config
13
+ # information here. For now only CXX_FLAGS is exposed, as Timer
14
+ # uses them.
15
+ def _cxx_flags():
16
+ """Returns the CXX_FLAGS used when building PyTorch."""
17
+ return torch._C._cxx_flags()
18
+
19
+
20
+ def parallel_info():
21
+ r"""Returns detailed string with parallelization settings"""
22
+ return torch._C._parallel_info()
torch/__future__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This global flag controls whether to assign new tensors to the parameters
3
+ instead of changing the existing parameters in-place when converting an `nn.Module`
4
+ using the following methods:
5
+ 1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
6
+ 2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
7
+ 3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
8
+ 4. `module._apply(fn)` (for generic functions applied to `module`)
9
+
10
+ Default: False
11
+ """
12
+ _overwrite_module_params_on_conversion = False
13
+
14
+
15
+ def set_overwrite_module_params_on_conversion(value):
16
+ global _overwrite_module_params_on_conversion
17
+ _overwrite_module_params_on_conversion = value
18
+
19
+
20
+ def get_overwrite_module_params_on_conversion():
21
+ return _overwrite_module_params_on_conversion
torch/_appdirs.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2005-2010 ActiveState Software Inc.
4
+ # Copyright (c) 2013 Eddy Petrișor
5
+
6
+ # flake8: noqa
7
+
8
+ """
9
+ This file is directly from
10
+ https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py
11
+
12
+ The license of https://github.com/ActiveState/appdirs copied below:
13
+
14
+
15
+ # This is the MIT license
16
+
17
+ Copyright (c) 2010 ActiveState Software Inc.
18
+
19
+ Permission is hereby granted, free of charge, to any person obtaining a
20
+ copy of this software and associated documentation files (the
21
+ "Software"), to deal in the Software without restriction, including
22
+ without limitation the rights to use, copy, modify, merge, publish,
23
+ distribute, sublicense, and/or sell copies of the Software, and to
24
+ permit persons to whom the Software is furnished to do so, subject to
25
+ the following conditions:
26
+
27
+ The above copyright notice and this permission notice shall be included
28
+ in all copies or substantial portions of the Software.
29
+
30
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
31
+ OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
32
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
33
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
34
+ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
35
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
36
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
37
+ """
38
+
39
+ """Utilities for determining application-specific dirs.
40
+
41
+ See <https://github.com/ActiveState/appdirs> for details and usage.
42
+ """
43
+ # Dev Notes:
44
+ # - MSDN on where to store app data files:
45
+ # http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
46
+ # - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
47
+ # - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
48
+
49
+ __version__ = "1.4.4"
50
+ __version_info__ = tuple(int(segment) for segment in __version__.split("."))
51
+
52
+
53
+ import os
54
+ import sys
55
+
56
+ unicode = str
57
+
58
+ if sys.platform.startswith("java"):
59
+ import platform
60
+
61
+ os_name = platform.java_ver()[3][0]
62
+ if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc.
63
+ system = "win32"
64
+ elif os_name.startswith("Mac"): # "Mac OS X", etc.
65
+ system = "darwin"
66
+ else: # "Linux", "SunOS", "FreeBSD", etc.
67
+ # Setting this to "linux2" is not ideal, but only Windows or Mac
68
+ # are actually checked for and the rest of the module expects
69
+ # *sys.platform* style strings.
70
+ system = "linux2"
71
+ else:
72
+ system = sys.platform
73
+
74
+
75
+ def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
76
+ r"""Return full path to the user-specific data dir for this application.
77
+
78
+ "appname" is the name of application.
79
+ If None, just the system directory is returned.
80
+ "appauthor" (only used on Windows) is the name of the
81
+ appauthor or distributing body for this application. Typically
82
+ it is the owning company name. This falls back to appname. You may
83
+ pass False to disable it.
84
+ "version" is an optional version path element to append to the
85
+ path. You might want to use this if you want multiple versions
86
+ of your app to be able to run independently. If used, this
87
+ would typically be "<major>.<minor>".
88
+ Only applied when appname is present.
89
+ "roaming" (boolean, default False) can be set True to use the Windows
90
+ roaming appdata directory. That means that for users on a Windows
91
+ network setup for roaming profiles, this user data will be
92
+ sync'd on login. See
93
+ <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
94
+ for a discussion of issues.
95
+
96
+ Typical user data directories are:
97
+ Mac OS X: ~/Library/Application Support/<AppName>
98
+ Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
99
+ Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
100
+ Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
101
+ Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
102
+ Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
103
+
104
+ For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
105
+ That means, by default "~/.local/share/<AppName>".
106
+ """
107
+ if system == "win32":
108
+ if appauthor is None:
109
+ appauthor = appname
110
+ const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
111
+ path = os.path.normpath(_get_win_folder(const))
112
+ if appname:
113
+ if appauthor is not False:
114
+ path = os.path.join(path, appauthor, appname)
115
+ else:
116
+ path = os.path.join(path, appname)
117
+ elif system == "darwin":
118
+ path = os.path.expanduser("~/Library/Application Support/")
119
+ if appname:
120
+ path = os.path.join(path, appname)
121
+ else:
122
+ path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
123
+ if appname:
124
+ path = os.path.join(path, appname)
125
+ if appname and version:
126
+ path = os.path.join(path, version)
127
+ return path
128
+
129
+
130
+ def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
131
+ r"""Return full path to the user-shared data dir for this application.
132
+
133
+ "appname" is the name of application.
134
+ If None, just the system directory is returned.
135
+ "appauthor" (only used on Windows) is the name of the
136
+ appauthor or distributing body for this application. Typically
137
+ it is the owning company name. This falls back to appname. You may
138
+ pass False to disable it.
139
+ "version" is an optional version path element to append to the
140
+ path. You might want to use this if you want multiple versions
141
+ of your app to be able to run independently. If used, this
142
+ would typically be "<major>.<minor>".
143
+ Only applied when appname is present.
144
+ "multipath" is an optional parameter only applicable to *nix
145
+ which indicates that the entire list of data dirs should be
146
+ returned. By default, the first item from XDG_DATA_DIRS is
147
+ returned, or '/usr/local/share/<AppName>',
148
+ if XDG_DATA_DIRS is not set
149
+
150
+ Typical site data directories are:
151
+ Mac OS X: /Library/Application Support/<AppName>
152
+ Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
153
+ Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
154
+ Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
155
+ Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
156
+
157
+ For Unix, this is using the $XDG_DATA_DIRS[0] default.
158
+
159
+ WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
160
+ """
161
+ if system == "win32":
162
+ if appauthor is None:
163
+ appauthor = appname
164
+ path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
165
+ if appname:
166
+ if appauthor is not False:
167
+ path = os.path.join(path, appauthor, appname)
168
+ else:
169
+ path = os.path.join(path, appname)
170
+ elif system == "darwin":
171
+ path = os.path.expanduser("/Library/Application Support")
172
+ if appname:
173
+ path = os.path.join(path, appname)
174
+ else:
175
+ # XDG default for $XDG_DATA_DIRS
176
+ # only first, if multipath is False
177
+ path = os.getenv(
178
+ "XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"])
179
+ )
180
+ pathlist = [
181
+ os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
182
+ ]
183
+ if appname:
184
+ if version:
185
+ appname = os.path.join(appname, version)
186
+ pathlist = [os.sep.join([x, appname]) for x in pathlist]
187
+
188
+ if multipath:
189
+ path = os.pathsep.join(pathlist)
190
+ else:
191
+ path = pathlist[0]
192
+ return path
193
+
194
+ if appname and version:
195
+ path = os.path.join(path, version)
196
+ return path
197
+
198
+
199
+ def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
200
+ r"""Return full path to the user-specific config dir for this application.
201
+
202
+ "appname" is the name of application.
203
+ If None, just the system directory is returned.
204
+ "appauthor" (only used on Windows) is the name of the
205
+ appauthor or distributing body for this application. Typically
206
+ it is the owning company name. This falls back to appname. You may
207
+ pass False to disable it.
208
+ "version" is an optional version path element to append to the
209
+ path. You might want to use this if you want multiple versions
210
+ of your app to be able to run independently. If used, this
211
+ would typically be "<major>.<minor>".
212
+ Only applied when appname is present.
213
+ "roaming" (boolean, default False) can be set True to use the Windows
214
+ roaming appdata directory. That means that for users on a Windows
215
+ network setup for roaming profiles, this user data will be
216
+ sync'd on login. See
217
+ <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
218
+ for a discussion of issues.
219
+
220
+ Typical user config directories are:
221
+ Mac OS X: ~/Library/Preferences/<AppName>
222
+ Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
223
+ Win *: same as user_data_dir
224
+
225
+ For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
226
+ That means, by default "~/.config/<AppName>".
227
+ """
228
+ if system == "win32":
229
+ path = user_data_dir(appname, appauthor, None, roaming)
230
+ elif system == "darwin":
231
+ path = os.path.expanduser("~/Library/Preferences/")
232
+ if appname:
233
+ path = os.path.join(path, appname)
234
+ else:
235
+ path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
236
+ if appname:
237
+ path = os.path.join(path, appname)
238
+ if appname and version:
239
+ path = os.path.join(path, version)
240
+ return path
241
+
242
+
243
+ def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
244
+ r"""Return full path to the user-shared data dir for this application.
245
+
246
+ "appname" is the name of application.
247
+ If None, just the system directory is returned.
248
+ "appauthor" (only used on Windows) is the name of the
249
+ appauthor or distributing body for this application. Typically
250
+ it is the owning company name. This falls back to appname. You may
251
+ pass False to disable it.
252
+ "version" is an optional version path element to append to the
253
+ path. You might want to use this if you want multiple versions
254
+ of your app to be able to run independently. If used, this
255
+ would typically be "<major>.<minor>".
256
+ Only applied when appname is present.
257
+ "multipath" is an optional parameter only applicable to *nix
258
+ which indicates that the entire list of config dirs should be
259
+ returned. By default, the first item from XDG_CONFIG_DIRS is
260
+ returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
261
+
262
+ Typical site config directories are:
263
+ Mac OS X: same as site_data_dir
264
+ Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
265
+ $XDG_CONFIG_DIRS
266
+ Win *: same as site_data_dir
267
+ Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
268
+
269
+ For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
270
+
271
+ WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
272
+ """
273
+ if system == "win32":
274
+ path = site_data_dir(appname, appauthor)
275
+ if appname and version:
276
+ path = os.path.join(path, version)
277
+ elif system == "darwin":
278
+ path = os.path.expanduser("/Library/Preferences")
279
+ if appname:
280
+ path = os.path.join(path, appname)
281
+ else:
282
+ # XDG default for $XDG_CONFIG_DIRS
283
+ # only first, if multipath is False
284
+ path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg")
285
+ pathlist = [
286
+ os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
287
+ ]
288
+ if appname:
289
+ if version:
290
+ appname = os.path.join(appname, version)
291
+ pathlist = [os.sep.join([x, appname]) for x in pathlist]
292
+
293
+ if multipath:
294
+ path = os.pathsep.join(pathlist)
295
+ else:
296
+ path = pathlist[0]
297
+ return path
298
+
299
+
300
+ def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
301
+ r"""Return full path to the user-specific cache dir for this application.
302
+
303
+ "appname" is the name of application.
304
+ If None, just the system directory is returned.
305
+ "appauthor" (only used on Windows) is the name of the
306
+ appauthor or distributing body for this application. Typically
307
+ it is the owning company name. This falls back to appname. You may
308
+ pass False to disable it.
309
+ "version" is an optional version path element to append to the
310
+ path. You might want to use this if you want multiple versions
311
+ of your app to be able to run independently. If used, this
312
+ would typically be "<major>.<minor>".
313
+ Only applied when appname is present.
314
+ "opinion" (boolean) can be False to disable the appending of
315
+ "Cache" to the base app data dir for Windows. See
316
+ discussion below.
317
+
318
+ Typical user cache directories are:
319
+ Mac OS X: ~/Library/Caches/<AppName>
320
+ Unix: ~/.cache/<AppName> (XDG default)
321
+ Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
322
+ Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
323
+
324
+ On Windows the only suggestion in the MSDN docs is that local settings go in
325
+ the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
326
+ app data dir (the default returned by `user_data_dir` above). Apps typically
327
+ put cache data somewhere *under* the given dir here. Some examples:
328
+ ...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
329
+ ...\Acme\SuperApp\Cache\1.0
330
+ OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
331
+ This can be disabled with the `opinion=False` option.
332
+ """
333
+ if system == "win32":
334
+ if appauthor is None:
335
+ appauthor = appname
336
+ path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
337
+ if appname:
338
+ if appauthor is not False:
339
+ path = os.path.join(path, appauthor, appname)
340
+ else:
341
+ path = os.path.join(path, appname)
342
+ if opinion:
343
+ path = os.path.join(path, "Cache")
344
+ elif system == "darwin":
345
+ path = os.path.expanduser("~/Library/Caches")
346
+ if appname:
347
+ path = os.path.join(path, appname)
348
+ else:
349
+ path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
350
+ if appname:
351
+ path = os.path.join(path, appname)
352
+ if appname and version:
353
+ path = os.path.join(path, version)
354
+ return path
355
+
356
+
357
+ def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
358
+ r"""Return full path to the user-specific state dir for this application.
359
+
360
+ "appname" is the name of application.
361
+ If None, just the system directory is returned.
362
+ "appauthor" (only used on Windows) is the name of the
363
+ appauthor or distributing body for this application. Typically
364
+ it is the owning company name. This falls back to appname. You may
365
+ pass False to disable it.
366
+ "version" is an optional version path element to append to the
367
+ path. You might want to use this if you want multiple versions
368
+ of your app to be able to run independently. If used, this
369
+ would typically be "<major>.<minor>".
370
+ Only applied when appname is present.
371
+ "roaming" (boolean, default False) can be set True to use the Windows
372
+ roaming appdata directory. That means that for users on a Windows
373
+ network setup for roaming profiles, this user data will be
374
+ sync'd on login. See
375
+ <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
376
+ for a discussion of issues.
377
+
378
+ Typical user state directories are:
379
+ Mac OS X: same as user_data_dir
380
+ Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
381
+ Win *: same as user_data_dir
382
+
383
+ For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
384
+ to extend the XDG spec and support $XDG_STATE_HOME.
385
+
386
+ That means, by default "~/.local/state/<AppName>".
387
+ """
388
+ if system in ["win32", "darwin"]:
389
+ path = user_data_dir(appname, appauthor, None, roaming)
390
+ else:
391
+ path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state"))
392
+ if appname:
393
+ path = os.path.join(path, appname)
394
+ if appname and version:
395
+ path = os.path.join(path, version)
396
+ return path
397
+
398
+
399
+ def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
400
+ r"""Return full path to the user-specific log dir for this application.
401
+
402
+ "appname" is the name of application.
403
+ If None, just the system directory is returned.
404
+ "appauthor" (only used on Windows) is the name of the
405
+ appauthor or distributing body for this application. Typically
406
+ it is the owning company name. This falls back to appname. You may
407
+ pass False to disable it.
408
+ "version" is an optional version path element to append to the
409
+ path. You might want to use this if you want multiple versions
410
+ of your app to be able to run independently. If used, this
411
+ would typically be "<major>.<minor>".
412
+ Only applied when appname is present.
413
+ "opinion" (boolean) can be False to disable the appending of
414
+ "Logs" to the base app data dir for Windows, and "log" to the
415
+ base cache dir for Unix. See discussion below.
416
+
417
+ Typical user log directories are:
418
+ Mac OS X: ~/Library/Logs/<AppName>
419
+ Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
420
+ Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
421
+ Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
422
+
423
+ On Windows the only suggestion in the MSDN docs is that local settings
424
+ go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
425
+ examples of what some windows apps use for a logs dir.)
426
+
427
+ OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
428
+ value for Windows and appends "log" to the user cache dir for Unix.
429
+ This can be disabled with the `opinion=False` option.
430
+ """
431
+ if system == "darwin":
432
+ path = os.path.join(os.path.expanduser("~/Library/Logs"), appname)
433
+ elif system == "win32":
434
+ path = user_data_dir(appname, appauthor, version)
435
+ version = False
436
+ if opinion:
437
+ path = os.path.join(path, "Logs")
438
+ else:
439
+ path = user_cache_dir(appname, appauthor, version)
440
+ version = False
441
+ if opinion:
442
+ path = os.path.join(path, "log")
443
+ if appname and version:
444
+ path = os.path.join(path, version)
445
+ return path
446
+
447
+
448
+ class AppDirs(object):
449
+ """Convenience wrapper for getting application dirs."""
450
+
451
+ def __init__(
452
+ self, appname=None, appauthor=None, version=None, roaming=False, multipath=False
453
+ ):
454
+ self.appname = appname
455
+ self.appauthor = appauthor
456
+ self.version = version
457
+ self.roaming = roaming
458
+ self.multipath = multipath
459
+
460
+ @property
461
+ def user_data_dir(self):
462
+ return user_data_dir(
463
+ self.appname, self.appauthor, version=self.version, roaming=self.roaming
464
+ )
465
+
466
+ @property
467
+ def site_data_dir(self):
468
+ return site_data_dir(
469
+ self.appname, self.appauthor, version=self.version, multipath=self.multipath
470
+ )
471
+
472
+ @property
473
+ def user_config_dir(self):
474
+ return user_config_dir(
475
+ self.appname, self.appauthor, version=self.version, roaming=self.roaming
476
+ )
477
+
478
+ @property
479
+ def site_config_dir(self):
480
+ return site_config_dir(
481
+ self.appname, self.appauthor, version=self.version, multipath=self.multipath
482
+ )
483
+
484
+ @property
485
+ def user_cache_dir(self):
486
+ return user_cache_dir(self.appname, self.appauthor, version=self.version)
487
+
488
+ @property
489
+ def user_state_dir(self):
490
+ return user_state_dir(self.appname, self.appauthor, version=self.version)
491
+
492
+ @property
493
+ def user_log_dir(self):
494
+ return user_log_dir(self.appname, self.appauthor, version=self.version)
495
+
496
+
497
+ # ---- internal support stuff
498
+
499
+
500
+ def _get_win_folder_from_registry(csidl_name):
501
+ """This is a fallback technique at best. I'm not sure if using the
502
+ registry for this guarantees us the correct answer for all CSIDL_*
503
+ names.
504
+ """
505
+ import winreg as _winreg
506
+
507
+ shell_folder_name = {
508
+ "CSIDL_APPDATA": "AppData",
509
+ "CSIDL_COMMON_APPDATA": "Common AppData",
510
+ "CSIDL_LOCAL_APPDATA": "Local AppData",
511
+ }[csidl_name]
512
+
513
+ key = _winreg.OpenKey(
514
+ _winreg.HKEY_CURRENT_USER,
515
+ r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
516
+ )
517
+ dir, type = _winreg.QueryValueEx(key, shell_folder_name)
518
+ return dir
519
+
520
+
521
+ def _get_win_folder_with_pywin32(csidl_name):
522
+ from win32com.shell import shell, shellcon
523
+
524
+ dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
525
+ # Try to make this a unicode path because SHGetFolderPath does
526
+ # not return unicode strings when there is unicode data in the
527
+ # path.
528
+ try:
529
+ dir = unicode(dir)
530
+
531
+ # Downgrade to short path name if have highbit chars. See
532
+ # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
533
+ has_high_char = False
534
+ for c in dir:
535
+ if ord(c) > 255:
536
+ has_high_char = True
537
+ break
538
+ if has_high_char:
539
+ try:
540
+ import win32api
541
+
542
+ dir = win32api.GetShortPathName(dir)
543
+ except ImportError:
544
+ pass
545
+ except UnicodeError:
546
+ pass
547
+ return dir
548
+
549
+
550
+ def _get_win_folder_with_ctypes(csidl_name):
551
+ import ctypes
552
+
553
+ csidl_const = {
554
+ "CSIDL_APPDATA": 26,
555
+ "CSIDL_COMMON_APPDATA": 35,
556
+ "CSIDL_LOCAL_APPDATA": 28,
557
+ }[csidl_name]
558
+
559
+ buf = ctypes.create_unicode_buffer(1024)
560
+ ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
561
+
562
+ # Downgrade to short path name if have highbit chars. See
563
+ # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
564
+ has_high_char = False
565
+ for c in buf:
566
+ if ord(c) > 255:
567
+ has_high_char = True
568
+ break
569
+ if has_high_char:
570
+ buf2 = ctypes.create_unicode_buffer(1024)
571
+ if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
572
+ buf = buf2
573
+
574
+ return buf.value
575
+
576
+
577
+ def _get_win_folder_with_jna(csidl_name):
578
+ import array
579
+
580
+ from com.sun import jna
581
+ from com.sun.jna.platform import win32
582
+
583
+ buf_size = win32.WinDef.MAX_PATH * 2
584
+ buf = array.zeros("c", buf_size)
585
+ shell = win32.Shell32.INSTANCE
586
+ shell.SHGetFolderPath(
587
+ None,
588
+ getattr(win32.ShlObj, csidl_name),
589
+ None,
590
+ win32.ShlObj.SHGFP_TYPE_CURRENT,
591
+ buf,
592
+ )
593
+ dir = jna.Native.toString(buf.tostring()).rstrip("\0")
594
+
595
+ # Downgrade to short path name if have highbit chars. See
596
+ # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
597
+ has_high_char = False
598
+ for c in dir:
599
+ if ord(c) > 255:
600
+ has_high_char = True
601
+ break
602
+ if has_high_char:
603
+ buf = array.zeros("c", buf_size)
604
+ kernel = win32.Kernel32.INSTANCE
605
+ if kernel.GetShortPathName(dir, buf, buf_size):
606
+ dir = jna.Native.toString(buf.tostring()).rstrip("\0")
607
+
608
+ return dir
609
+
610
+
611
+ if system == "win32":
612
+ try:
613
+ import win32com.shell
614
+
615
+ _get_win_folder = _get_win_folder_with_pywin32
616
+ except ImportError:
617
+ try:
618
+ from ctypes import windll
619
+
620
+ _get_win_folder = _get_win_folder_with_ctypes
621
+ except ImportError:
622
+ try:
623
+ import com.sun.jna
624
+
625
+ _get_win_folder = _get_win_folder_with_jna
626
+ except ImportError:
627
+ _get_win_folder = _get_win_folder_from_registry
628
+
629
+
630
+ # ---- self test code
631
+
632
+ if __name__ == "__main__":
633
+ appname = "MyApp"
634
+ appauthor = "MyCompany"
635
+
636
+ props = (
637
+ "user_data_dir",
638
+ "user_config_dir",
639
+ "user_cache_dir",
640
+ "user_state_dir",
641
+ "user_log_dir",
642
+ "site_data_dir",
643
+ "site_config_dir",
644
+ )
645
+
646
+ print(f"-- app dirs {__version__} --")
647
+
648
+ print("-- app dirs (with optional 'version')")
649
+ dirs = AppDirs(appname, appauthor, version="1.0")
650
+ for prop in props:
651
+ print(f"{prop}: {getattr(dirs, prop)}")
652
+
653
+ print("\n-- app dirs (without optional 'version')")
654
+ dirs = AppDirs(appname, appauthor)
655
+ for prop in props:
656
+ print(f"{prop}: {getattr(dirs, prop)}")
657
+
658
+ print("\n-- app dirs (without optional 'appauthor')")
659
+ dirs = AppDirs(appname)
660
+ for prop in props:
661
+ print(f"{prop}: {getattr(dirs, prop)}")
662
+
663
+ print("\n-- app dirs (with disabled 'appauthor')")
664
+ dirs = AppDirs(appname, appauthor=False)
665
+ for prop in props:
666
+ print(f"{prop}: {getattr(dirs, prop)}")
torch/_awaits/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import cast, Callable, Generic, Type, TypeVar
4
+
5
+ import torch
6
+
7
+ __all__ = ['Await']
8
+
9
+ W = TypeVar("W")
10
+
11
+ class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
12
+ pass
13
+
14
+ class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
15
+ r"""
16
+ Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
17
+ of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
18
+ ``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
19
+
20
+ Torch scriptable manipulations:
21
+ ``torch.jit._awaitable(func, *args)``
22
+ Creates ``Await[W]`` object, where W is return type of func.
23
+
24
+ Returns:
25
+ ``torch.jit._awaitable_wait(Await[W])``
26
+ Returns the result of the function, specified at ``_awaitable``, with specified arguments.
27
+
28
+ Returns:
29
+ The result of type ``W`` of the function call. The result is owned by ``Await[W]``
30
+ and returned on all following ``_awaitable_wait`` calls.
31
+
32
+
33
+ ``torch.jit._awaitable_nowait(W)``
34
+ Returns:
35
+ Trivial ``Await[W]`` with specified result.
36
+
37
+
38
+ Only in eager mode:
39
+ ``fn() -> Callable[Tuple[Any], W]``
40
+ Returns:
41
+ Specified at ``_awaitable`` python function ``func``.
42
+
43
+ ``args() -> Tuple[Any]``
44
+ Returns:
45
+ Specified at ``_awaitable`` python args.
46
+
47
+ ``is_nowait() -> _bool``
48
+ Returns:
49
+ ``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
50
+
51
+ In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
52
+ ``_awaitable_wait()`` call will be transparently added.
53
+ """
54
+ pass
torch/_awaits/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.08 kB). View file
 
torch/_classes.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+
3
+ import torch._C
4
+
5
+
6
+ class _ClassNamespace(types.ModuleType):
7
+ def __init__(self, name):
8
+ super().__init__("torch.classes" + name)
9
+ self.name = name
10
+
11
+ def __getattr__(self, attr):
12
+ proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
13
+ if proxy is None:
14
+ raise RuntimeError(f"Class {self.name}.{attr} not registered!")
15
+ return proxy
16
+
17
+
18
+ class _Classes(types.ModuleType):
19
+ __file__ = "_classes.py"
20
+
21
+ def __init__(self):
22
+ super().__init__("torch.classes")
23
+
24
+ def __getattr__(self, name):
25
+ namespace = _ClassNamespace(name)
26
+ setattr(self, name, namespace)
27
+ return namespace
28
+
29
+ @property
30
+ def loaded_libraries(self):
31
+ return torch.ops.loaded_libraries
32
+
33
+ def load_library(self, path):
34
+ """
35
+ Loads a shared library from the given path into the current process.
36
+
37
+ The library being loaded may run global initialization code to register
38
+ custom classes with the PyTorch JIT runtime. This allows dynamically
39
+ loading custom classes. For this, you should compile your class
40
+ and the static registration code into a shared library object, and then
41
+ call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
42
+ shared object.
43
+
44
+ After the library is loaded, it is added to the
45
+ ``torch.classes.loaded_libraries`` attribute, a set that may be inspected
46
+ for the paths of all libraries loaded using this function.
47
+
48
+ Args:
49
+ path (str): A path to a shared library to load.
50
+ """
51
+ torch.ops.load_library(path)
52
+
53
+
54
+ # The classes "namespace"
55
+ classes = _Classes()
torch/_compile.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ APIs related to torch.compile which lazily import torch._dynamo to avoid
3
+ circular dependencies.
4
+ """
5
+ import functools
6
+
7
+
8
+ def _disable_dynamo(fn=None, recursive=True):
9
+ """
10
+ This API should be only used inside torch, external users should still use
11
+ torch._dynamo.disable. The main goal of this API is to avoid circular
12
+ imports issues that is common while using _dynamo.disable inside torch
13
+ itself.
14
+
15
+ This API avoids it by lazily importing torch._dynamo from the import time to
16
+ the invocation of the decorated function.
17
+ """
18
+ if fn is not None:
19
+
20
+ @functools.wraps(fn)
21
+ def inner(*args, **kwargs):
22
+ import torch._dynamo
23
+
24
+ return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
25
+
26
+ return inner
27
+ else:
28
+ # decorator usage like @_disable_dynamo(recursive=False). The resulting
29
+ # object expects the original decorated function as the arg.
30
+ return functools.partial(_disable_dynamo, recursive=recursive)
torch/_custom_op/__init__.py ADDED
File without changes
torch/_custom_op/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (156 Bytes). View file
 
torch/_custom_op/__pycache__/autograd.cpython-310.pyc ADDED
Binary file (8.86 kB). View file
 
torch/_custom_op/__pycache__/functional.cpython-310.pyc ADDED
Binary file (5.95 kB). View file
 
torch/_custom_op/__pycache__/impl.cpython-310.pyc ADDED
Binary file (33.5 kB). View file
 
torch/_custom_op/autograd.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils._pytree as pytree
3
+ from collections import namedtuple
4
+ import functools
5
+
6
+
7
+ # NOTE [CustomOp autograd kernel indirection]
8
+ # We register `inner` as the autograd kernel for this custom_op.
9
+ # `inner` either calls the autograd formula registered by the user,
10
+ # or goes into an `autograd_not_implemented` kernel.
11
+ #
12
+ # The reason why this indirection exists is
13
+ # so that we can swap out the autograd kernel (the PyTorch dispatcher
14
+ # doesn't actually allow us to do this). By default, we want
15
+ # the `autograd_not_implemented` behavior, but then the user may come
16
+ # and register something that is actually a backward formula
17
+ def autograd_kernel_indirection(custom_op):
18
+ autograd_fallback = autograd_not_implemented(custom_op)
19
+
20
+ def inner(*args, **kwargs):
21
+ if custom_op._has_impl('autograd'):
22
+ kernel = custom_op._get_impl('autograd').func
23
+ return kernel(*args, **kwargs)
24
+ # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
25
+ # after the user gives us "backward" and "save_for_backward", we generate
26
+ # the "autograd" impl. If the user only provided one, then we tell
27
+ # the user they've done something wrong.
28
+ if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
29
+ missing = (
30
+ 'save_for_backward' if custom_op._has_impl('backward')
31
+ else 'backward'
32
+ )
33
+ found = 'save_for_backward' if missing == 'backward' else 'backward'
34
+ loc = custom_op._get_impl(found).location
35
+ raise RuntimeError(
36
+ f"We found a '{found}' registration for {custom_op} at "
37
+ f"{loc} but were unable to find a '{missing}' registration. "
38
+ f"To use the CustomOp API to register a backward formula, "
39
+ f"please provide us both a backward function and a "
40
+ f"'save for backward' function via `impl_backward` and "
41
+ f"`impl_save_for_backward` respectively.")
42
+ return autograd_fallback(*args, **kwargs)
43
+ return inner
44
+
45
+
46
+ # TODO(#101191): Use the actual C++ autograd not implemented fallback,
47
+ # or change the default autograd fallback to the autograd not implemented fallback.
48
+ def autograd_not_implemented(custom_op):
49
+ def kernel(*args, **kwargs):
50
+ if torch.is_grad_enabled() and pytree.tree_any(
51
+ lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
52
+ ):
53
+ raise RuntimeError("Autograd has not been implemented for operator")
54
+ with torch._C._AutoDispatchBelowAutograd():
55
+ return custom_op(*args, **kwargs)
56
+ return kernel
57
+
58
+
59
+ def mark_non_differentiable(ctx, output, output_differentiability):
60
+ # Output types are restricted to be:
61
+ # - Tensor
62
+ # - Tensor[]
63
+ # - int, bool, Scalar, float
64
+ # See _check_can_register_backward
65
+ if output_differentiability is not None:
66
+ if not isinstance(output, tuple):
67
+ tuple_output = (output,)
68
+ else:
69
+ tuple_output = output # type: ignore[assignment]
70
+ assert len(output_differentiability) == len(tuple_output)
71
+ non_differentiable_tensors = []
72
+ for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
73
+ if isinstance(out, torch.Tensor):
74
+ if not differentiable:
75
+ non_differentiable_tensors.append(out)
76
+ continue
77
+ if isinstance(out, list):
78
+ if not differentiable:
79
+ non_differentiable_tensors.extend(out)
80
+ continue
81
+ if differentiable:
82
+ raise RuntimeError(
83
+ f"With output_differentiability={output_differentiability}. "
84
+ f"At idx {idx}, we received an object of type {type(out)} that "
85
+ f"is not a Tensor, so it cannot have be marked as differentiable in "
86
+ f"output_differentiability.")
87
+ if non_differentiable_tensors:
88
+ ctx.mark_non_differentiable(*non_differentiable_tensors)
89
+
90
+
91
+ def construct_autograd_kernel(
92
+ schema,
93
+ output_differentiability,
94
+ custom_op,
95
+ op_overload,
96
+ save_for_backward_fn,
97
+ backward_fn):
98
+
99
+ def apply(*args):
100
+ flat_args, spec = pytree.tree_flatten(args)
101
+ out_spec = None
102
+
103
+ def forward(ctx, *flat_args):
104
+ ctx.set_materialize_grads(True)
105
+ args = pytree.tree_unflatten(list(flat_args), spec)
106
+ with torch._C._AutoDispatchBelowAutograd():
107
+ output = op_overload(*args)
108
+
109
+ # We use the info about args to give better error messages in backward
110
+ args_info = namedtuple_args(
111
+ schema, pytree.tree_map(type, args))
112
+
113
+ save_for_backward_fn_inputs = namedtuple_args(schema, args)
114
+ to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
115
+
116
+ save_pytree_for_backward(ctx, (to_save, args_info))
117
+ mark_non_differentiable(ctx, output, output_differentiability)
118
+
119
+ nonlocal out_spec
120
+ flat_output, out_spec = pytree.tree_flatten(output)
121
+ return tuple(flat_output)
122
+
123
+ def backward(ctx, *flat_grad_output):
124
+ assert out_spec is not None
125
+ grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
126
+ saved, args_info = unpack_saved(ctx)
127
+ # There is nothing on the ctx object for now, it is just there so
128
+ # that we can add additional things in the future.
129
+ inner_ctx = object()
130
+ if not isinstance(grads, tuple):
131
+ grads = (grads,)
132
+ grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
133
+
134
+ # Massage the grad_inputs_dict to a form acceptable by
135
+ # autograd.Function.
136
+ validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
137
+ return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
138
+
139
+ generated_cls = gen_autograd_function(
140
+ custom_op._opname + '_customop', forward, backward)
141
+
142
+ flat_output = generated_cls.apply(*flat_args)
143
+ assert out_spec is not None
144
+ return pytree.tree_unflatten(list(flat_output), out_spec)
145
+ return apply
146
+
147
+
148
+ def gen_autograd_function(name, forward, backward):
149
+ generated_cls = type(
150
+ name,
151
+ (torch.autograd.Function,),
152
+ {
153
+ 'forward': staticmethod(forward),
154
+ 'backward': staticmethod(backward),
155
+ }
156
+ )
157
+ return generated_cls
158
+
159
+
160
+ @functools.lru_cache
161
+ def namedtuple_args_cls(schema):
162
+ attribs = [arg.name for arg in schema.arguments.flat_all]
163
+ name = str(schema.name) + "_args"
164
+ # mypy doesn't support dynamic namedtuple name
165
+ tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
166
+ return tuple_cls
167
+
168
+
169
+ def namedtuple_args(schema, args):
170
+ assert isinstance(args, tuple)
171
+ tuple_cls = namedtuple_args_cls(schema)
172
+ return tuple_cls(*args)
173
+
174
+
175
+ def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
176
+ def error(what):
177
+ backward = forward_op._get_impl('backward')
178
+ raise RuntimeError(
179
+ f"In the backward function defined for {forward_op} at "
180
+ f"{backward.location} using the CustomOp API, {what}")
181
+
182
+ if not isinstance(grad_inputs_dict, dict):
183
+ error(f"expected the output of the backward function to be a dict but "
184
+ f"got {type(grad_inputs_dict)}")
185
+
186
+ expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
187
+ if arg.type.is_tensor_like()}
188
+ actual_keys = grad_inputs_dict.keys()
189
+ if expected_keys != actual_keys:
190
+ error(f"expected the returned grad_input dict to have keys "
191
+ f"{expected_keys} but got {actual_keys}. The backward "
192
+ f"function must return a gradient (can be None) for each arg "
193
+ f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
194
+ f"Args declared to be non-Tensor-like types should not appear "
195
+ f"in the grad_input dict")
196
+
197
+ for name, grad in grad_inputs_dict.items():
198
+ arg_info = getattr(args_info, name)
199
+
200
+ if isinstance(arg_info, list):
201
+ if not isinstance(grad, (tuple, list)):
202
+ error(f"for input '{name}' expected the grad_input dict to "
203
+ f"hold a list of gradients but got object of type "
204
+ f"{type(grad)}.")
205
+ if not len(grad) == len(arg_info):
206
+ error(f"for input '{name}' expected the grad_input dict to "
207
+ f"hold a list of {len(arg_info)} gradients but got "
208
+ f"{len(grad)}")
209
+ for idx, (g, info) in enumerate(zip(grad, arg_info)):
210
+ if g is None:
211
+ continue
212
+ if not isinstance(g, torch.Tensor):
213
+ error(f"for input '{name}' expected the grad_input dict to "
214
+ f"hold a list of None or Tensor gradients but got "
215
+ f"object of {type(g)} at index {idx}")
216
+ if not issubclass(info, torch.Tensor):
217
+ error(f"for input '{name}', got a Tensor as the gradient "
218
+ f"for the {idx}-th value but expected None because "
219
+ f"the {idx}-th value was not a Tensor (it was "
220
+ f"type {arg_info}")
221
+ continue
222
+
223
+ if grad is None:
224
+ continue
225
+ if not isinstance(grad, torch.Tensor):
226
+ error(f"got object of type {type(grad)} as the gradient for input "
227
+ f"'{name}', "
228
+ f"but expected the gradient to be either None or a Tensor")
229
+ if not issubclass(arg_info, torch.Tensor):
230
+ error(f"got a Tensor as the gradient for input '{name}' but "
231
+ f"expected None as the gradient because input '{name}' "
232
+ f"was not a Tensor (it was type {arg_info}).")
233
+
234
+
235
+ def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
236
+ result = []
237
+ for name, arg_info in args_info._asdict().items():
238
+ if name not in grad_inputs_dict:
239
+ result.append(pytree.tree_map(lambda x: None, arg_info))
240
+ continue
241
+ result.append(grad_inputs_dict[name])
242
+ return tuple(pytree.tree_leaves(result))
243
+
244
+ # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
245
+ # autograd.Function prefers that users use ctx.save_for_backward to
246
+ # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
247
+ # ctx object.
248
+ def save_pytree_for_backward(ctx, stuff):
249
+ flat_stuff, spec = pytree.tree_flatten(stuff)
250
+ num_elts = len(flat_stuff)
251
+ tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
252
+ if isinstance(thing, torch.Tensor)]
253
+ non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
254
+ if not isinstance(thing, torch.Tensor)]
255
+ tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
256
+ non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
257
+
258
+ ctx.spec = spec
259
+ ctx.num_elts = num_elts
260
+ ctx.save_for_backward(*tensors)
261
+ ctx.tensor_idxs = tensor_idxs
262
+ ctx.saved_non_tensors = non_tensors
263
+ ctx.non_tensor_idxs = non_tensor_idxs
264
+
265
+
266
+ # Inverse operation to save_pytree_for_backward
267
+ def unpack_saved(ctx):
268
+ flat_stuff = [None] * ctx.num_elts
269
+ for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
270
+ flat_stuff[idx] = tensor
271
+ for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
272
+ flat_stuff[idx] = non_tensor
273
+ stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
274
+ return stuff
torch/_custom_op/functional.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weakref
2
+
3
+ import torch
4
+ import torch.utils._pytree as pytree
5
+ from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
6
+ from torch._ops import OpOverload
7
+ from torch.library import Library
8
+ from torchgen.model import (
9
+ BaseTy,
10
+ BaseType,
11
+ FunctionSchema,
12
+ OperatorName,
13
+ OptionalType,
14
+ SchemaKind,
15
+ )
16
+
17
+ from .autograd import autograd_not_implemented
18
+
19
+
20
+ def register_functional_op(
21
+ lib: Library,
22
+ new_op_name: str,
23
+ mutable_op: OpOverload,
24
+ ) -> None:
25
+ """Given a mutable operator, registers the functional variant.
26
+
27
+ This API also correctly links the functional variant with the mutable
28
+ operator for the purposes of functionalization.
29
+
30
+ All of the new registrations are performed on the ``lib`` passed in.
31
+
32
+ Arguments:
33
+ lib (Library): Should be a torch.library.Library object that has
34
+ the same namespace as ``mutable_op``'s namespace.
35
+ lib will be used to register the new functional op as well
36
+ as a functionalization kernel for the ``mutable_op``
37
+ If you don't have a library handy, use
38
+ ``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
39
+ new_op_name (str): The name of the functional operator (without the
40
+ namespace). If no namespace, the new functional variant will be
41
+ accessible under ``torch.ops.{lib.ns}.new_op_name``.
42
+ mutable_op (OpOverload): The mutable custom operator. Note
43
+ that you may need to add a `.default` to it, like
44
+ `torch.ops.aten.abs_.default`.
45
+
46
+ """
47
+ validate(mutable_op)
48
+ schema = functional_schema(new_op_name, mutable_op)
49
+ lib.define(schema)
50
+
51
+ functional_impl = construct_functional_impl(mutable_op)
52
+ lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
53
+
54
+ functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
55
+
56
+ # There's no easy way for us to generate the autograd kernel, so we
57
+ # use autograd_not_implemented. Also, this makes it so that the user
58
+ # is unable to register an autograd formula themselves. This shouldn't
59
+ # be a problem if the user doesn't use the functional op direclty
60
+ # in their program, but we may need to revist this in the future.
61
+ lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
62
+
63
+ f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
64
+
65
+ lib.impl(mutable_op, f_kernel, 'Functionalize')
66
+
67
+
68
+ def construct_functional_impl(mutable_op):
69
+ def functional_impl(*args):
70
+ # Strategy:
71
+ # - clone args that would have been mutated
72
+ # - run mutable_op
73
+ # - return the cloned args as additional outputs
74
+ new_args = []
75
+ extra_rets = []
76
+ for is_write, arg in zip(mutable_args(mutable_op), args):
77
+ if is_write:
78
+ cloned = arg.clone() if arg is not None else None
79
+ new_args.append(cloned)
80
+ extra_rets.append(cloned)
81
+ else:
82
+ new_args.append(arg)
83
+ result = mutable_op(*new_args)
84
+ if result is None:
85
+ return tuple(extra_rets)
86
+ if isinstance(result, tuple):
87
+ return (*result, *extra_rets)
88
+ return (result, *extra_rets)
89
+ return functional_impl
90
+
91
+
92
+ def construct_functionalization_kernel(mutable_op, functional_op):
93
+ def kernel(*args):
94
+ # There's nothing to be functionalized!
95
+ # We can still end up here because DispatchKey::Functionalize is a mode key
96
+ if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
97
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
98
+ return mutable_op(*args)
99
+
100
+ # NB: This differs from the codegen -- codegen handles cases where there
101
+ # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
102
+ # This only really matters for XLA (mixed CPU-XLA tensors) and
103
+ # running functionalization without the PT2 stack (which guarantees to us that
104
+ # all tensors are FunctionalTensorWrapper).
105
+ if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
106
+ raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
107
+
108
+ unwrapped_args = []
109
+ for arg in args:
110
+ if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
111
+ torch._sync(arg)
112
+ unwrapped = torch._from_functional_tensor(arg)
113
+ unwrapped_args.append(unwrapped)
114
+ else:
115
+ unwrapped_args.append(arg)
116
+
117
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
118
+ output = functional_op(*unwrapped_args)
119
+
120
+ num_actual_output = len(mutable_op._schema.returns)
121
+ actual_output = pytree.tree_map(
122
+ torch._to_functional_tensor, output[:num_actual_output])
123
+
124
+ new_values_to_propagate = output[num_actual_output:]
125
+ inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
126
+ if is_write]
127
+ assert len(new_values_to_propagate) == len(inputs_to_replace)
128
+ for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
129
+ if (arg is None and new_value is None) or (arg is not None and new_value is not None):
130
+ continue
131
+ torch._C._propagate_xla_data(arg, new_value)
132
+ torch._C._replace_(arg, new_value)
133
+ torch._C._commit_update(arg)
134
+ torch._sync(arg)
135
+
136
+ if len(actual_output) == 1:
137
+ return actual_output[0]
138
+ elif len(actual_output) == 0:
139
+ return None
140
+ return actual_output
141
+
142
+ return kernel
143
+
144
+
145
+ def validate(mutable_op: OpOverload):
146
+ if not isinstance(mutable_op, OpOverload):
147
+ raise TypeError(
148
+ f"register_functional_op(mutable_op): expected mutable_op to be instance of "
149
+ f"OpOverload but got {type(mutable_op)}")
150
+
151
+ # There are generally three types of "in-place" or "mutable" ops.
152
+ # Each of them have their own conventions:
153
+ # - inplace (first input modified in-place and returned as only output)
154
+ # - out= (some args modified in-place and returned as outputs)
155
+ # - mutable (some args modified in-place but none of those returned as outputs)
156
+ # In theory we can support all three, but we'll just support the last
157
+ # option right now for simplicity.
158
+ schema = FunctionSchema.parse(str(mutable_op._schema))
159
+ if not schema.kind() == SchemaKind.mutable:
160
+ raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
161
+ for ret in schema.returns:
162
+ # construct_functionalization_kernel assumes this for simplicity
163
+ if ret.annotation is not None:
164
+ raise NotImplementedError(
165
+ "NYI: register_functional_op(op) where op returns a mutated or aliased value. "
166
+ "Please file an issue (and as a workaround, modify your operator to "
167
+ "not return the mutated value or aliases)")
168
+ for arg in schema.arguments.flat_all:
169
+ # construct_functionalization_kernel assumes this for simplicity
170
+ if arg.type.is_tensor_like() and (
171
+ arg.type != BaseType(BaseTy.Tensor)
172
+ and arg.type != OptionalType(BaseType(BaseTy.Tensor))
173
+ ):
174
+ raise NotImplementedError(
175
+ "NYI: register_functional_op(op) where op has a List[Tensor] input."
176
+ "Please file an issue.")
177
+
178
+
179
+ def functional_schema(new_op_name, op: OpOverload):
180
+ schema = FunctionSchema.parse(str(op._schema))
181
+ schema = schema.signature().with_name(OperatorName.parse(new_op_name))
182
+ return str(schema)
183
+
184
+
185
+ def mutable_args(op: OpOverload):
186
+ return tuple(False if arg.alias_info is None else arg.alias_info.is_write
187
+ for arg in op._schema.arguments)
torch/_custom_op/impl.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import functools
3
+ import inspect
4
+ import sys
5
+ import typing
6
+ import weakref
7
+
8
+ from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
9
+
10
+ import torch
11
+ import torch._C as _C
12
+ import torch.library as library
13
+ from torch._library.abstract_impl import AbstractImplCtx
14
+ from torch.library import get_ctx
15
+
16
+ from .autograd import autograd_kernel_indirection, construct_autograd_kernel
17
+
18
+ """
19
+ For a detailed guide on custom ops, please see
20
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
21
+
22
+ This file includes pieces of the implementation of our custom operator API.
23
+ """
24
+
25
+ __all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
26
+
27
+
28
+ SUPPORTED_DEVICE_TYPE_TO_KEY = {
29
+ "cpu": "CPU",
30
+ "cuda": "CUDA",
31
+ }
32
+
33
+ # We will not let users register CustomOps with anything that could look like
34
+ # PyTorch internals to avoid confusion.
35
+ RESERVED_NS = {
36
+ "prim",
37
+ "prims",
38
+ "aten",
39
+ "at",
40
+ "torch",
41
+ "pytorch",
42
+ }
43
+
44
+
45
+ def custom_op(
46
+ qualname: str, manual_schema: typing.Optional[str] = None
47
+ ) -> typing.Callable:
48
+ r"""Creates a new CustomOp object.
49
+
50
+ WARNING: if you're a user, please do not use this directly
51
+ (instead use the torch._custom_ops APIs).
52
+ Also please see the following for a detailed guide on custom ops.
53
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
54
+
55
+ In PyTorch, defining an op (short for "operator") is a two step-process:
56
+ - we need to define (create) the op
57
+ - we need to implement behavior for how the operator interacts with
58
+ various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
59
+
60
+ This entrypoint defines the CustomOp object (the first step);
61
+ you must then perform the second step by calling various methods on
62
+ the CustomOp object.
63
+
64
+ This API is used as a decorator (see examples).
65
+
66
+ Arguments:
67
+ qualname (str): Should be a string that looks like
68
+ "namespace::operator_name". Operators in PyTorch need a namespace to
69
+ avoid name collisions; a given operator may only be created once.
70
+ If you are writing a Python library, we recommend the namespace to
71
+ be the name of your top-level module. The operator_name must be
72
+ the same as the name of the function you pass to custom_op
73
+ (see examples).
74
+ manual_schema (Optional[str]): Each PyTorch operator needs a schema that
75
+ tells PyTorch the types of the inputs/outputs. If None (default),
76
+ we will infer the schema from the type annotations on the function
77
+ (see examples). Otherwise, if you don't want to use type annotations,
78
+ you may provide us the schema string.
79
+
80
+ Example::
81
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
82
+ >>> import numpy as np
83
+ >>> from torch import Tensor
84
+ >>>
85
+ >>> # Step 1: define the CustomOp.
86
+ >>> # We need to provide the decorator a "prototype function"
87
+ >>> # (a function with Python ellipses as the body).
88
+ >>> @custom_op("my_library::numpy_sin")
89
+ >>> def numpy_sin(x: Tensor) -> Tensor:
90
+ >>> ...
91
+ >>>
92
+ >>> # numpy_sin is now an instance of class CustomOp
93
+ >>> print(type(numpy_sin))
94
+ >>>
95
+ >>> # Step 2: Register an implementation for various PyTorch subsystems
96
+ >>>
97
+ >>> # Register an implementation for CPU tensors
98
+ >>> @numpy_sin.impl('cpu')
99
+ >>> def numpy_sin_impl_cpu(x):
100
+ >>> return torch.from_numpy(np.sin(x.numpy()))
101
+ >>>
102
+ >>> # Register an implementation for CUDA tensors
103
+ >>> @numpy_sin.impl('cuda')
104
+ >>> def numpy_sin_impl_cuda(x):
105
+ >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
106
+ >>>
107
+ >>> x = torch.randn(3)
108
+ >>> numpy_sin(x) # calls numpy_sin_impl_cpu
109
+ >>>
110
+ >>> x_cuda = x.cuda()
111
+ >>> numpy_sin(x) # calls numpy_sin_impl_cuda
112
+
113
+ """
114
+
115
+ def inner(func):
116
+ if not inspect.isfunction(func):
117
+ raise ValueError(
118
+ f"custom_op(...)(func): Expected `func` to be a Python "
119
+ f"function, got: {type(func)}"
120
+ )
121
+
122
+ ns, name = parse_qualname(qualname)
123
+ validate_namespace(ns)
124
+ if func.__name__ != name:
125
+ raise ValueError(
126
+ f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
127
+ f"to have name '{name}' but got '{func.__name__}'. "
128
+ f"Please either change the name of `func` or the qualname that "
129
+ f"is passed to `custom_op`"
130
+ )
131
+
132
+ schema = infer_schema(func) if manual_schema is None else manual_schema
133
+ schema_str = f"{name}{schema}"
134
+ function_schema = FunctionSchema.parse(schema_str)
135
+ validate_schema(function_schema)
136
+ if manual_schema is not None:
137
+ validate_function_matches_schema(function_schema, func)
138
+
139
+ lib = library.Library(ns, "FRAGMENT")
140
+ lib.define(schema_str)
141
+ ophandle = find_ophandle_or_throw(ns, function_schema.name)
142
+ result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
143
+
144
+ result.__name__ = func.__name__
145
+ result.__module__ = func.__module__
146
+ result.__doc__ = func.__doc__
147
+
148
+ library.impl(lib, result._opname, "Autograd")(
149
+ autograd_kernel_indirection(weakref.proxy(result))
150
+ )
151
+
152
+ torch._C._dispatch_set_report_error_callback(
153
+ ophandle, functools.partial(report_error_callback, weakref.proxy(result))
154
+ )
155
+
156
+ return result
157
+
158
+ return inner
159
+
160
+
161
+ # Global dictionary holding references to all CustomOp objects
162
+ # Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
163
+ # Used to query the CustomOp associated with a specific C++ dispatcher operator.
164
+ # An example usage is FakeTensor: FakeTensor checks if a specific operator
165
+ # has an implementation registered via the CustomOp API.
166
+ # Indexed by qualname (e.g. aten::foo)
167
+ global_registry: typing.Dict[str, "CustomOp"] = {}
168
+
169
+
170
+ class CustomOp:
171
+ r"""Class for custom operators in PyTorch.
172
+
173
+ Use the CustomOp API to create user-defined custom operators that behave
174
+ just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
175
+ comes to various PyTorch subsystems (like torch.compile).
176
+
177
+ To construct a `CustomOp`, use `custom_op`.
178
+ """
179
+
180
+ def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
181
+ super().__init__()
182
+ if not _private_access:
183
+ raise RuntimeError(
184
+ "The CustomOp constructor is private and we do not guarantee "
185
+ "BC for it. Please use custom_op(...) to create a CustomOp object"
186
+ )
187
+ name = f"{cpp_ns}::{operator_name}"
188
+ self._schema = schema
189
+ self._cpp_ns = cpp_ns
190
+ self._lib: library.Library = lib
191
+ self._ophandle: _C._DispatchOperatorHandle = ophandle
192
+ # Has the name of the op, e.g. "foo". We cache here for convenience.
193
+ self._opname: str = operator_name
194
+ # this is _opname but with namespace. e.g. "custom::foo"
195
+ self._qualname: str = name
196
+ self.__name__ = None # mypy requires this
197
+ # NB: Some of these impls are registered as kernels to DispatchKeys.
198
+ # Modifying the _impls dict directly won't do anything in that case.
199
+ self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
200
+ # See NOTE [CustomOp autograd kernel indirection]
201
+ self._registered_autograd_kernel_indirection = False
202
+
203
+ global_registry[self._qualname] = self
204
+
205
+ def _register_autograd_kernel_indirection(self):
206
+ assert not self._registered_autograd_kernel_indirection
207
+ self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
208
+ self._registered_autograd_kernel_indirection = True
209
+
210
+ # Records the impl and the source location in self._impls
211
+ # Note that this doesn't cause torch.library to use the impl, that
212
+ # needs to be done in a separate self._lib.impl call.
213
+ def _register_impl(self, kind, func, stacklevel=2):
214
+ if self._has_impl(kind):
215
+ func_and_location = self._impls[kind]
216
+ assert func_and_location is not None # Pacify mypy
217
+ location = func_and_location.location
218
+ raise RuntimeError(
219
+ f"Attempting to register a {kind} impl for operator {self._qualname} "
220
+ f"that already has a {kind} impl registered from Python at "
221
+ f"{location}. This is not supported."
222
+ )
223
+ frame = inspect.getframeinfo(sys._getframe(stacklevel))
224
+ location = f"{frame.filename}:{frame.lineno}"
225
+ self._impls[kind] = FuncAndLocation(func, location)
226
+
227
+ def _get_impl(self, kind):
228
+ return self._impls[kind]
229
+
230
+ def _has_impl(self, kind):
231
+ return kind in self._impls
232
+
233
+ def _destroy(self):
234
+ # NOTE: [CustomOp lifetime]
235
+ # A CustomOp, once created, lives forever. The mechanism is that the
236
+ # global registry holds a reference to it. However, to make testing
237
+ # easier, we want to be able to destroy CustomOp objects.
238
+ # CustomOp._destroy does the job, though it leaves the CustomOp
239
+ # in a garbage state.
240
+ del self._lib
241
+
242
+ opnamespace = getattr(torch.ops, self._cpp_ns)
243
+ if hasattr(opnamespace, self._opname):
244
+ delattr(opnamespace, self._opname)
245
+
246
+ del global_registry[self._qualname]
247
+
248
+ def __repr__(self):
249
+ return f'<CustomOp(op="{self._qualname}")>'
250
+
251
+ def __call__(self, *args, **kwargs):
252
+ # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
253
+ # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
254
+ # issues from caching operators that make testing CustomOp difficult).
255
+ result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
256
+ return result
257
+
258
+ def impl(
259
+ self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
260
+ ) -> typing.Callable:
261
+ r"""Register an implementation for a device type for this CustomOp object.
262
+
263
+ WARNING: if you're a user, please do not use this directly
264
+ (instead use the torch._custom_ops APIs).
265
+ Also please see the following for a detailed guide on custom ops.
266
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
267
+
268
+ If the CustomOp is passed multiple Tensor inputs with different device
269
+ types, it will dispatch to the registered implementation for the highest
270
+ priority device type among those present.
271
+ The supported device types, in order of priority, are {'cuda', 'cpu'}.
272
+
273
+ This API is used as a decorator (see examples).
274
+
275
+ Arguments:
276
+ device_types (str or Iterable[str]): the device type(s) to register the function for.
277
+
278
+ Examples::
279
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
280
+ >>> import numpy as np
281
+ >>> from torch import Tensor
282
+ >>>
283
+ >>> @custom_op("my_library::numpy_cos")
284
+ >>> def numpy_cos(x: Tensor) -> Tensor:
285
+ >>> ...
286
+ >>>
287
+ >>> # Register an implementation for CPU Tensors
288
+ >>> @numpy_cos.impl('cpu')
289
+ >>> def numpy_cos_impl_cpu(x):
290
+ >>> return torch.from_numpy(np.cos(x.numpy()))
291
+ >>>
292
+ >>> # Register an implementation for CUDA Tensors
293
+ >>> @numpy_cos.impl('cuda')
294
+ >>> def numpy_cos_impl_cuda(x):
295
+ >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
296
+ >>>
297
+ >>> x = torch.randn(3)
298
+ >>> numpy_cos(x) # calls numpy_cos_impl_cpu
299
+ >>>
300
+ >>> x_cuda = x.cuda()
301
+ >>> numpy_cos(x) # calls numpy_cos_impl_cuda
302
+
303
+ """
304
+ if isinstance(device_types, str):
305
+ device_types = [device_types]
306
+ for device_type in device_types:
307
+ validate_device_type(device_type)
308
+
309
+ def inner(f):
310
+ for device_type in set(device_types):
311
+ self._check_doesnt_have_library_impl(device_type)
312
+ self._register_impl(device_type, f, stacklevel=_stacklevel)
313
+ dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
314
+ library.impl(self._lib, self._opname, dispatch_key)(f)
315
+ return f
316
+
317
+ return inner
318
+
319
+ def _check_doesnt_have_library_impl(self, device_type):
320
+ if self._has_impl(device_type):
321
+ return
322
+ key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
323
+ if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
324
+ raise RuntimeError(
325
+ f"impl(..., device_types={device_type}): the operator {self._qualname} "
326
+ f"already has an implementation for this device type via a "
327
+ f"pre-existing torch.library or TORCH_LIBRARY registration.")
328
+
329
+ def impl_factory(self) -> typing.Callable:
330
+ r"""Register an implementation for a factory function."""
331
+
332
+ def inner(f):
333
+ self._register_impl("factory", f)
334
+ library.impl(self._lib, self._opname, "BackendSelect")(f)
335
+ return f
336
+
337
+ return inner
338
+
339
+ def impl_abstract(self, _stacklevel=2) -> typing.Callable:
340
+ r"""Register an abstract implementation for this operator.
341
+
342
+ WARNING: please do not use this directly (and instead use the torch._custom_ops
343
+ APIs). Also please see the following for a detailed guide on custom ops.
344
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
345
+
346
+ An "abstract implementation" specifies the behavior of this operator on
347
+ Tensors that carry no data. Given some input Tensors with certain properties
348
+ (sizes/strides/storage_offset/device), it specifies what the properties of
349
+ the output Tensors are.
350
+
351
+ The abstract implementation has the same signature as the operator.
352
+ It is run for both FakeTensors and meta tensors. To write an abstract
353
+ implementation, assume that all Tensor inputs to the operator are
354
+ regular CPU/CUDA/Meta tensors, but they do not have storage, and
355
+ you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
356
+ The abstract implementation must consist of only PyTorch operations
357
+ (and may not directly access the storage or data of any input or
358
+ intermediate Tensors).
359
+
360
+ This API is used as a decorator (see examples).
361
+
362
+ Examples::
363
+ >>> import numpy as np
364
+ >>> from torch import Tensor
365
+ >>>
366
+ >>> # Example 1: an operator without data-dependent output shape
367
+ >>> @custom_op('my_library::custom_linear')
368
+ >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
369
+ >>> ...
370
+ >>>
371
+ >>> @custom_linear.impl_abstract()
372
+ >>> def custom_linear_abstract(x, weight):
373
+ >>> assert x.dim() == 2
374
+ >>> assert weight.dim() == 2
375
+ >>> assert bias.dim() == 1
376
+ >>> assert x.shape[1] == weight.shape[1]
377
+ >>> assert weight.shape[0] == bias.shape[0]
378
+ >>> assert x.device == weight.device
379
+ >>>
380
+ >>> return (x @ weight.t()) + bias
381
+ >>>
382
+ >>> # Example 2: an operator with data-dependent output shape
383
+ >>> @custom_op('my_library::custom_nonzero')
384
+ >>> def custom_nonzero(x: Tensor) -> Tensor:
385
+ >>> ...
386
+ >>>
387
+ >>> @custom_nonzero.impl_abstract()
388
+ >>> def custom_nonzero_abstract(x):
389
+ >>> # Number of nonzero-elements is data-dependent.
390
+ >>> # Since we cannot peek at the data in an abstract impl,
391
+ >>> # we use the ctx object to construct a new symint that
392
+ >>> # represents the data-dependent size.
393
+ >>> ctx = torch._custom_op.get_ctx()
394
+ >>> nnz = ctx.create_unbacked_symint()
395
+ >>> shape = [x.dim(), nnz]
396
+ >>> result = x.new_empty(shape, dtype=torch.long)
397
+ >>> return result
398
+ >>>
399
+ >>> @custom_nonzero.impl(['cpu', 'cuda'])
400
+ >>> def custom_nonzero_impl(x):
401
+ >>> x_np = to_numpy(x)
402
+ >>> res = np.stack(np.nonzero(x_np), axis=1)
403
+ >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
404
+ >>> # constrain the range to at least 2
405
+ >>> if res.shape[0] <= 1:
406
+ >>> raise RuntimeError("not supported")
407
+ >>> return torch.tensor(res, device=x.device)
408
+
409
+ """
410
+
411
+ def inner(f):
412
+ self._check_doesnt_have_library_meta_impl()
413
+ self._register_impl("abstract", f, stacklevel=_stacklevel)
414
+ location = self._get_impl("abstract").location
415
+
416
+ qualname = self._qualname
417
+
418
+ # Handle DispatchKey.Meta registration
419
+ @functools.wraps(f)
420
+ def f_with_ctx(*args, **kwargs):
421
+ def error_on_ctx():
422
+ raise RuntimeError(
423
+ f"Attempted to call get_ctx() for the meta implementation "
424
+ f"for {qualname}."
425
+ f"You have presumably called get_ctx() because the operator "
426
+ f"has a data-dependent output shape; if so, there is no "
427
+ f"such meta implementation and this error is the correct "
428
+ f"behavior. Otherwise, please remove the call to get_ctx() "
429
+ f"in the implementation registered with impl_abstract "
430
+ f"at {location}"
431
+ )
432
+
433
+ with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
434
+ return f(*args, **kwargs)
435
+
436
+ self._lib.impl(self._opname, f_with_ctx, "Meta")
437
+ return f
438
+
439
+ return inner
440
+
441
+ def _check_can_register_backward(self):
442
+ def error(detail):
443
+ raise RuntimeError(
444
+ f"Cannot use torch._custom_ops APIs to register backward "
445
+ f"formula for {detail}. Got operator "
446
+ f"{self._qualname} with schema: {schema}"
447
+ )
448
+
449
+ schema = self._schema
450
+ if schema.kind() != SchemaKind.functional:
451
+ error("non-functional operator")
452
+
453
+ rets = schema.returns
454
+ if not schema.returns:
455
+ error("operator with no returns")
456
+
457
+ assert len(rets) > 0
458
+ is_non_mutating_view = any(
459
+ r.annotation is not None and not r.annotation.is_write for r in rets
460
+ )
461
+ if is_non_mutating_view:
462
+ error("operator that returns views")
463
+
464
+ # We make assumptions about the schema's return types.
465
+ allowed_return_types = {
466
+ BaseType(BaseTy.int): "int",
467
+ BaseType(BaseTy.SymInt): "SymInt",
468
+ BaseType(BaseTy.bool): "bool",
469
+ BaseType(BaseTy.float): "float",
470
+ BaseType(BaseTy.Tensor): "Tensor",
471
+ ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
472
+ }
473
+ for ret in schema.returns:
474
+ if ret.type in allowed_return_types:
475
+ continue
476
+ error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
477
+
478
+ def _check_doesnt_have_library_autograd_impl(self):
479
+ if self._registered_autograd_kernel_indirection:
480
+ return
481
+
482
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
483
+ raise RuntimeError(
484
+ f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
485
+ f"already has an implementation for this device type via a "
486
+ f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
487
+ f"CompositeImplicitAutograd operators do not need an autograd formula; "
488
+ f"instead, the operator will decompose into its constituents and those "
489
+ f"can have autograd formulas defined on them.")
490
+
491
+ # We can improve this by adding "all Autograd<BACKEND> keys", but
492
+ # realistically people will just be using this API for CPU/CUDA for now.
493
+ for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
494
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
495
+ raise RuntimeError(
496
+ f"impl_backward/impl_save_for_backward: "
497
+ f"the operator {self._qualname} already has an Autograd kernel "
498
+ f"registered to DispatchKey::{key} vi a pre-existing "
499
+ f"torch.library or TORCH_LIBRARY registration. Please either "
500
+ f"remove those registrations or don't use the torch._custom_ops APIs")
501
+
502
+ def _check_doesnt_have_library_meta_impl(self):
503
+ if self._has_impl("abstract"):
504
+ return
505
+
506
+ # If the user's operator is CompositeExplicitAutograd,
507
+ # allow them to impl_abstract. This is being pragmatic
508
+ # (existing custom ops may have CompositeExplicitAutograd
509
+ # registration that don't work with Meta kernels, so this
510
+ # gives them an escape hatch).
511
+ if (
512
+ _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
513
+ and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
514
+ ):
515
+ return
516
+
517
+ # Otherwise, if the user's already has a Meta kernel or their
518
+ # op is CompositeImplicitAutograd or some other alias dispatch key,
519
+ # raise.
520
+
521
+ # Special case for CompositeImplicitAutograd
522
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
523
+ raise RuntimeError(
524
+ f"impl_abstract(...): the operator {self._qualname} "
525
+ f"already has an implementation for this device type via a "
526
+ f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
527
+ f"CompositeImplicitAutograd operators do not need an abstract impl; "
528
+ f"instead, the operator will decompose into its constituents and those "
529
+ f"can have abstract impls defined on them.")
530
+
531
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
532
+ raise RuntimeError(
533
+ f"impl_abstract(...): the operator {self._qualname} "
534
+ f"already has an DispatchKey::Meta implementation via a "
535
+ f"pre-existing torch.library or TORCH_LIBRARY registration. "
536
+ f"Please either remove that registration or don't call impl_abstract.")
537
+
538
+ # NOTE ["backward", "save_for_backward", and "autograd"]
539
+ # As a part of the explicit autograd API, a user must provide us
540
+ # a "save_for_backward" function and a "backward" function.
541
+ # When both of these have been provided, then we automatically
542
+ # construct the "autograd" kernel.
543
+ def _register_autograd_kernel(self):
544
+ assert self._has_impl("backward")
545
+ assert self._has_impl("save_for_backward")
546
+ kernel = construct_autograd_kernel(
547
+ self._schema,
548
+ self._output_differentiability,
549
+ self,
550
+ get_op(self._qualname),
551
+ self._get_impl("save_for_backward").func,
552
+ self._get_impl("backward").func)
553
+ self._register_impl("autograd", kernel)
554
+
555
+ def impl_save_for_backward(self, _stacklevel=2):
556
+ r"""Register a function that tells us what to save for backward.
557
+
558
+ Please see impl_backward for more details.
559
+ """
560
+ def inner(f):
561
+ self._check_can_register_backward()
562
+ self._check_doesnt_have_library_autograd_impl()
563
+ if not self._registered_autograd_kernel_indirection:
564
+ self._register_autograd_kernel_indirection()
565
+ self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
566
+ if self._has_impl("backward"):
567
+ self._register_autograd_kernel()
568
+ return inner
569
+
570
+ def impl_backward(self, output_differentiability=None, _stacklevel=2):
571
+ r"""Registers a backward formula.
572
+
573
+ WARNING: if you're a user, please do not use this directly
574
+ (instead use the torch._custom_ops APIs).
575
+ Also please see the following for a detailed guide on custom ops.
576
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
577
+
578
+ In order for the CustomOp to work with autograd, you need to register
579
+ a backward formula. There are two pieces to this:
580
+ 1. You must give us a function to specify what to save for backward.
581
+ Call this the "save for backward" function.
582
+ 2. You must give us a function that computes gradients. Call this the
583
+ "backward" function.
584
+
585
+ Use `impl_save_for_backward` to define a "save for backward" function
586
+ that specifies what gets saved for backward. The function should accept
587
+ two arguments ``(inputs, output)`` and return the quantities to be saved
588
+ for backward.
589
+
590
+ During runtime, when you call the CustomOp, PyTorch will invoke the
591
+ "save for backward" function with the inputs and output of the CustomOp.
592
+
593
+ Use `impl_backward` to define the "backward" function. The backward
594
+ function must accept ``(ctx, saved, *grads)``:
595
+ - ``ctx`` is a context object where we may provide information
596
+ - ``saved`` is exactly what gets returned from the "save for backward"
597
+ function
598
+ - ``grads`` is one or more gradients. The number of gradients matches
599
+ the number of outputs of the CustomOp.
600
+
601
+ The backward function must return a dict that maps the name of
602
+ an input to the CustomOp to its corresponding gradient. All inputs that
603
+ were declared to be Tensors in the CustomOp definition must be accounted
604
+ for in the dict. The gradient may be a Tensor or None.
605
+
606
+ """
607
+ if output_differentiability is not None:
608
+ def yell():
609
+ raise RuntimeError(
610
+ f"impl_backward(output_differentiability): expected "
611
+ f"output_differentiability to be a list of bools with "
612
+ f"length equal to the number of outputs of this CustomOp "
613
+ f"got: {output_differentiability}")
614
+
615
+ if not isinstance(output_differentiability, list):
616
+ yell()
617
+ for diff in output_differentiability:
618
+ if not isinstance(diff, bool):
619
+ yell()
620
+ if len(self._schema.returns) != len(output_differentiability):
621
+ yell()
622
+
623
+ def inner(f):
624
+ self._check_can_register_backward()
625
+ self._check_doesnt_have_library_autograd_impl()
626
+ if not self._registered_autograd_kernel_indirection:
627
+ self._register_autograd_kernel_indirection()
628
+ self._register_impl("backward", f, stacklevel=_stacklevel)
629
+ self._output_differentiability = output_differentiability
630
+ if self._has_impl("save_for_backward"):
631
+ self._register_autograd_kernel()
632
+ return inner
633
+
634
+
635
+ @dataclasses.dataclass
636
+ class FuncAndLocation:
637
+ func: typing.Callable
638
+ location: str
639
+
640
+
641
+ def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
642
+ overload_name = (
643
+ "" if operator_name.overload_name is None else operator_name.overload_name
644
+ )
645
+ return _C._dispatch_find_schema_or_throw(
646
+ f"{cpp_ns}::{str(operator_name.name)}", overload_name
647
+ )
648
+
649
+
650
+ def validate_namespace(ns: str) -> None:
651
+ if "." in ns:
652
+ raise ValueError(
653
+ f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
654
+ f"valid variable name)"
655
+ )
656
+ if ns in RESERVED_NS:
657
+ raise ValueError(
658
+ f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
659
+ f"please choose something else. "
660
+ )
661
+
662
+ def validate_schema(schema: FunctionSchema) -> None:
663
+ if not torch._library.utils.is_functional_schema(schema):
664
+ raise ValueError(
665
+ f"custom_op only supports functional operators "
666
+ f"(ops that do not mutate any inputs, do not return "
667
+ f"views of the inputs, and has at least one return). "
668
+ f"Got the following non-functional schema: {schema}"
669
+ )
670
+
671
+ # For simplicity: don't allow self arguments
672
+ if schema.arguments.self_arg is not None:
673
+ raise ValueError(
674
+ f"custom_op does not support arguments named 'self'. Please "
675
+ f"rename your argument. Got: {schema}"
676
+ )
677
+
678
+
679
+ def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
680
+ names = qualname.split("::", 1)
681
+ if len(names) != 2:
682
+ raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
683
+ f"operator name should look something like ns::foo")
684
+ if '.' in names[1]:
685
+ raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
686
+ f"i.e. operator names with '.' in them. "
687
+ f"Please name your operator something like ns::foo. "
688
+ f"Got: {qualname}")
689
+ return names[0], names[1]
690
+
691
+
692
+ def validate_device_type(device_type: str) -> None:
693
+ if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
694
+ raise ValueError(
695
+ f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
696
+ f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
697
+ )
698
+
699
+
700
+ def supported_param(param: inspect.Parameter) -> bool:
701
+ return param.kind in (
702
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
703
+ inspect.Parameter.KEYWORD_ONLY,
704
+ )
705
+
706
+
707
+ def validate_function_matches_schema(
708
+ schema: FunctionSchema, func: typing.Callable
709
+ ) -> None:
710
+ sig = inspect.signature(func)
711
+
712
+ if not all(supported_param(p) for _, p in sig.parameters.items()):
713
+ raise ValueError(
714
+ f"custom_op(..., manual_schema)(func): positional-only args, "
715
+ f"varargs, and kwargs are not supported. Please rewrite `func` "
716
+ f"to not have them. Got `func` with signature: {sig}"
717
+ )
718
+
719
+ if (
720
+ any(
721
+ p.annotation is not inspect.Parameter.empty
722
+ for _, p in sig.parameters.items()
723
+ )
724
+ or sig.return_annotation is not inspect.Signature.empty
725
+ ):
726
+ raise ValueError(
727
+ f"custom_op(..., manual_schema)(func): When passing in a manual "
728
+ f"schema, we expect `func` to have no type annotations to avoid "
729
+ f"ambiguity. Got `func` with signature: {sig}"
730
+ )
731
+
732
+ positional = [
733
+ (name, param)
734
+ for name, param in sig.parameters.items()
735
+ if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
736
+ ]
737
+ kwargonly = [
738
+ (name, param)
739
+ for name, param in sig.parameters.items()
740
+ if param.kind == inspect.Parameter.KEYWORD_ONLY
741
+ ]
742
+
743
+ def error():
744
+ raise ValueError(
745
+ f"custom_op(..., manual_schema)(func): When passing in a manual "
746
+ f"schema, we expect `func`'s signature to match `manual_schema` "
747
+ f"(aside from type annotations). "
748
+ f"func's signature: {sig}, manual_schema: {schema}"
749
+ )
750
+
751
+ def error_default_args():
752
+ raise ValueError(
753
+ f"custom_op(..., manual_schema)(func): "
754
+ f"neither func nor manual_schema should have default "
755
+ f"arguments. Got "
756
+ f"func's signature: {sig}, manual_schema: {schema}"
757
+ )
758
+
759
+ def compare(sig_args, schema_args):
760
+ if len(sig_args) != len(schema_args):
761
+ error()
762
+ for (name, param), arg in zip(sig_args, schema_args):
763
+ if name != arg.name:
764
+ error()
765
+ if param.default is not inspect.Parameter.empty or arg.default is not None:
766
+ error_default_args()
767
+
768
+ compare(positional, schema.arguments.flat_positional)
769
+ compare(kwargonly, schema.arguments.flat_kwarg_only)
770
+
771
+
772
+ def infer_schema(prototype_function: typing.Callable) -> str:
773
+ sig = inspect.signature(prototype_function)
774
+
775
+ def error_fn(what):
776
+ raise ValueError(
777
+ f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
778
+ )
779
+
780
+ params = [
781
+ parse_param(name, param, error_fn) for name, param in sig.parameters.items()
782
+ ]
783
+ ret = parse_return(sig.return_annotation, error_fn)
784
+ return f"({', '.join(params)}) -> {ret}"
785
+
786
+
787
+ def parse_param(name, param, error_fn):
788
+ if not supported_param(param):
789
+ error_fn("We do not support positional-only args, varargs, or varkwargs.")
790
+
791
+ if param.annotation is inspect.Parameter.empty:
792
+ error_fn(f"Parameter {name} must have a type annotation.")
793
+
794
+ if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
795
+ error_fn(
796
+ f"Parameter {name} has unsupported type {param.annotation}. "
797
+ f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
798
+ )
799
+
800
+ if param.default is not inspect.Parameter.empty:
801
+ error_fn(
802
+ f"Parameter {name} has a default value; this is not supported. "
803
+ f"If you want to use default values then create a function with "
804
+ f"default values that calls the CustomOp"
805
+ )
806
+
807
+ return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
808
+
809
+
810
+ def derived_types(
811
+ base_type, cpp_type, list_base, optional_base_list, optional_list_base
812
+ ):
813
+ result = [
814
+ (base_type, cpp_type),
815
+ (typing.Optional[base_type], f"{cpp_type}?"),
816
+ ]
817
+ if list_base:
818
+ result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type]
819
+ if optional_base_list:
820
+ result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type]
821
+ if optional_list_base:
822
+ result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type]
823
+ return result
824
+
825
+
826
+ def get_supported_param_types():
827
+ data = [
828
+ # (python type, schema type, type[] variant, type?[] variant, type[]? variant
829
+ (torch.Tensor, "Tensor", True, True, False),
830
+ (int, "SymInt", True, False, True),
831
+ (float, "float", True, False, True),
832
+ (bool, "bool", True, False, True),
833
+ (str, "str", False, False, False),
834
+ (torch.types.Number, "Scalar", True, False, False),
835
+ (torch.dtype, "ScalarType", False, False, False),
836
+ (torch.device, "Device", False, False, False),
837
+ ]
838
+ result = []
839
+ for line in data:
840
+ result.extend(derived_types(*line))
841
+ return dict(result)
842
+
843
+
844
+ SUPPORTED_RETURN_TYPES = {
845
+ torch.Tensor: "Tensor",
846
+ typing.List[torch.Tensor]: "Tensor[]",
847
+ int: "SymInt",
848
+ float: "float",
849
+ bool: "bool",
850
+ torch.types.Number: "Scalar",
851
+ }
852
+
853
+
854
+ def parse_return(annotation, error_fn):
855
+ origin = typing.get_origin(annotation)
856
+ if origin is not tuple:
857
+ if annotation not in SUPPORTED_RETURN_TYPES.keys():
858
+ error_fn(
859
+ f"Return has unsupported type {annotation}. "
860
+ f"The valid types are: {SUPPORTED_RETURN_TYPES}."
861
+ )
862
+ return SUPPORTED_RETURN_TYPES[annotation]
863
+
864
+ args = typing.get_args(annotation)
865
+ for arg in args:
866
+ if arg not in SUPPORTED_RETURN_TYPES:
867
+ error_fn(
868
+ f"Return has unsupported type {annotation}. "
869
+ f"The valid types are: {SUPPORTED_RETURN_TYPES}."
870
+ )
871
+
872
+ return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
873
+
874
+
875
+ SUPPORTED_PARAM_TYPES = get_supported_param_types()
876
+
877
+
878
+ def report_error_callback(custom_op: typing.Any, key: str) -> None:
879
+ if key == "Undefined":
880
+ raise NotImplementedError(
881
+ f"{custom_op}: There were no Tensor inputs to this operator "
882
+ f"(e.g. you passed an empty list of Tensors). If your operator is a "
883
+ f"factory function (that is, it takes no Tensors and constructs "
884
+ f"a new one), then please use CustomOp.impl_factory to register "
885
+ f"an implementation for it"
886
+ )
887
+ if key == "Meta":
888
+ raise NotImplementedError(
889
+ f"{custom_op}: when running with device='Meta' tensors: there is no "
890
+ f"abstract impl registered for this CustomOp. Please register one via "
891
+ f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
892
+ )
893
+ if key in ("CPU", "CUDA"):
894
+ device = key.lower()
895
+ raise NotImplementedError(
896
+ f"{custom_op}: when running with device='{device}' tensors: there is no "
897
+ f"{device} impl registered for this CustomOp. Please register one via "
898
+ f"CustomOp.impl(device_type='{device}')"
899
+ )
900
+ raise NotImplementedError(
901
+ f"{custom_op}: No implementation for dispatch key {key}. It is likely "
902
+ f"that we have not added this functionality yet, please either open an "
903
+ f"issue or if you're feeling adventurous, use the low-level "
904
+ f"torch.library API"
905
+ )
906
+
907
+
908
+ def custom_op_from_existing(op):
909
+ ns = op.namespace
910
+ lib = torch.library.Library(ns, "FRAGMENT")
911
+ name = op.name().split("::")[-1]
912
+ schema_str = str(op._schema)
913
+ # CustomOp expects the schema string without the namespace
914
+ schema_str = schema_str.split("::")[-1]
915
+ schema = FunctionSchema.parse(schema_str)
916
+ return CustomOp(lib, ns, schema, name, op, _private_access=True)
917
+
918
+
919
+ def get_op(qualname):
920
+ def error_not_found():
921
+ raise ValueError(
922
+ f"Could not find the operator {qualname}. Please make sure you have "
923
+ f"already registered the operator and (if registered from C++) "
924
+ f"loaded it via torch.ops.load_library.")
925
+
926
+ ns, name = parse_qualname(qualname)
927
+ if not hasattr(torch.ops, ns):
928
+ error_not_found()
929
+ opnamespace = getattr(torch.ops, ns)
930
+ if not hasattr(opnamespace, name):
931
+ error_not_found()
932
+ packet = getattr(opnamespace, name)
933
+ if not hasattr(packet, 'default'):
934
+ error_not_found()
935
+ return packet.default
936
+
937
+
938
+ def _find_custom_op(qualname, also_check_torch_library=False):
939
+ if qualname in global_registry:
940
+ return global_registry[qualname]
941
+ if not also_check_torch_library:
942
+ raise RuntimeError(
943
+ f"Could not find custom op \"{qualname}\". Did you register it via "
944
+ f"the torch._custom_ops API?")
945
+ overload = get_op(qualname)
946
+ result = custom_op_from_existing(overload)
947
+ return result
948
+
949
+
950
+ def get_abstract_impl(qualname):
951
+ if qualname not in torch._custom_op.impl.global_registry:
952
+ return None
953
+ custom_op = torch._custom_op.impl.global_registry[qualname]
954
+ if custom_op is None:
955
+ return None
956
+ if not custom_op._has_impl("abstract"):
957
+ return None
958
+ return custom_op._get_impl("abstract").func
959
+
960
+
961
+ def _custom_op_with_schema(qualname, schema):
962
+ ns, name = qualname.split("::")
963
+ schema_str = f"{name}{schema}"
964
+ function_schema = FunctionSchema.parse(schema_str)
965
+ validate_schema(function_schema)
966
+
967
+ lib = library.Library(ns, "FRAGMENT")
968
+ lib.define(schema_str)
969
+ ophandle = find_ophandle_or_throw(ns, function_schema.name)
970
+ result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
971
+ result._register_autograd_kernel_indirection()
972
+
973
+ torch._C._dispatch_set_report_error_callback(
974
+ ophandle, functools.partial(report_error_callback, weakref.proxy(result))
975
+ )
976
+ return get_op(qualname)
torch/_custom_ops.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ from torch._custom_op.impl import (
4
+ _custom_op_with_schema,
5
+ _find_custom_op,
6
+ infer_schema,
7
+ parse_qualname,
8
+ validate_namespace,
9
+ )
10
+ from torch.library import get_ctx
11
+
12
+ __all__ = [
13
+ "custom_op",
14
+ "impl",
15
+ "impl_abstract",
16
+ "get_ctx",
17
+ "impl_save_for_backward",
18
+ "impl_backward",
19
+ ]
20
+
21
+
22
+ def custom_op(qualname, func_or_schema=None):
23
+ r"""Register a new custom operator
24
+
25
+ In PyTorch, defining an op (short for "operator") is a two step-process:
26
+ - we need to define the op (by providing an operator name and schema)
27
+ - we need to implement behavior for how the operator interacts with
28
+ various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
29
+
30
+ This entrypoint defines the custom operator (the first step)
31
+ you must then perform the second step by calling various
32
+ ``impl_*`` APIs.
33
+
34
+ This API may be used as a decorator (see examples).
35
+
36
+ For a detailed guide on custom ops, please see
37
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
38
+
39
+ Arguments:
40
+ qualname (str): Should be a string that looks like
41
+ "namespace::operator_name". Operators in PyTorch need a namespace to
42
+ avoid name collisions; a given operator may only be created once.
43
+ If you are writing a Python library, we recommend the namespace to
44
+ be the name of your top-level module.
45
+ func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
46
+ schema that tells PyTorch the types of the inputs/outputs.
47
+ If this is a Callable, we will automatically infer the schema from
48
+ the type annotations on the function (see examples). Otherwise,
49
+ if you don't want to use type annotations, you may provide us the
50
+ schema string.
51
+
52
+ Example::
53
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
54
+ >>> import torch
55
+ >>> import numpy as np
56
+ >>> from torch import Tensor
57
+ >>>
58
+ >>> # Step 1: define the custom op.
59
+ >>> # We need to provide the API a "prototype function"
60
+ >>> # (a function that returns NotImplementedError), from which
61
+ >>> # we will infer the types of the inputs and outputs.
62
+ >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
63
+ >>> def numpy_sin(x: Tensor) -> Tensor:
64
+ >>> raise NotImplementedError()
65
+ >>>
66
+ >>> # The custom op is now accessible via the torch.ops module:
67
+ >>> torch.ops.mylibrary.numpy_sin
68
+ >>>
69
+ >>> # Step 2: Register an implementation for various PyTorch subsystems
70
+ >>>
71
+ >>> # Register an implementation for CPU tensors
72
+ >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
73
+ >>> def numpy_sin_impl_cpu(x):
74
+ >>> return torch.from_numpy(np.sin(x.numpy()))
75
+ >>>
76
+ >>> # Register an implementation for CUDA tensors
77
+ >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
78
+ >>> def numpy_sin_impl_cuda(x):
79
+ >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
80
+ >>>
81
+ >>> x = torch.randn(3)
82
+ >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
83
+ >>>
84
+ >>> x_cuda = x.cuda()
85
+ >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
86
+
87
+ """
88
+ ns, name = parse_qualname(qualname)
89
+ validate_namespace(ns)
90
+
91
+ def inner(func):
92
+ if not inspect.isfunction(func):
93
+ raise ValueError(
94
+ f"custom_op(...)(func): Expected `func` to be a Python "
95
+ f"function, got: {type(func)}"
96
+ )
97
+
98
+ if func.__name__ != name:
99
+ raise ValueError(
100
+ f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
101
+ f"to have name '{name}' but got '{func.__name__}'. "
102
+ f"Please either change the name of `func` or the qualname that "
103
+ f"is passed to `custom_op`"
104
+ )
105
+
106
+ schema = infer_schema(func)
107
+ _custom_op_with_schema(qualname, schema)
108
+ return func
109
+
110
+ if func_or_schema is None:
111
+ return inner
112
+ if isinstance(func_or_schema, str):
113
+ _custom_op_with_schema(qualname, func_or_schema)
114
+ else:
115
+ return inner(func_or_schema)
116
+
117
+
118
+ def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
119
+ r"""Register an implementation for a device type for this custom op.
120
+
121
+ If the op is passed multiple Tensor inputs with different device
122
+ types, it will dispatch to the registered implementation for the highest
123
+ priority device type among those present.
124
+ The supported device types, in order of priority, are {'cuda', 'cpu'}.
125
+
126
+ This API may be used as a decorator (see examples).
127
+
128
+ For a detailed guide on custom ops, please see
129
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
130
+
131
+ Arguments:
132
+ device_types (str or Iterable[str]): the device type(s) to register the function for.
133
+
134
+ Example::
135
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
136
+ >>> import torch
137
+ >>> import numpy as np
138
+ >>> from torch import Tensor
139
+ >>>
140
+ >>> # Step 1: define the custom op.
141
+ >>> # We need to provide the API a "prototype function"
142
+ >>> # (a function that returns NotImplementedError), from which
143
+ >>> # we will infer the types of the inputs and outputs.
144
+ >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
145
+ >>> def numpy_cos(x: Tensor) -> Tensor:
146
+ >>> raise NotImplementedError()
147
+ >>>
148
+ >>> # The custom op is now accessible via the torch.ops module:
149
+ >>> torch.ops.mylibrary.numpy_cos
150
+ >>>
151
+ >>> # Step 2: Register an implementation for various PyTorch subsystems
152
+ >>>
153
+ >>> # Register an implementation for CPU tensors
154
+ >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
155
+ >>> def numpy_cos_impl_cpu(x):
156
+ >>> return torch.from_numpy(np.cos(x.numpy()))
157
+ >>>
158
+ >>> # Register an implementation for CUDA tensors
159
+ >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
160
+ >>> def numpy_cos_impl_cuda(x):
161
+ >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
162
+ >>>
163
+ >>> x = torch.randn(3)
164
+ >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
165
+ >>>
166
+ >>> x_cuda = x.cuda()
167
+ >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
168
+
169
+ """
170
+
171
+ def inner(func):
172
+ custom_op = _find_custom_op(qualname, also_check_torch_library=True)
173
+ custom_op.impl(device_types, _stacklevel=3)(func)
174
+ return func
175
+
176
+ if func is None:
177
+ return inner
178
+ return inner(func)
179
+
180
+
181
+ def impl_abstract(qualname, *, func=None):
182
+ r"""Register an abstract implementation for this operator.
183
+
184
+ An "abstract implementation" specifies the behavior of this operator on
185
+ Tensors that carry no data. Given some input Tensors with certain properties
186
+ (sizes/strides/storage_offset/device), it specifies what the properties of
187
+ the output Tensors are.
188
+
189
+ The abstract implementation has the same signature as the operator.
190
+ It is run for both FakeTensors and meta tensors. To write an abstract
191
+ implementation, assume that all Tensor inputs to the operator are
192
+ regular CPU/CUDA/Meta tensors, but they do not have storage, and
193
+ you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
194
+ The abstract implementation must consist of only PyTorch operations
195
+ (and may not directly access the storage or data of any input or
196
+ intermediate Tensors).
197
+
198
+ This API may be used as a decorator (see examples).
199
+
200
+ For a detailed guide on custom ops, please see
201
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
202
+
203
+ Examples::
204
+ >>> import numpy as np
205
+ >>> from torch import Tensor
206
+ >>>
207
+ >>> # Example 1: an operator without data-dependent output shape
208
+ >>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
209
+ >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
210
+ >>> raise NotImplementedError()
211
+ >>>
212
+ >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
213
+ >>> def custom_linear_abstract(x, weight):
214
+ >>> assert x.dim() == 2
215
+ >>> assert weight.dim() == 2
216
+ >>> assert bias.dim() == 1
217
+ >>> assert x.shape[1] == weight.shape[1]
218
+ >>> assert weight.shape[0] == bias.shape[0]
219
+ >>> assert x.device == weight.device
220
+ >>>
221
+ >>> return (x @ weight.t()) + bias
222
+ >>>
223
+ >>> # Example 2: an operator with data-dependent output shape
224
+ >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
225
+ >>> def custom_nonzero(x: Tensor) -> Tensor:
226
+ >>> ...
227
+ >>>
228
+ >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
229
+ >>> def custom_nonzero_abstract(x):
230
+ >>> # Number of nonzero-elements is data-dependent.
231
+ >>> # Since we cannot peek at the data in an abstract impl,
232
+ >>> # we use the ctx object to construct a new symint that
233
+ >>> # represents the data-dependent size.
234
+ >>> ctx = torch._custom_ops.get_ctx()
235
+ >>> nnz = ctx.create_unbacked_symint()
236
+ >>> shape = [x.dim(), nnz]
237
+ >>> result = x.new_empty(shape, dtype=torch.long)
238
+ >>> return result
239
+ >>>
240
+ >>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
241
+ >>> def custom_nonzero_impl(x):
242
+ >>> x_np = to_numpy(x)
243
+ >>> res = np.stack(np.nonzero(x_np), axis=1)
244
+ >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
245
+ >>> # constrain the range to at least 2
246
+ >>> if res.shape[0] <= 1:
247
+ >>> raise RuntimeError("not supported")
248
+ >>> return torch.tensor(res, device=x.device)
249
+
250
+ """
251
+ import torch.library
252
+
253
+ return torch.library.impl_abstract(qualname, func, _stacklevel=2)
254
+
255
+
256
+ def impl_save_for_backward(qualname, *, func=None):
257
+ r"""Register a function that tells us what to save for backward.
258
+
259
+ Please see :func:`impl_backward` for more details.
260
+ """
261
+
262
+ def inner(func):
263
+ custom_op = _find_custom_op(qualname, also_check_torch_library=True)
264
+ custom_op.impl_save_for_backward(_stacklevel=3)(func)
265
+ return func
266
+
267
+ if func is None:
268
+ return inner
269
+ return inner(func)
270
+
271
+
272
+ def impl_backward(qualname, output_differentiability=None, *, func=None):
273
+ r"""Registers a backward formula for an operator.
274
+
275
+ In order for an operator to work with autograd, you need to register
276
+ a backward formula. There are two pieces to this:
277
+ 1. You must give us a function to specify what to save for backward.
278
+ Call this the "save for backward" function.
279
+ 2. You must give us a function that computes gradients. Call this the
280
+ "backward" function.
281
+
282
+ Use `impl_save_for_backward` to define a "save for backward" function
283
+ that specifies what gets saved for backward. The function should accept
284
+ two arguments ``(inputs, output)`` and return the quantities to be saved
285
+ for backward.
286
+
287
+ During runtime, when you call the operator in a forwards pass, PyTorch
288
+ will invoke the "save for backward" function with the inputs and output
289
+ of the operator.
290
+
291
+ Use `impl_backward` to define the "backward" function. The backward
292
+ function must accept ``(ctx, saved, *grads)``:
293
+ - ``ctx`` is a context object where we may provide information
294
+ - ``saved`` is exactly what gets returned from the "save for backward"
295
+ function
296
+ - ``grads`` is one or more gradients. The number of gradients matches
297
+ the number of outputs of the operator.
298
+
299
+ The backward function must return a dict that maps the name of
300
+ an input to the operator to its corresponding gradient. All inputs that
301
+ were declared to be Tensors in the operator definition must be accounted
302
+ for in the dict. The gradient may be a Tensor or None.
303
+
304
+ For a detailed guide on custom ops, please see
305
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
306
+
307
+ """
308
+
309
+ def inner(func):
310
+ custom_op = _find_custom_op(qualname, also_check_torch_library=True)
311
+ custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
312
+ return func
313
+
314
+ if func is None:
315
+ return inner
316
+ return inner(func)
317
+
318
+
319
+ def _destroy(qualname):
320
+ """De-registers a custom op. For testing purposes only"""
321
+ custom_op = _find_custom_op(qualname)
322
+ custom_op._destroy()
torch/_decomp/__init__.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from collections import defaultdict
3
+ from functools import wraps
4
+ from itertools import chain
5
+ from typing import Callable, Dict, List, Sequence, Union
6
+
7
+ import torch
8
+ import torch.library
9
+ from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
10
+ from torch._prims_common import CustomOutParamAnnotation
11
+ from torch.utils import _pytree as pytree
12
+
13
+ __all__ = [
14
+ "decomposition_table",
15
+ "pre_autograd_decomposition_table",
16
+ "meta_table",
17
+ "register_decomposition",
18
+ "get_decompositions",
19
+ "core_aten_decompositions",
20
+ ]
21
+
22
+
23
+ # TODO: relax key type here; torch registrations should be possible to; but
24
+ # right now this type is accurate
25
+ global_decomposition_table: Dict[
26
+ str, Dict[torch._ops.OperatorBase, Callable]
27
+ ] = defaultdict(dict)
28
+
29
+ decomposition_table = global_decomposition_table["post_autograd"]
30
+ pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
31
+ meta_table = global_decomposition_table["meta"]
32
+
33
+
34
+ def _add_op_to_registry(registry, op, fn):
35
+ """
36
+ This is an internal API for adding an op to the decomposition table.
37
+
38
+ If op is OpOverload, it will be added to the registry directly.
39
+ If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
40
+ """
41
+ overloads: List[Union[torch._ops.OperatorBase]] = []
42
+ if isinstance(op, HigherOrderOperator):
43
+ # There's no concept of overloads for HigherOrderOperator
44
+ registry[op] = fn
45
+ return
46
+ elif isinstance(op, OpOverload):
47
+ overloads.append(op)
48
+ else:
49
+ assert isinstance(op, OpOverloadPacket)
50
+ for ol in op.overloads():
51
+ overloads.append(getattr(op, ol))
52
+
53
+ for op_overload in overloads:
54
+ if op_overload in registry:
55
+ raise RuntimeError(f"duplicate registrations for {op_overload}")
56
+ # TorchScript dumps a bunch of extra nonsense overloads
57
+ # which don't have corresponding dispatcher entries, we need
58
+ # to filter those out, e.g aten.add.float_int
59
+ if torch._C._dispatch_has_kernel(op_overload.name()):
60
+ registry[op_overload] = fn
61
+
62
+
63
+ def _convert_out_params(f):
64
+ out_annotation = f.__annotations__.get("out")
65
+
66
+ # If there are no out params, do not wrap the function.
67
+ if not out_annotation:
68
+ return f
69
+
70
+ # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
71
+ if getattr(out_annotation, "__origin__", None) is tuple:
72
+ sig = inspect.signature(f)
73
+ out_names = sig.return_annotation._fields
74
+ # If out is a tuple, we need to register a function that unpacks all the out
75
+ # elements as this is what native_functions.yaml expects
76
+
77
+ @wraps(f)
78
+ def _fn(*args, **kwargs):
79
+ out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
80
+ # Either all of the out kwargs are set or none of them
81
+ is_none = out_kwargs[0] is None
82
+ assert all((o is None) == is_none for o in out_kwargs)
83
+ return f(*args, **kwargs, out=None if is_none else out_kwargs)
84
+
85
+ out_params = [
86
+ inspect.Parameter(
87
+ o,
88
+ kind=inspect.Parameter.KEYWORD_ONLY,
89
+ default=None,
90
+ annotation=t,
91
+ )
92
+ for o, t in zip(out_names, out_annotation.__args__)
93
+ ]
94
+ # Drop the out parameter and concatenate the new kwargs in the signature
95
+ params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
96
+ _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
97
+ parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
98
+ )
99
+ # Drop the out parameter and concatenate the new kwargs in the annotations
100
+ _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
101
+ for o in out_params:
102
+ _fn.__annotations__[o.name] = o.annotation
103
+
104
+ # Propagate that this function is wrapped by `out_wrapper`
105
+ _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
106
+
107
+ return _fn
108
+
109
+ # Alternatively, there may be a single tensor out parameter with a name
110
+ # other than "out". This will need special treatment and is indicated by an
111
+ # annotation, which we will remove here so it is not exposed after wrapping.
112
+ custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
113
+ if custom_out_param_name:
114
+
115
+ @wraps(f)
116
+ def _fn(*args, **kwargs):
117
+ out_kwarg = kwargs.pop(custom_out_param_name, None)
118
+ return f(*args, **kwargs, out=out_kwarg)
119
+
120
+ out_param = inspect.Parameter(
121
+ custom_out_param_name,
122
+ kind=inspect.Parameter.KEYWORD_ONLY,
123
+ default=None,
124
+ annotation=out_annotation,
125
+ )
126
+
127
+ # Drop the out parameter and concatenate the new kwarg in the signature
128
+ sig = inspect.signature(f)
129
+ params = chain(
130
+ (v for k, v in sig.parameters.items() if k != "out"), (out_param,)
131
+ )
132
+ _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
133
+ parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
134
+ )
135
+
136
+ # Drop the out parameter and concatenate the new kwargs in the annotations
137
+ _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
138
+ _fn.__annotations__[out_param.name] = out_param.annotation
139
+
140
+ return _fn
141
+
142
+ return f
143
+
144
+
145
+ def register_decomposition(
146
+ aten_op, registry=None, *, type="post_autograd", unsafe=False
147
+ ):
148
+ """
149
+ A decorator to register a function as a decomposition to the Python
150
+ decomposition table. Use it like this::
151
+
152
+ @register_decomposition(torch.ops.aten.clamp_min)
153
+ def clamp_min(x):
154
+ return torch.clamp(self, min=min)
155
+
156
+ If you are writing a new decomposition, consider contributing it
157
+ directly to PyTorch in torch._decomp.decompositions.
158
+
159
+ This API is experimental; we are almost certainly going to extend
160
+ the API when we make decompositions eligible for use in transforms (e.g.,
161
+ autograd) and not just backend tracing, where we then need to know if a
162
+ decomposition can be used to simulate a transform.
163
+
164
+ By default, we also will register it to the Meta key of dispatcher,
165
+ and replace the c++ Meta implementation if there is already one.
166
+
167
+ unsafe kwarg is for reuse of this function for registering non-function
168
+ things
169
+ """
170
+
171
+ assert type in {"post_autograd", "pre_autograd", "meta"}
172
+
173
+ def decomposition_decorator(fn: Callable) -> Callable:
174
+ if not unsafe:
175
+ fn = _convert_out_params(fn)
176
+
177
+ nonlocal registry
178
+ if registry is None:
179
+ registry = global_decomposition_table[type]
180
+
181
+ def register(op):
182
+ _add_op_to_registry(registry, op, fn)
183
+
184
+ # To handle allowing multiple aten_ops at once
185
+ pytree.tree_map_(register, aten_op)
186
+ return fn
187
+
188
+ return decomposition_decorator
189
+
190
+
191
+ def get_decompositions(
192
+ aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
193
+ type: str = "post_autograd",
194
+ ) -> Dict[torch._ops.OperatorBase, Callable]:
195
+ """
196
+ Retrieve a dictionary of decompositions corresponding to the list of
197
+ operator overloads and overload packets passed as input. Overload
198
+ packets will include all decomposed overloads in the packet. If there is
199
+ no decomposition for a requested operator, it is silently ignored.
200
+
201
+ This API is experimental; we are almost certainly going to give an alternate,
202
+ more recommended formulation, where a user provides the set of operators
203
+ they know how to implement, and we provide decompositions for everything
204
+ not in this set.
205
+ """
206
+ assert type in {"post_autograd", "pre_autograd", "meta"}
207
+
208
+ registry = global_decomposition_table[type]
209
+ packets_to_overloads = defaultdict(list)
210
+ for opo in registry:
211
+ if isinstance(opo, (OpOverload, OpOverloadPacket)):
212
+ packets_to_overloads[opo.overloadpacket].append(opo)
213
+ decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
214
+ for op in aten_ops:
215
+ if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
216
+ for op_overload in packets_to_overloads[op]:
217
+ decompositions[op_overload] = registry[op_overload]
218
+ elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
219
+ decompositions[op] = registry[op]
220
+ return decompositions
221
+
222
+
223
+ def remove_decompositions(
224
+ decompositions: Dict[torch._ops.OperatorBase, Callable],
225
+ aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
226
+ ) -> None:
227
+ """
228
+ Given a dictionary of decompositions obtained from get_decompositions(), removes
229
+ operators associated with a list of operator overloads and overload packets passed
230
+ as input. If the decomposition dictionary does not contain a decomposition that is
231
+ specified to be removed, it is silently ignored.
232
+ """
233
+ for op in aten_ops:
234
+ if isinstance(op, OpOverloadPacket):
235
+ for overload_name in op.overloads():
236
+ opo = getattr(op, overload_name)
237
+ decompositions.pop(opo, None)
238
+ elif isinstance(op, OpOverload):
239
+ decompositions.pop(op, None)
240
+
241
+
242
+ # populate the table
243
+ import torch._decomp.decompositions
244
+ import torch._refs
245
+
246
+
247
+ # See NOTE [Core ATen Ops]
248
+ #
249
+ # list was copied from torch/_inductor/decomposition.py
250
+ # excluding decompositions that results in prim ops
251
+ # Resulting opset of decomposition is core aten ops
252
+ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
253
+ aten = torch.ops.aten
254
+ return get_decompositions(
255
+ [
256
+ aten.addcdiv,
257
+ aten.addcdiv_,
258
+ aten.addcmul,
259
+ aten.addcmul_,
260
+ aten.addr,
261
+ aten.affine_grid_generator,
262
+ aten.all,
263
+ aten.aminmax,
264
+ aten.arange.default,
265
+ aten.arange.start,
266
+ aten.avg_pool2d_backward,
267
+ aten.baddbmm,
268
+ aten.binary_cross_entropy,
269
+ aten.binary_cross_entropy_backward,
270
+ aten.binary_cross_entropy_with_logits,
271
+ aten.celu,
272
+ aten.celu_,
273
+ aten.clamp_max,
274
+ aten.clamp_min,
275
+ aten.col2im,
276
+ aten.count_nonzero,
277
+ aten.cudnn_batch_norm,
278
+ aten.cudnn_batch_norm_backward,
279
+ aten.deg2rad,
280
+ aten.deg2rad_,
281
+ aten.detach,
282
+ aten.diag_embed,
283
+ aten.diagonal_backward,
284
+ aten.dot,
285
+ aten.vdot,
286
+ aten.elu,
287
+ aten.elu_,
288
+ aten.elu_backward,
289
+ aten._embedding_bag,
290
+ aten.embedding_dense_backward,
291
+ aten.empty_like,
292
+ aten._euclidean_dist.default,
293
+ aten.expand_as,
294
+ aten.eye,
295
+ aten.fill,
296
+ aten.fill_,
297
+ aten.floor_divide,
298
+ aten.frac,
299
+ aten.frac_,
300
+ aten._fused_moving_avg_obs_fq_helper,
301
+ aten.gelu_,
302
+ aten.gelu_backward,
303
+ aten.glu,
304
+ aten.glu_backward,
305
+ aten.hardshrink,
306
+ aten.hardsigmoid,
307
+ aten.hardsigmoid_,
308
+ aten.hardsigmoid_backward,
309
+ aten.hardswish,
310
+ aten.hardswish_,
311
+ aten.hardswish_backward,
312
+ aten.hardtanh_,
313
+ aten.hardtanh_backward,
314
+ aten.heaviside,
315
+ aten.heaviside_,
316
+ aten.huber_loss,
317
+ aten.huber_loss_backward,
318
+ aten.im2col,
319
+ aten.index_add,
320
+ aten.index_add_,
321
+ aten.index_copy,
322
+ aten.index_copy_,
323
+ aten.index_fill,
324
+ aten.index_fill_,
325
+ aten.isneginf,
326
+ aten.isposinf,
327
+ aten.l1_loss,
328
+ aten.leaky_relu_,
329
+ aten.leaky_relu_backward,
330
+ aten.lerp,
331
+ aten.lerp_,
332
+ aten.linspace,
333
+ aten.logaddexp,
334
+ aten.logaddexp2,
335
+ aten.logit,
336
+ aten.logit_,
337
+ aten.logit_backward,
338
+ aten.log_sigmoid_backward,
339
+ aten.log_sigmoid_forward,
340
+ aten._log_softmax_backward_data,
341
+ aten.logspace,
342
+ aten.logsumexp.default,
343
+ aten.masked_fill,
344
+ aten.masked_fill_,
345
+ aten.mish,
346
+ aten.mish_,
347
+ aten.mse_loss,
348
+ aten.mse_loss_backward,
349
+ aten.multi_margin_loss,
350
+ aten.multilabel_margin_loss_forward,
351
+ aten.mv,
352
+ aten.mvlgamma,
353
+ aten.mvlgamma_,
354
+ aten.nansum,
355
+ aten.nan_to_num,
356
+ aten.nan_to_num_,
357
+ aten.narrow,
358
+ aten.native_batch_norm_backward,
359
+ aten.native_dropout_backward,
360
+ aten.native_group_norm_backward,
361
+ aten.native_layer_norm_backward,
362
+ aten.new_empty,
363
+ aten.new_full,
364
+ aten.new_ones,
365
+ aten.new_zeros,
366
+ aten.nll_loss_backward,
367
+ aten.nll_loss_forward,
368
+ aten.norm,
369
+ aten.ones,
370
+ aten.ones_like,
371
+ aten._prelu_kernel,
372
+ aten._prelu_kernel_backward,
373
+ aten._reshape_alias,
374
+ aten.rad2deg,
375
+ aten.rad2deg_,
376
+ aten.renorm,
377
+ aten.renorm_,
378
+ aten.replication_pad2d,
379
+ aten.rot90,
380
+ aten.rrelu_with_noise,
381
+ aten.rrelu_with_noise_,
382
+ aten.rsub.Scalar,
383
+ aten.rsub.Tensor,
384
+ aten._scaled_dot_product_flash_attention.default,
385
+ aten.select_backward,
386
+ aten.select_scatter,
387
+ aten.sgn,
388
+ aten.sgn_,
389
+ aten.sigmoid_backward,
390
+ aten.silu,
391
+ aten.silu_,
392
+ aten.silu_backward,
393
+ aten.sinc,
394
+ aten.sinc_,
395
+ aten.slice_backward,
396
+ aten.smooth_l1_loss,
397
+ aten.smooth_l1_loss_backward,
398
+ aten.soft_margin_loss,
399
+ aten.soft_margin_loss_backward,
400
+ aten._softmax_backward_data,
401
+ aten.softplus,
402
+ aten.softplus_backward,
403
+ aten.softshrink,
404
+ aten.special_entr,
405
+ aten.special_log_ndtr,
406
+ aten.special_xlog1py,
407
+ aten.split.Tensor,
408
+ aten.squeeze.default,
409
+ aten.squeeze.dim,
410
+ aten.std,
411
+ aten.std_mean,
412
+ aten.stack,
413
+ aten.sum.default,
414
+ aten.sum.out,
415
+ aten.t,
416
+ aten.tanh_backward,
417
+ aten.threshold,
418
+ aten.threshold_,
419
+ aten.threshold_backward,
420
+ aten.trace,
421
+ aten.transpose.int,
422
+ aten.tril,
423
+ aten.tril_,
424
+ aten.triu,
425
+ aten.triu_,
426
+ aten.unbind,
427
+ aten.unfold_backward,
428
+ aten.unfold_copy,
429
+ aten._unsafe_index,
430
+ aten.unsafe_split.Tensor,
431
+ aten.unsafe_split_with_sizes,
432
+ aten._unsafe_view,
433
+ aten.upsample_bilinear2d,
434
+ aten.upsample_nearest2d_backward,
435
+ aten.view_as_complex,
436
+ aten.xlogy,
437
+ aten.xlogy_,
438
+ aten.zero,
439
+ aten.zero_,
440
+ aten.zeros,
441
+ aten.zeros_like,
442
+ aten._weight_norm_interface,
443
+ ]
444
+ )
torch/_decomp/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
torch/_decomp/__pycache__/decompositions.cpython-310.pyc ADDED
Binary file (102 kB). View file
 
torch/_decomp/__pycache__/decompositions_for_jvp.cpython-310.pyc ADDED
Binary file (6.27 kB). View file
 
torch/_decomp/__pycache__/decompositions_for_rng.cpython-310.pyc ADDED
Binary file (7.99 kB). View file
 
torch/_decomp/decompositions.py ADDED
The diff for this file is too large to render. See raw diff
 
torch/_decomp/decompositions_for_jvp.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Callable, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch._decomp
6
+ from torch import Tensor
7
+ from torch._prims_common.wrappers import _maybe_remove_out_wrapper
8
+
9
+ decomposition_table = torch._decomp.decomposition_table
10
+ decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
11
+ register_decomposition = torch._decomp.register_decomposition
12
+ aten = torch.ops.aten
13
+
14
+ # NOTE: [forward-mode AD decompositions mechanism]
15
+ #
16
+ # The mechanism is in VariableType,
17
+ # IF any inputs have forward grad
18
+ # AND there is no forward AD formula implemented
19
+ # AND the functions is actually differentiable
20
+ # run the decomposition
21
+ # See run_jit_decomposition_with_args_for_jvp
22
+ # We currently use python decompositions that we torchscript.
23
+ #
24
+ # Note that we would be building the backward graph at the decomposed level
25
+ # too, but that is OK, because we would've errored out otherwise anyway.
26
+ #
27
+ # TODO: The mechanism we are using to register decompositions doesn't
28
+ # seem to be exclusively used for jvp. So open question here is whether
29
+ # torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
30
+ # If that is the case, we may go down the decomposition path unexpectedly
31
+ # (and possibly produce an unintelligible error) vs erroring out earlier and
32
+ # printing that the forward AD formula is not implemented.
33
+ #
34
+ # The solution to this may be to have a explicitly white list control when
35
+ # to enable the decomposition.
36
+
37
+
38
+ def maybe_register_decomposition(op):
39
+ def decorator(f):
40
+ try:
41
+ return register_decomposition(op)(f)
42
+ except Exception:
43
+ return f
44
+
45
+ return decorator
46
+
47
+
48
+ # Functions where we need a special decomposition for jvp but there's another version that
49
+ # should be used more generally (ex. for jvp we need to recompute the mean and variance for
50
+ # the backwards of a normalization function. Without jvp, it should use the saved value)
51
+ decomposition_table_for_jvp = {}
52
+
53
+
54
+ def register_decomposition_for_jvp(fn):
55
+ return register_decomposition(fn, registry=decomposition_table_for_jvp)
56
+
57
+
58
+ def _register_jit_decomposition_for_jvp(decomp, use_python=False):
59
+ if decomp in decomposition_table_for_jvp:
60
+ decomposition_table_used = decomposition_table_for_jvp
61
+ elif decomp in decomposition_table:
62
+ decomposition_table_used = decomposition_table
63
+ else:
64
+ raise RuntimeError(f"could not find decomposition for {decomp}")
65
+ decomp_fn = decomposition_table_used[decomp]
66
+
67
+ # `out_wrapper` extends a decompositions signature with
68
+ # an `out` parameter. However jit will use the unwrapped function's
69
+ # signature instead so we need to unwrap here to prevent an error
70
+ decomp_fn = _maybe_remove_out_wrapper(decomp_fn)
71
+
72
+ if use_python:
73
+ decomp_fn = torch.jit.ignore(decomp_fn)
74
+ sig = inspect.signature(decomp_fn)
75
+
76
+ # Create a string wrapping the function from the signature
77
+ # example output:
78
+ # def wrapped_decomp(x: torch.Tensor, y: int, z: int):
79
+ # return decomp_fn(x, y, z)
80
+ # Thanks copilot!
81
+ def get_function_def(sig):
82
+ param_def = [f"{param_str}" for param_str in sig.parameters.values()]
83
+ param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
84
+
85
+ return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
86
+
87
+ f_str = get_function_def(sig)
88
+ graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
89
+ else:
90
+ graph = torch.jit.script(decomp_fn).graph
91
+ torch.jit._register_decomposition(decomp, graph)
92
+
93
+
94
+ # The only decompositions here are temporary or hacks for the purposes of jvp
95
+
96
+
97
+ # TODO: do these also belong here?
98
+ @maybe_register_decomposition(aten.trace.default)
99
+ def trace(self: Tensor) -> Tensor:
100
+ return torch.sum(torch.diag(self))
101
+
102
+
103
+ @maybe_register_decomposition(aten.log_sigmoid_forward.default)
104
+ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
105
+ min = torch.minimum(self.new_zeros(()), self)
106
+ z = torch.exp(-torch.abs(self))
107
+ if self.is_cuda:
108
+ buffer = self.new_zeros((0,))
109
+ else:
110
+ buffer = z
111
+ return min - torch.log1p(z), buffer
112
+
113
+
114
+ def recompute_mean_var(
115
+ input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
116
+ ):
117
+ # for most norm decompositions, it will be the same as the core version except for here.
118
+ # We recompute the mean and variance so that they track gradients through input
119
+
120
+ mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
121
+ var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
122
+ eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
123
+ eps = eps.detach()
124
+ rstd = 1 / torch.sqrt(var + eps)
125
+ return mean, rstd
126
+
127
+
128
+ @register_decomposition_for_jvp(aten.native_layer_norm_backward)
129
+ def native_layer_norm_backward(
130
+ grad_out: Tensor,
131
+ input: Tensor,
132
+ normalized_shape: List[int],
133
+ mean: Tensor,
134
+ rstd: Tensor,
135
+ weight: Optional[Tensor],
136
+ bias: Optional[Tensor],
137
+ output_mask: List[bool],
138
+ ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
139
+ input_shape = input.shape
140
+ input_ndim = input.dim()
141
+
142
+ axis = input_ndim - len(normalized_shape)
143
+ inner_dims = input_shape[axis:]
144
+ outer_dims = input_shape[:axis]
145
+ inner_dim_indices = list(range(axis, input_ndim))
146
+ outer_dim_indices = list(range(0, axis))
147
+
148
+ N = 1
149
+ for i in inner_dims:
150
+ N *= i
151
+ M = 1
152
+ for i in outer_dims:
153
+ M *= i
154
+ if M <= 0 or N <= 0:
155
+ return (
156
+ input.new_zeros(input_shape),
157
+ input.new_zeros(input_shape[axis:]),
158
+ input.new_zeros(input_shape[axis:]),
159
+ )
160
+
161
+ mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
162
+
163
+ x_hat = (input - mean_) * rstd_
164
+ if weight is not None:
165
+ grad_x_hat = grad_out * weight
166
+ else:
167
+ grad_x_hat = grad_out
168
+ a = grad_x_hat * N
169
+ b = torch.sum(grad_x_hat, inner_dim_indices, True)
170
+ c1 = torch.mul(grad_x_hat, x_hat)
171
+ c2 = torch.sum(c1, inner_dim_indices, True)
172
+ c3 = torch.mul(x_hat, c2)
173
+ inner = a - b - c3
174
+
175
+ if output_mask[0]:
176
+ d_input: Optional[Tensor] = (rstd_ / N) * inner
177
+ else:
178
+ d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
179
+
180
+ if output_mask[1] and weight is not None:
181
+ if len(outer_dim_indices) > 0:
182
+ d_weight: Optional[Tensor] = torch.sum(
183
+ grad_out * x_hat, outer_dim_indices, False
184
+ )
185
+ else:
186
+ d_weight = grad_out * x_hat
187
+ elif weight is not None:
188
+ d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
189
+ else:
190
+ d_weight = torch.zeros(()) # should be None but doesn't work with vjp
191
+
192
+ if output_mask[2] and bias is not None:
193
+ if len(outer_dim_indices) > 0:
194
+ d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
195
+ else:
196
+ d_bias = grad_out.clone()
197
+ elif bias is not None:
198
+ d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
199
+ else:
200
+ d_bias = torch.zeros(()) # should be None but doesn't work with vjp
201
+
202
+ return (d_input, d_weight, d_bias)
203
+
204
+
205
+ def prod(x: List[int]):
206
+ r = 1
207
+ for i in x:
208
+ r *= i
209
+ return r
210
+
211
+
212
+ @register_decomposition_for_jvp(aten.native_batch_norm_backward)
213
+ def native_batch_norm_backward(
214
+ grad_out: Tensor,
215
+ input: Tensor,
216
+ weight: Optional[Tensor],
217
+ running_mean: Optional[Tensor],
218
+ running_var: Optional[Tensor],
219
+ save_mean: Optional[Tensor],
220
+ save_invstd: Optional[Tensor],
221
+ train: bool,
222
+ eps: float,
223
+ output_mask: List[bool],
224
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
225
+ input_shape = input.shape
226
+ input_rank = input.dim()
227
+ assert input_rank >= 2, "rank of the input must be at least 2"
228
+
229
+ axis = 1
230
+ num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
231
+ mean = save_mean
232
+ invstd = save_invstd
233
+ if train:
234
+ assert (
235
+ save_mean is not None and save_invstd is not None
236
+ ), "when train=True, save_mean and save_invstd are required"
237
+
238
+ reduciton_dims = [0] + list(range(2, input.dim()))
239
+ assert invstd is not None # for typing
240
+ mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
241
+ else:
242
+ assert running_mean is not None and running_var is not None
243
+ mean = running_mean
244
+ invstd = torch.rsqrt(running_var + eps)
245
+
246
+ assert invstd is not None and mean is not None
247
+
248
+ broadcast_mask = [1] * input_rank
249
+ broadcast_mask[axis] = input_shape[axis]
250
+
251
+ reduction_axes: List[int] = []
252
+ for i in range(input_rank):
253
+ if i != axis:
254
+ reduction_axes.append(i)
255
+
256
+ mean = torch.reshape(mean, broadcast_mask)
257
+ norm = 1.0 / num_features
258
+ grad_output_sum = torch.sum(grad_out, reduction_axes)
259
+ dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
260
+
261
+ grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
262
+ proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
263
+
264
+ if weight is None:
265
+ grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
266
+ else:
267
+ grad_scale = torch.reshape(invstd * weight, broadcast_mask)
268
+
269
+ if train:
270
+ proj = (input - mean) * proj_scale
271
+ grad_input = ((grad_out - proj) - grad_mean) * grad_scale
272
+ else:
273
+ grad_input = grad_out * grad_scale
274
+
275
+ if output_mask[1]:
276
+ grad_weight = dot_p * invstd
277
+ elif weight is not None:
278
+ grad_weight = torch.zeros_like(
279
+ weight
280
+ ) # should be None but doesn't work with vjp
281
+ else:
282
+ grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
283
+
284
+ if output_mask[2]:
285
+ grad_bias = grad_output_sum
286
+ else:
287
+ grad_bias = torch.zeros_like(
288
+ grad_output_sum
289
+ ) # should be None but doesn't work with vjp
290
+
291
+ return (grad_input, grad_weight, grad_bias)
292
+
293
+
294
+ _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
295
+ _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
296
+ _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
297
+ _register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
298
+ _register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
299
+ _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
300
+ _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
301
+ _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
302
+ _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
torch/_decomp/decompositions_for_rng.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from collections import defaultdict
3
+ from typing import Callable, Dict
4
+
5
+ import torch
6
+ import torch._decomp as decomp
7
+ from torch._decomp import get_decompositions
8
+ from torch._ops import OpOverload
9
+
10
+ aten = torch.ops.aten
11
+
12
+ rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
13
+
14
+
15
+ def register_rng_decomposition(aten_op):
16
+ return decomp.register_decomposition(aten_op, rng_decompositions)
17
+
18
+
19
+ def throw_on_non_cuda(device):
20
+ raise RuntimeError(
21
+ f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
22
+ f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
23
+ "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
24
+ )
25
+
26
+
27
+ # TODO - We have to register many more distributions here, and also higher level
28
+ # ops like dropout which have fused implementation and can hide the rand inside.
29
+ @register_rng_decomposition(aten.rand)
30
+ def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
31
+ if device and device.type != "cuda":
32
+ throw_on_non_cuda(device)
33
+ seed, offset = PhiloxStateTracker.get_state_as_tuple()
34
+ dtype = dtype or torch.float32
35
+ out, offset_jump = torch.ops.rngprims.philox_rand(
36
+ shape, seed, offset, None, device, dtype
37
+ )
38
+ PhiloxStateTracker.advance_offset(offset_jump)
39
+ return out
40
+
41
+
42
+ @register_rng_decomposition(aten.rand_like)
43
+ def rand_like(
44
+ x: torch.Tensor,
45
+ dtype=None,
46
+ layout=None,
47
+ device=None,
48
+ pin_memory=False,
49
+ memory_format=torch.preserve_format,
50
+ ):
51
+ device = device or x.device
52
+ if device.type != "cuda":
53
+ throw_on_non_cuda(device)
54
+ dtype = dtype or x.dtype
55
+ seed, offset = PhiloxStateTracker.get_state_as_tuple()
56
+ out, offset_jump = torch.ops.rngprims.philox_rand(
57
+ x.shape, seed, offset, None, device, dtype
58
+ )
59
+ PhiloxStateTracker.advance_offset(offset_jump)
60
+ return out
61
+
62
+
63
+ class PhiloxState:
64
+ """
65
+ Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
66
+ relative_offset. seed and base_offset basically point to the rng state just
67
+ before tracing starts. relative offset tracks the totally consumed offset at
68
+ trace time.
69
+ """
70
+
71
+ def __init__(self):
72
+ self.reset()
73
+
74
+ def reset(self):
75
+ self.seed = torch.tensor(())
76
+ self.base_offset = torch.tensor(())
77
+ self.relative_offset = 0
78
+ self.offset_advanced_alteast_once = False
79
+
80
+ def validate_state(self):
81
+ assert self.seed.numel() != 0 and self.base_offset.numel() != 0
82
+
83
+ def advance_offset(self, consumed_offset):
84
+ self.offset_advanced_alteast_once = True
85
+ self.relative_offset = self.relative_offset + consumed_offset
86
+
87
+ def set_state(self, seed, base_offset, relative_offset=0):
88
+ self.seed = seed
89
+ self.base_offset = base_offset
90
+ self.relative_offset = relative_offset
91
+
92
+ def get_state_as_tuple(self):
93
+ self.validate_state()
94
+ return (self.seed, self.base_offset + self.relative_offset)
95
+
96
+ def get_state_as_tensor(self):
97
+ # Only needed because we override get_rng_state.
98
+ self.validate_state()
99
+ return torch.stack([self.seed, self.base_offset + self.relative_offset])
100
+
101
+ def set_state_from_tensor(self, state):
102
+ # Only needed because we override set_rng_state.
103
+ self.seed, self.base_offset = torch.unbind(state)
104
+ self.relative_offset = 0
105
+
106
+
107
+ class PhiloxStateTracker:
108
+ """
109
+ Singleton class to track the philox rng state during AOT Autograd tracing.
110
+ For each aot tracing instance, AOT Autograd resets this tracker and keeps
111
+ track of both forward and backward offsets. At runtime, we only care about
112
+ the total consumed forward and backward offsets. For dynamic shapes, these
113
+ offsets are a function of input shapes. Therefore, the AOT generated graphs
114
+ have additional outputs that compute total consumed forward and backward
115
+ offsets.
116
+ """
117
+
118
+ running_state: PhiloxState
119
+ fwd_state: PhiloxState
120
+ bwd_state: PhiloxState
121
+
122
+ def __enter__(self):
123
+ PhiloxStateTracker.reset()
124
+ return self
125
+
126
+ def __exit__(self, exc_type, exc_cal, exc_tb):
127
+ PhiloxStateTracker.reset()
128
+
129
+ @classmethod
130
+ def reset(cls):
131
+ cls.running_state = PhiloxState()
132
+ cls.fwd_state = PhiloxState()
133
+ cls.bwd_state = PhiloxState()
134
+
135
+ @classmethod
136
+ def mark_beginning_of_forward(cls):
137
+ # Tells the tracker to use fwd_state as the running state
138
+ cls.running_state = cls.fwd_state
139
+
140
+ @classmethod
141
+ def mark_beginning_of_backward(cls):
142
+ # Tells the tracker to use bwd_state as the running state
143
+ cls.running_state = cls.bwd_state
144
+
145
+ @classmethod
146
+ def record_state(cls, seed, offset, mode):
147
+ # Records the seed and offset tensors. These tensors are used to invoke
148
+ # the philox_rand functional primitives.
149
+ if mode == "forward":
150
+ cls.fwd_state.set_state(seed, offset)
151
+ cls.mark_beginning_of_forward()
152
+ else:
153
+ assert mode == "backward"
154
+ cls.bwd_state.set_state(seed, offset)
155
+
156
+ @classmethod
157
+ def get_state_as_tensor(cls):
158
+ # The only reason this exists is because we override get_rng_state and
159
+ # set_rng_state during tracing. get_rng_state expects a tensor output,
160
+ # so return (seed, offset) tuple upset other parts of the program like
161
+ # ctx.saved_tensors.
162
+
163
+ # A bad consequence is that if user saves and restores rng state, we
164
+ # have little bit of ugliness in the generated code, where we first
165
+ # concat the (seed, offset) to create a tensor for get_rng_state, and
166
+ # then split it back to get (seed, offset) tuple in set_rng_state.
167
+
168
+ # TODO: Investigate if there is be a better way to wrap the tuple in a
169
+ # false Tensor object, and then desugar it later on.
170
+ return cls.running_state.get_state_as_tensor()
171
+
172
+ @classmethod
173
+ def get_state_as_tuple(cls):
174
+ return cls.running_state.get_state_as_tuple()
175
+
176
+ @classmethod
177
+ def set_state_from_tensor(cls, x):
178
+ # This is only needed because we override set_rng_state. Look at the
179
+ # comment in get_state_from_tensor method.
180
+ cls.running_state.set_state_from_tensor(x)
181
+
182
+ @classmethod
183
+ def advance_offset(cls, consumed_offset):
184
+ cls.running_state.advance_offset(consumed_offset)
185
+
186
+ @classmethod
187
+ def get_current_relative_offset(cls):
188
+ return cls.running_state.relative_offset
189
+
190
+ @staticmethod
191
+ def multiple_of_4(offset):
192
+ # torch cuda rng state offset must be a multiple of 4. For inductor, as
193
+ # we sum up all the numel, the result might not be a multiple of 4. This
194
+ # method achieves that.
195
+ return (offset + 3) // 4 * 4
196
+
197
+ @classmethod
198
+ def get_updated_fwd_offset(cls):
199
+ # Short circuit if no rand ops were observed
200
+ if not cls.fwd_state.offset_advanced_alteast_once:
201
+ return cls.fwd_state.base_offset
202
+ return cls.multiple_of_4(
203
+ cls.fwd_state.base_offset + cls.fwd_state.relative_offset
204
+ )
205
+
206
+ @classmethod
207
+ def get_updated_bwd_offset(cls):
208
+ # Short circuit if no rand ops were observed
209
+ if not cls.bwd_state.offset_advanced_alteast_once:
210
+ return cls.bwd_state.base_offset
211
+ return cls.multiple_of_4(
212
+ cls.bwd_state.base_offset + cls.bwd_state.relative_offset
213
+ )
214
+
215
+
216
+ # Adding more decompositions which eventually use rand_like inside decomps.
217
+ # Adding these in rng_decompositions ensures the functionalization of rand_like
218
+ # ops used in these decomps. The list is copied from inductor codebase, which
219
+ # uses it for similar purpose.
220
+ #
221
+ # Caution - These decomps do not have same accuracy as that of eager. However,
222
+ # we can't just disable them with a config flag like fallback_random, because
223
+ # for functionalization of rng ops, we have to decompose these ops.
224
+ extra_random_decomps = get_decompositions(
225
+ [
226
+ aten.cauchy,
227
+ aten.cauchy_,
228
+ aten.exponential,
229
+ aten.exponential_,
230
+ aten.geometric,
231
+ aten.geometric_,
232
+ aten.native_dropout,
233
+ aten.normal,
234
+ aten.normal_,
235
+ aten.normal_functional,
236
+ aten.log_normal,
237
+ aten.log_normal_,
238
+ aten.rrelu_with_noise,
239
+ aten.rrelu_with_noise_,
240
+ aten.uniform_,
241
+ ]
242
+ )
243
+ register_extra_random_decomp = functools.partial(
244
+ decomp.register_decomposition, registry=extra_random_decomps
245
+ )
246
+
247
+
248
+ @register_extra_random_decomp([aten.bernoulli_])
249
+ def bernoulli_(self, p=0.5):
250
+ if self.device == torch.device("cpu"):
251
+ return NotImplemented
252
+ return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
253
+
254
+
255
+ @register_extra_random_decomp([aten.bernoulli.p])
256
+ def bernoulli_p(self, p=0.5, *, generator=None):
257
+ if self.device == torch.device("cpu"):
258
+ return NotImplemented
259
+ assert generator is None
260
+ return torch.rand_like(self, dtype=torch.float32) < p
261
+
262
+
263
+ rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]
torch/_deploy.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import torch
4
+ from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
5
+ from torch.package._package_pickler import create_pickler
6
+ from torch.package._package_unpickler import PackageUnpickler
7
+ from torch.serialization import _maybe_decode_ascii
8
+
9
+
10
+ def _save_storages(importer, obj):
11
+ serialized_storages = []
12
+ serialized_dtypes = []
13
+
14
+ importer = importer if isinstance(importer, torch.package.PackageImporter) else None
15
+ importers: Importer
16
+ if importer is not None:
17
+ importers = OrderedImporter(importer, sys_importer)
18
+ else:
19
+ importers = sys_importer
20
+
21
+ def persistent_id(obj):
22
+ if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
23
+ if isinstance(obj, torch.storage.TypedStorage):
24
+ # TODO: Once we decide to break serialization FC, we can
25
+ # remove this case
26
+ storage = obj._untyped_storage
27
+ dtype = obj.dtype
28
+ else:
29
+ storage = obj
30
+ dtype = torch.uint8
31
+
32
+ serialized_storages.append(obj)
33
+ serialized_dtypes.append(dtype)
34
+ return ("storage", len(serialized_storages) - 1)
35
+
36
+ if hasattr(obj, "__reduce_deploy__"):
37
+ if _serialized_reduces.get(id(obj)) is None:
38
+ _serialized_reduces[id(obj)] = (
39
+ "reduce_deploy",
40
+ id(obj),
41
+ *obj.__reduce_deploy__(importers),
42
+ )
43
+ return _serialized_reduces[id(obj)]
44
+
45
+ return None
46
+
47
+ # Write the pickle data for `obj`
48
+ data_buf = io.BytesIO()
49
+ pickler = create_pickler(data_buf, importers)
50
+ pickler.persistent_id = persistent_id
51
+ pickler.dump(obj)
52
+ data_value = data_buf.getvalue()
53
+ return (
54
+ data_value,
55
+ serialized_storages,
56
+ serialized_dtypes,
57
+ importer.zip_reader if importer else None,
58
+ )
59
+
60
+
61
+ def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
62
+ def persistent_load(saved_id):
63
+ assert isinstance(saved_id, tuple)
64
+ typename = _maybe_decode_ascii(saved_id[0])
65
+ data = saved_id[1:]
66
+
67
+ if typename == "storage":
68
+ # TODO: Once we decide to break serialization FC, we can
69
+ # stop wrapping with TypedStorage
70
+ storage = serialized_storages[data[0]]
71
+ dtype = serialized_dtypes[data[0]]
72
+ return torch.storage.TypedStorage(
73
+ wrap_storage=storage.untyped(), dtype=dtype
74
+ )
75
+
76
+ if typename == "reduce_deploy":
77
+ reduce_id, func, args = data
78
+ if reduce_id not in _loaded_reduces:
79
+ _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
80
+ return _loaded_reduces[reduce_id]
81
+
82
+ return None
83
+
84
+ importer: Importer
85
+ if zip_reader is not None:
86
+ importer = OrderedImporter(_get_package(zip_reader), sys_importer)
87
+ else:
88
+ importer = sys_importer
89
+
90
+ unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
91
+ unpickler.persistent_load = persistent_load # type: ignore[assignment]
92
+ result = _deploy_objects[id] = unpickler.load()
93
+ return result
94
+
95
+
96
+ def _get_package(zip_reader):
97
+ if zip_reader not in _raw_packages:
98
+ _raw_packages[zip_reader] = PackageImporter(zip_reader)
99
+ return _raw_packages[zip_reader]
100
+
101
+
102
+ _raw_packages: dict = {}
103
+ _deploy_objects: dict = {}
104
+ _serialized_reduces: dict = {}
105
+ _loaded_reduces: dict = {}
torch/_dispatch/__init__.py ADDED
File without changes