Adi-69s commited on
Commit
420d654
·
verified ·
1 Parent(s): a3f9db9

Delete torch

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. torch/_C.cp310-win_amd64.pyd +0 -0
  2. torch/_C/_VariableFunctions.pyi +0 -0
  3. torch/_C/__init__.pyi +0 -0
  4. torch/_C/_autograd.pyi +0 -123
  5. torch/_C/_cpu.pyi +0 -5
  6. torch/_C/_cudnn.pyi +0 -17
  7. torch/_C/_distributed_autograd.pyi +0 -26
  8. torch/_C/_distributed_c10d.pyi +0 -478
  9. torch/_C/_distributed_rpc.pyi +0 -188
  10. torch/_C/_distributed_rpc_testing.pyi +0 -35
  11. torch/_C/_functions.pyi +0 -11
  12. torch/_C/_functorch.pyi +0 -71
  13. torch/_C/_itt.pyi +0 -5
  14. torch/_C/_lazy.pyi +0 -28
  15. torch/_C/_lazy_ts_backend.pyi +0 -11
  16. torch/_C/_monitor.pyi +0 -44
  17. torch/_C/_nn.pyi +0 -86
  18. torch/_C/_nvtx.pyi +0 -6
  19. torch/_C/_onnx.pyi +0 -38
  20. torch/_C/_profiler.pyi +0 -238
  21. torch/_C/_verbose.pyi +0 -3
  22. torch/_VF.py +0 -30
  23. torch/_VF.pyi +0 -0
  24. torch/__config__.py +0 -22
  25. torch/__future__.py +0 -21
  26. torch/_appdirs.py +0 -666
  27. torch/_awaits/__init__.py +0 -54
  28. torch/_awaits/__pycache__/__init__.cpython-310.pyc +0 -0
  29. torch/_classes.py +0 -55
  30. torch/_compile.py +0 -30
  31. torch/_custom_op/__init__.py +0 -0
  32. torch/_custom_op/__pycache__/__init__.cpython-310.pyc +0 -0
  33. torch/_custom_op/__pycache__/autograd.cpython-310.pyc +0 -0
  34. torch/_custom_op/__pycache__/functional.cpython-310.pyc +0 -0
  35. torch/_custom_op/__pycache__/impl.cpython-310.pyc +0 -0
  36. torch/_custom_op/autograd.py +0 -274
  37. torch/_custom_op/functional.py +0 -187
  38. torch/_custom_op/impl.py +0 -976
  39. torch/_custom_ops.py +0 -322
  40. torch/_decomp/__init__.py +0 -444
  41. torch/_decomp/__pycache__/__init__.cpython-310.pyc +0 -0
  42. torch/_decomp/__pycache__/decompositions.cpython-310.pyc +0 -0
  43. torch/_decomp/__pycache__/decompositions_for_jvp.cpython-310.pyc +0 -0
  44. torch/_decomp/__pycache__/decompositions_for_rng.cpython-310.pyc +0 -0
  45. torch/_decomp/decompositions.py +0 -0
  46. torch/_decomp/decompositions_for_jvp.py +0 -302
  47. torch/_decomp/decompositions_for_rng.py +0 -263
  48. torch/_deploy.py +0 -105
  49. torch/_dispatch/__init__.py +0 -0
  50. torch/_dispatch/__pycache__/__init__.cpython-310.pyc +0 -0
torch/_C.cp310-win_amd64.pyd DELETED
Binary file (10.2 kB)
 
torch/_C/_VariableFunctions.pyi DELETED
The diff for this file is too large to render. See raw diff
 
torch/_C/__init__.pyi DELETED
The diff for this file is too large to render. See raw diff
 
torch/_C/_autograd.pyi DELETED
@@ -1,123 +0,0 @@
1
- from enum import Enum
2
- from typing import Any, Callable, List, Optional, Set
3
-
4
- import torch
5
-
6
- from ._profiler import (
7
- _ProfilerEvent,
8
- ActiveProfilerType,
9
- ProfilerActivity,
10
- ProfilerConfig,
11
- )
12
-
13
- # Defined in tools/autograd/init.cpp
14
-
15
- class DeviceType(Enum):
16
- CPU = ...
17
- CUDA = ...
18
- MKLDNN = ...
19
- OPENGL = ...
20
- OPENCL = ...
21
- IDEEP = ...
22
- HIP = ...
23
- FPGA = ...
24
- ORT = ...
25
- XLA = ...
26
- MPS = ...
27
- HPU = ...
28
- Meta = ...
29
- Vulkan = ...
30
- Metal = ...
31
- PrivateUse1 = ...
32
-
33
- class ProfilerEvent:
34
- def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
35
- def cpu_memory_usage(self) -> int: ...
36
- def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
37
- def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
38
- def cuda_memory_usage(self) -> int: ...
39
- def device(self) -> int: ...
40
- def handle(self) -> int: ...
41
- def has_cuda(self) -> bool: ...
42
- def is_remote(self) -> bool: ...
43
- def kind(self) -> int: ...
44
- def name(self) -> str: ...
45
- def node_id(self) -> int: ...
46
- def sequence_nr(self) -> int: ...
47
- def shapes(self) -> List[List[int]]: ...
48
- def thread_id(self) -> int: ...
49
- def flops(self) -> float: ...
50
- def is_async(self) -> bool: ...
51
-
52
- class _KinetoEvent:
53
- def name(self) -> str: ...
54
- def device_index(self) -> int: ...
55
- def start_us(self) -> int: ...
56
- def duration_us(self) -> int: ...
57
- def is_async(self) -> bool: ...
58
- def linked_correlation_id(self) -> int: ...
59
- def shapes(self) -> List[List[int]]: ...
60
- def dtypes(self) -> List[str]: ...
61
- def concrete_inputs(self) -> List[Any]: ...
62
- def device_type(self) -> DeviceType: ...
63
- def start_thread_id(self) -> int: ...
64
- def end_thread_id(self) -> int: ...
65
- def correlation_id(self) -> int: ...
66
- def fwd_thread_id(self) -> int: ...
67
- def stack(self) -> List[str]: ...
68
- def scope(self) -> int: ...
69
- def sequence_nr(self) -> int: ...
70
- def flops(self) -> int: ...
71
- def cuda_elapsed_us(self) -> int: ...
72
- def privateuse1_elapsed_us(self) -> int: ...
73
-
74
- class _ProfilerResult:
75
- def events(self) -> List[_KinetoEvent]: ...
76
- def legacy_events(self) -> List[List[ProfilerEvent]]: ...
77
- def save(self, path: str) -> None: ...
78
- def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
79
- def trace_start_us(self) -> int: ...
80
-
81
- class SavedTensor: ...
82
-
83
- def _enable_profiler(
84
- config: ProfilerConfig,
85
- activities: Set[ProfilerActivity],
86
- ) -> None: ...
87
- def _prepare_profiler(
88
- config: ProfilerConfig,
89
- activities: Set[ProfilerActivity],
90
- ) -> None: ...
91
- def _disable_profiler() -> _ProfilerResult: ...
92
- def _profiler_enabled() -> bool: ...
93
- def _add_metadata_json(key: str, value: str) -> None: ...
94
- def _kineto_step() -> None: ...
95
- def _get_sequence_nr() -> int: ...
96
- def kineto_available() -> bool: ...
97
- def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
98
- def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
99
- def _supported_activities() -> Set[ProfilerActivity]: ...
100
- def _enable_record_function(enable: bool) -> None: ...
101
- def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
102
- def _push_saved_tensors_default_hooks(
103
- pack_hook: Callable[[torch.Tensor], Any],
104
- unpack_hook: Callable[[Any], torch.Tensor],
105
- ) -> None: ...
106
- def _pop_saved_tensors_default_hooks() -> None: ...
107
- def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
108
- def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
109
- def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
110
- def _profiler_type() -> ActiveProfilerType: ...
111
- def _saved_tensors_hooks_enable() -> None: ...
112
- def _saved_tensors_hooks_disable(message: str) -> None: ...
113
- def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
114
-
115
- class CreationMeta(Enum):
116
- DEFAULT = ...
117
- IN_CUSTOM_FUNCTION = ...
118
- MULTI_OUTPUT_NODE = ...
119
- NO_GRAD_MODE = ...
120
- INFERENCE_MODE = ...
121
-
122
- def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
123
- def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_cpu.pyi DELETED
@@ -1,5 +0,0 @@
1
- from torch.types import _bool
2
-
3
- # Defined in torch/csrc/cpu/Module.cpp
4
-
5
- def _is_cpu_support_vnni() -> _bool: ...
 
 
 
 
 
 
torch/_C/_cudnn.pyi DELETED
@@ -1,17 +0,0 @@
1
- from enum import Enum
2
-
3
- from torch.types import _bool, Tuple
4
-
5
- # Defined in torch/csrc/cuda/shared/cudnn.cpp
6
- is_cuda: _bool
7
-
8
- def getRuntimeVersion() -> Tuple[int, int, int]: ...
9
- def getCompileVersion() -> Tuple[int, int, int]: ...
10
- def getVersionInt() -> int: ...
11
-
12
- class RNNMode(int, Enum):
13
- value: int
14
- rnn_relu = ...
15
- rnn_tanh = ...
16
- lstm = ...
17
- gru = ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_distributed_autograd.pyi DELETED
@@ -1,26 +0,0 @@
1
- from typing import Any, Dict, List, Set
2
-
3
- import torch
4
-
5
- # This module is defined in torch/csrc/distributed/autograd/init.cpp
6
-
7
- class DistAutogradContext:
8
- def _context_id(self) -> int: ...
9
- def _recv_functions(self) -> Dict[int, Any]: ...
10
- def _send_functions(self) -> Dict[int, Any]: ...
11
- def _known_worker_ids(self) -> Set[int]: ...
12
-
13
- def _new_context() -> DistAutogradContext: ...
14
- def _release_context(context_id: int) -> None: ...
15
- def _get_max_id() -> int: ...
16
- def _is_valid_context(worker_id: int) -> bool: ...
17
- def _retrieve_context(context_id: int) -> DistAutogradContext: ...
18
- def _current_context() -> DistAutogradContext: ...
19
- def _init(worker_id: int) -> None: ...
20
- def _get_debug_info() -> Dict[str, str]: ...
21
- def backward(
22
- context_id: int,
23
- roots: List[torch.Tensor],
24
- retain_graph=False,
25
- ) -> None: ...
26
- def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_distributed_c10d.pyi DELETED
@@ -1,478 +0,0 @@
1
- # mypy: disable-error-code="type-arg"
2
- from datetime import timedelta
3
- from enum import Enum
4
- from typing import Any, Dict, List, Optional, overload, Tuple, Union
5
-
6
- from torch import Tensor
7
- from torch._C import ScriptObject
8
- from torch.futures import Future
9
-
10
- # This module is defined in torch/csrc/distributed/c10d/init.cpp
11
-
12
- _DEFAULT_FIRST_BUCKET_BYTES: int
13
- _DEFAULT_NO_TIMEOUT: timedelta
14
- _DEFAULT_PG_TIMEOUT: timedelta
15
- _DEFAULT_PG_NCCL_TIMEOUT: timedelta
16
-
17
- class BuiltinCommHookType(Enum):
18
- ALLREDUCE = ...
19
- FP16_COMPRESS = ...
20
-
21
- def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
22
- def _register_builtin_comm_hook(
23
- reducer: Reducer,
24
- comm_hook_type: BuiltinCommHookType,
25
- ): ...
26
-
27
- class GradBucket:
28
- def index(self) -> int: ...
29
- def buffer(self) -> Tensor: ...
30
- def gradients(self) -> List[Tensor]: ...
31
- def is_last(self) -> bool: ...
32
- def set_buffer(self, tensor: Tensor) -> None: ...
33
- def parameters(self) -> List[Tensor]: ...
34
-
35
- class Reducer:
36
- def __init__(
37
- self,
38
- params: List[Tensor],
39
- bucket_indices: List[List[int]],
40
- per_bucket_size_limits: List[int],
41
- process_group: ProcessGroup,
42
- expect_sparse_gradients: List[bool] = ...,
43
- bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
44
- find_unused_parameters: bool = ...,
45
- gradient_as_bucket_view: bool = ...,
46
- param_to_name_mapping: Dict[int, str] = ...,
47
- first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
48
- ): ...
49
- def prepare_for_forward(self) -> None: ...
50
- def prepare_for_backward(self, output: List[Tensor]) -> None: ...
51
- def get_backward_stats(self) -> List[int]: ...
52
- def _install_post_backward_futures(self, futures: List[Future]) -> None: ...
53
- def _rebuild_buckets(self) -> bool: ...
54
- def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ...
55
- def _push_all_rebuilt_params(self) -> None: ...
56
- def _set_forward_pass_work_handle(
57
- self,
58
- work: Work,
59
- use_static_world_size: bool,
60
- ): ...
61
- def _get_local_used_map(self) -> Tensor: ...
62
- def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
63
- def _set_static_graph(self) -> None: ...
64
- def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
65
- def set_logger(self, logger: Logger) -> None: ...
66
- def _remove_autograd_hooks(self) -> None: ...
67
- def _check_reducer_finalized(self) -> None: ...
68
- def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
69
- def _reset_state(self) -> None: ...
70
- def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
71
-
72
- class DDPLoggingData:
73
- strs_map: Dict[str, str]
74
- ints_map: Dict[str, int]
75
-
76
- class Logger:
77
- def __init__(self, reducer: Reducer): ...
78
- def set_construction_data_and_log(
79
- self,
80
- module_name: str,
81
- device_ids: List[int],
82
- output_device: int,
83
- broadcast_buffers: bool,
84
- has_sync_bn: bool,
85
- static_graph: bool,
86
- ): ...
87
- def set_runtime_stats_and_log(self) -> None: ...
88
- def set_error_and_log(self, error: str) -> None: ...
89
- def _get_ddp_logging_data(self) -> DDPLoggingData: ...
90
- def _set_comm_hook_name(self, comm_hook: str) -> None: ...
91
- def _set_uneven_input_join(self) -> None: ...
92
- def _set_static_graph(self) -> None: ...
93
-
94
- def get_debug_level(): ...
95
- def set_debug_level(): ...
96
- def set_debug_level_from_env(): ...
97
-
98
- class DebugLevel(Enum):
99
- OFF = ...
100
- INFO = ...
101
- DETAIL = ...
102
-
103
- class ReduceOp:
104
- def __init__(self, op: RedOpType): ...
105
-
106
- SUM: RedOpType = ...
107
- AVG: RedOpType = ...
108
- PRODUCT: RedOpType = ...
109
- MIN: RedOpType = ...
110
- MAX: RedOpType = ...
111
- BAND: RedOpType = ...
112
- BOR: RedOpType = ...
113
- BXOR: RedOpType = ...
114
- PREMUL_SUM: RedOpType = ...
115
- UNUSED: RedOpType = ...
116
-
117
- class RedOpType(Enum): ...
118
-
119
- class BroadcastOptions:
120
- rootRank: int
121
- rootTensor: int
122
- timeout: timedelta
123
- asyncOp: bool
124
-
125
- class AllreduceOptions:
126
- reduceOp: ReduceOp
127
- timeout: timedelta
128
-
129
- class AllreduceCoalescedOptions(AllreduceOptions): ...
130
-
131
- class ReduceOptions:
132
- reduceOp: ReduceOp
133
- rootRank: int
134
- rootTensor: int
135
- timeout: timedelta
136
-
137
- class AllgatherOptions:
138
- timeout: timedelta
139
- asyncOp: bool
140
-
141
- class GatherOptions:
142
- rootRank: int
143
- timeout: timedelta
144
-
145
- class ScatterOptions:
146
- rootRank: int
147
- timeout: timedelta
148
- asyncOp: bool
149
-
150
- class ReduceScatterOptions:
151
- reduceOp: ReduceOp
152
- timeout: timedelta
153
- asyncOp: bool
154
-
155
- class BarrierOptions:
156
- device_ids: List[int]
157
- timeout: timedelta
158
-
159
- class AllToAllOptions:
160
- timeout: timedelta
161
-
162
- class Store:
163
- def set(self, key: str, value: str): ...
164
- def get(self, key: str) -> bytes: ...
165
- def add(self, key: str, value: int) -> int: ...
166
- def compare_set(
167
- self,
168
- key: str,
169
- expected_value: str,
170
- desired_value: str,
171
- ) -> bytes: ...
172
- def delete_key(self, key: str) -> bool: ...
173
- def num_keys(self) -> int: ...
174
- def set_timeout(self, timeout: timedelta): ...
175
- @overload
176
- def wait(self, keys: List[str]): ...
177
- @overload
178
- def wait(self, keys: List[str], timeout: timedelta): ...
179
-
180
- class FileStore(Store):
181
- def __init__(self, path: str, numWorkers: int = ...): ...
182
-
183
- class HashStore(Store):
184
- def __init__(self): ...
185
-
186
- class TCPStore(Store):
187
- def __init__(
188
- self,
189
- host_name: str,
190
- port: int,
191
- world_size: Optional[int] = ...,
192
- is_master: bool = ...,
193
- timeout: timedelta = ...,
194
- wait_for_workers: bool = ...,
195
- multi_tenant: bool = ...,
196
- master_listen_fd: Optional[int] = ...,
197
- use_libuv: Optional[bool] = ...,
198
- ): ...
199
- @property
200
- def host(self) -> str: ...
201
- @property
202
- def port(self) -> int: ...
203
-
204
- class PrefixStore(Store):
205
- def __init__(self, prefix: str, store: Store): ...
206
- @property
207
- def underlying_store(self) -> Store: ...
208
-
209
- class Work:
210
- def is_completed(self) -> bool: ...
211
- def is_success(self) -> bool: ...
212
- def exception(self) -> Any: ...
213
- def wait(self, timeout: timedelta = ...) -> bool: ...
214
- def source_rank(self) -> int: ...
215
- def _source_rank(self) -> int: ...
216
- def result(self) -> List[Tensor]: ...
217
- def synchronize(self): ...
218
- def boxed(self) -> ScriptObject: ...
219
- @staticmethod
220
- def unbox(obj: ScriptObject) -> Work: ...
221
-
222
- class ProcessGroup:
223
- class Options: ...
224
-
225
- def __init__(self): ...
226
- def rank(self) -> int: ...
227
- def size(self) -> int: ...
228
- @overload
229
- def broadcast(
230
- self,
231
- tensors: List[Tensor],
232
- opts=...,
233
- ) -> Work: ...
234
- @overload
235
- def broadcast(
236
- self,
237
- tensor: Tensor,
238
- root: int,
239
- ) -> Work: ...
240
- @overload
241
- def allreduce(
242
- self,
243
- tensors: List[Tensor],
244
- opts: AllreduceOptions = ...,
245
- ) -> Work: ...
246
- @overload
247
- def allreduce(
248
- self,
249
- tensors: List[Tensor],
250
- op=...,
251
- ) -> Work: ...
252
- @overload
253
- def allreduce(
254
- self,
255
- tensor: Tensor,
256
- op=...,
257
- ) -> Work: ...
258
- def allreduce_coalesced(
259
- self,
260
- tensors: List[Tensor],
261
- opts=...,
262
- ) -> Work: ...
263
- @overload
264
- def reduce(
265
- self,
266
- tensors: List[Tensor],
267
- opts=...,
268
- ) -> Work: ...
269
- @overload
270
- def reduce(
271
- self,
272
- tensor: Tensor,
273
- root: int,
274
- op=...,
275
- ) -> Work: ...
276
- @overload
277
- def allgather(
278
- self,
279
- output_tensors: List[List[Tensor]],
280
- input_tensors: List[Tensor],
281
- opts=...,
282
- ) -> Work: ...
283
- @overload
284
- def allgather(
285
- self,
286
- output_tensors: List[Tensor],
287
- input_tensor: Tensor,
288
- ) -> Work: ...
289
- def _allgather_base(
290
- self,
291
- output: Tensor,
292
- input: Tensor,
293
- opts=...,
294
- ) -> Work: ...
295
- def allgather_coalesced(
296
- self,
297
- output_lists: List[List[Tensor]],
298
- input_list: List[Tensor],
299
- opts=...,
300
- ) -> Work: ...
301
- @overload
302
- def gather(
303
- self,
304
- output_tensors: List[List[Tensor]],
305
- input_tensors: List[Tensor],
306
- opts=...,
307
- ) -> Work: ...
308
- @overload
309
- def gather(
310
- self,
311
- output_tensors: List[Tensor],
312
- input_tensor: Tensor,
313
- root: int,
314
- ) -> Work: ...
315
- @overload
316
- def scatter(
317
- self,
318
- output_tensors: List[Tensor],
319
- input_tensors: List[List[Tensor]],
320
- opts=...,
321
- ) -> Work: ...
322
- @overload
323
- def scatter(
324
- self,
325
- output_tensor: Tensor,
326
- input_tensors: List[Tensor],
327
- root: int,
328
- ) -> Work: ...
329
- @overload
330
- def reduce_scatter(
331
- self,
332
- output_tensors: List[Tensor],
333
- input_tensors: List[List[Tensor]],
334
- opts=...,
335
- ) -> Work: ...
336
- @overload
337
- def reduce_scatter(
338
- self,
339
- output_tensors: Tensor,
340
- input_tensor: List[Tensor],
341
- ) -> Work: ...
342
- def _reduce_scatter_base(
343
- self,
344
- outputTensor: Tensor,
345
- inputTensor: Tensor,
346
- ) -> Work: ...
347
- @overload
348
- def alltoall_base(
349
- self,
350
- output_tensor: Tensor,
351
- input_tensor: Tensor,
352
- output_split_sizes: List[int],
353
- input_split_sizes: List[int],
354
- opts=...,
355
- ) -> Work: ...
356
- @overload
357
- def alltoall_base(
358
- self,
359
- output: Tensor,
360
- input: Tensor,
361
- output_split_sizes: List[int],
362
- input_split_sizes: List[int],
363
- ) -> Work: ...
364
- @overload
365
- def alltoall(
366
- self,
367
- output_tensor: List[Tensor],
368
- input_tensor: List[Tensor],
369
- opts=...,
370
- ) -> Work: ...
371
- @overload
372
- def alltoall(
373
- self,
374
- output: List[Tensor],
375
- input: List[Tensor],
376
- ) -> Work: ...
377
- def send(
378
- self,
379
- tensors: List[Tensor],
380
- dstRank: int,
381
- tag: int,
382
- ) -> Work: ...
383
- def recv(
384
- self,
385
- tensors: List[Tensor],
386
- srcRank: int,
387
- tag: int,
388
- ) -> Work: ...
389
- def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
390
- def barrier(self, opts=...) -> Work: ...
391
- def boxed(self) -> ScriptObject: ...
392
- @staticmethod
393
- def unbox(obj: ScriptObject) -> ProcessGroup: ...
394
-
395
- class ProcessGroupRoundRobin(ProcessGroup): ...
396
-
397
- def _round_robin_process_groups(
398
- process_groups: List[ProcessGroup],
399
- ) -> ProcessGroupRoundRobin: ...
400
-
401
- class ProcessGroupGloo(ProcessGroup):
402
- class Device: ...
403
- class Options: ...
404
-
405
- def __init__(
406
- self,
407
- store: Store,
408
- rank: int,
409
- size: int,
410
- timeout: timedelta,
411
- ): ...
412
- @staticmethod
413
- def create_device(hostname="", interface="") -> Device: ...
414
- @staticmethod
415
- def create_default_device() -> Device: ...
416
-
417
- class _ProcessGroupWrapper(ProcessGroup):
418
- def __init__(self, pg: ProcessGroup, gloo_pg: ProcessGroupGloo): ...
419
- wrapped_pg: ProcessGroup
420
-
421
- class ProcessGroupNCCL(ProcessGroup):
422
- class Options: ...
423
-
424
- def __init__(
425
- self,
426
- store: Store,
427
- rank: int,
428
- size: int,
429
- timeout: timedelta,
430
- ): ...
431
- def _group_start(self) -> None: ...
432
- def _group_end(self) -> None: ...
433
-
434
- class ProcessGroupUCC(ProcessGroup):
435
- def __init__(
436
- self,
437
- store: Store,
438
- rank: int,
439
- size: int,
440
- timeout: timedelta,
441
- ): ...
442
-
443
- class ProcessGroupMPI(ProcessGroup):
444
- def __init__(
445
- self,
446
- rank: int,
447
- size: int,
448
- pgComm: int,
449
- ): ...
450
- @staticmethod
451
- def create(ranks: List[int]) -> ProcessGroupMPI: ...
452
-
453
- def _compute_bucket_assignment_by_size(
454
- tensors: List[Tensor],
455
- bucket_size_limits: List[int],
456
- expect_sparse_gradient: List[bool] = ...,
457
- tensor_indices: List[int] = ...,
458
- ) -> Tuple[List[List[int]], List[int]]: ...
459
- def _broadcast_coalesced(
460
- process_group: ProcessGroup,
461
- tensors: List[Tensor],
462
- buffer_size: int,
463
- src: int,
464
- ): ...
465
- def _test_python_store(store: Store): ...
466
- def _verify_params_across_processes(
467
- process_group: ProcessGroup,
468
- params: List[Tensor],
469
- logger: Optional[Logger],
470
- ): ...
471
- def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
472
-
473
- class Backend:
474
- def __init__(
475
- self,
476
- rank: int,
477
- size: int,
478
- ): ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_distributed_rpc.pyi DELETED
@@ -1,188 +0,0 @@
1
- # mypy: disable-error-code="type-arg"
2
- from datetime import timedelta
3
- from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar
4
-
5
- import torch
6
-
7
- from . import Future
8
- from ._autograd import ProfilerEvent
9
- from ._distributed_c10d import Store
10
- from ._profiler import ProfilerConfig
11
-
12
- # This module is defined in torch/csrc/distributed/rpc/init.cpp
13
-
14
- _DEFAULT_INIT_METHOD: str
15
- _DEFAULT_NUM_WORKER_THREADS: int
16
- _UNSET_RPC_TIMEOUT: float
17
- _DEFAULT_RPC_TIMEOUT_SEC: float
18
-
19
- _T = TypeVar("_T")
20
-
21
- class RpcBackendOptions:
22
- rpc_timeout: float
23
- init_method: str
24
- def __init__(
25
- self,
26
- rpc_timeout: float = ...,
27
- init_method: str = ...,
28
- ): ...
29
-
30
- class WorkerInfo:
31
- def __init__(self, name: str, worker_id: int): ...
32
- @property
33
- def name(self) -> str: ...
34
- @property
35
- def id(self) -> int: ...
36
- def __eq__(self, other: object) -> bool: ...
37
-
38
- class RpcAgent:
39
- def join(self, shutdown: bool = False, timeout: float = 0): ...
40
- def sync(self): ...
41
- def shutdown(self): ...
42
- @overload
43
- def get_worker_info(self) -> WorkerInfo: ...
44
- @overload
45
- def get_worker_info(self, workerName: str) -> WorkerInfo: ...
46
- def get_worker_infos(self) -> List[WorkerInfo]: ...
47
- def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
48
- def get_debug_info(self) -> Dict[str, str]: ...
49
- def get_metrics(self) -> Dict[str, str]: ...
50
-
51
- class PyRRef(Generic[_T]):
52
- def __init__(self, value: _T, type_hint: Any = None) -> None: ...
53
- def is_owner(self) -> bool: ...
54
- def confirmed_by_owner(self) -> bool: ...
55
- def owner(self) -> WorkerInfo: ...
56
- def owner_name(self) -> str: ...
57
- def to_here(self, timeout: float = ...) -> _T: ...
58
- def local_value(self) -> Any: ...
59
- def rpc_sync(self, timeout: float = ...) -> Any: ...
60
- def rpc_async(self, timeout: float = ...) -> Any: ...
61
- def remote(self, timeout: float = ...) -> Any: ...
62
- def _serialize(self) -> Tuple: ...
63
- @staticmethod
64
- def _deserialize(tp: Tuple) -> PyRRef: ...
65
- def _get_type(self) -> Type[_T]: ...
66
- def _get_future(self) -> Future[_T]: ...
67
- def _get_profiling_future(self) -> Future[_T]: ...
68
- def _set_profiling_future(self, profilingFuture: Future[_T]): ...
69
-
70
- class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
71
- num_worker_threads: int
72
- device_maps: Dict[str, Dict[torch.device, torch.device]]
73
- devices: List[torch.device]
74
- def __init__(
75
- self,
76
- num_worker_threads: int,
77
- _transports: Optional[List],
78
- _channels: Optional[List],
79
- rpc_timeout: float = ...,
80
- init_method: str = ...,
81
- device_maps: Dict[str, Dict[torch.device, torch.device]] = {}, # noqa: B006
82
- devices: List[torch.device] = [], # noqa: B006
83
- ): ...
84
- def _set_device_map(
85
- self,
86
- to: str,
87
- device_map: Dict[torch.device, torch.device],
88
- ): ...
89
-
90
- class TensorPipeAgent(RpcAgent):
91
- def __init__(
92
- self,
93
- store: Store,
94
- name: str,
95
- worker_id: int,
96
- world_size: Optional[int],
97
- opts: _TensorPipeRpcBackendOptionsBase,
98
- reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
99
- devices: List[torch.device],
100
- ): ...
101
- def join(self, shutdown: bool = False, timeout: float = 0): ...
102
- def shutdown(self): ...
103
- @overload
104
- def get_worker_info(self) -> WorkerInfo: ...
105
- @overload
106
- def get_worker_info(self, workerName: str) -> WorkerInfo: ...
107
- @overload
108
- def get_worker_info(self, id: int) -> WorkerInfo: ...
109
- def get_worker_infos(self) -> List[WorkerInfo]: ...
110
- def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
111
- def _update_group_membership(
112
- self,
113
- worker_info: WorkerInfo,
114
- my_devices: List[torch.device],
115
- reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
116
- is_join: bool,
117
- ): ...
118
- def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
119
- @property
120
- def is_static_group(self) -> bool: ...
121
- @property
122
- def store(self) -> Store: ...
123
-
124
- def _is_current_rpc_agent_set() -> bool: ...
125
- def _get_current_rpc_agent() -> RpcAgent: ...
126
- def _set_and_start_rpc_agent(agent: RpcAgent): ...
127
- def _reset_current_rpc_agent(): ...
128
- def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
129
- def _destroy_rref_context(ignoreRRefLeak: bool): ...
130
- def _rref_context_get_debug_info() -> Dict[str, str]: ...
131
- def _cleanup_python_rpc_handler(): ...
132
- def _invoke_rpc_builtin(
133
- dst: WorkerInfo,
134
- opName: str,
135
- rpcTimeoutSeconds: float,
136
- *args: Any,
137
- **kwargs: Any,
138
- ): ...
139
- def _invoke_rpc_python_udf(
140
- dst: WorkerInfo,
141
- pickledPythonUDF: str,
142
- tensors: List[torch.Tensor],
143
- rpcTimeoutSeconds: float,
144
- isAsyncExecution: bool,
145
- ): ...
146
- def _invoke_rpc_torchscript(
147
- dstWorkerName: str,
148
- qualifiedNameStr: str,
149
- argsTuple: Tuple,
150
- kwargsDict: Dict,
151
- rpcTimeoutSeconds: float,
152
- isAsyncExecution: bool,
153
- ): ...
154
- def _invoke_remote_builtin(
155
- dst: WorkerInfo,
156
- opName: str,
157
- rpcTimeoutSeconds: float,
158
- *args: Any,
159
- **kwargs: Any,
160
- ): ...
161
- def _invoke_remote_python_udf(
162
- dst: WorkerInfo,
163
- pickledPythonUDF: str,
164
- tensors: List[torch.Tensor],
165
- rpcTimeoutSeconds: float,
166
- isAsyncExecution: bool,
167
- ): ...
168
- def _invoke_remote_torchscript(
169
- dstWorkerName: WorkerInfo,
170
- qualifiedNameStr: str,
171
- rpcTimeoutSeconds: float,
172
- isAsyncExecution: bool,
173
- *args: Any,
174
- **kwargs: Any,
175
- ): ...
176
- def get_rpc_timeout() -> float: ...
177
- def enable_gil_profiling(flag: bool): ...
178
- def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
179
-
180
- class RemoteProfilerManager:
181
- @staticmethod
182
- def set_current_profiling_key(key: str): ...
183
-
184
- def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
185
- def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
186
- def _set_profiler_node_id(default_node_id: int): ...
187
- def _enable_jit_rref_pickle(): ...
188
- def _disable_jit_rref_pickle(): ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_distributed_rpc_testing.pyi DELETED
@@ -1,35 +0,0 @@
1
- from typing import Dict, List
2
-
3
- import torch
4
-
5
- from ._distributed_c10d import Store
6
- from ._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
7
-
8
- # This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
9
-
10
- class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
11
- def __init__(
12
- self,
13
- num_worker_threads: int,
14
- rpc_timeout: float,
15
- init_method: str,
16
- messages_to_fail: List[str],
17
- messages_to_delay: Dict[str, float],
18
- num_fail_sends: int,
19
- ): ...
20
- num_send_recv_threads: int
21
- messages_to_fail: List[str]
22
- messages_to_delay: Dict[str, float]
23
- num_fail_sends: int
24
-
25
- class FaultyTensorPipeAgent(TensorPipeAgent):
26
- def __init__(
27
- self,
28
- store: Store,
29
- name: str,
30
- rank: int,
31
- world_size: int,
32
- options: FaultyTensorPipeRpcBackendOptions,
33
- reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
34
- devices: List[torch.device],
35
- ): ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_functions.pyi DELETED
@@ -1,11 +0,0 @@
1
- from typing import AnyStr, List
2
-
3
- from torch import Tensor
4
-
5
- class UndefinedGrad:
6
- def __init__(self) -> None: ...
7
- def __call__(self, *inputs: Tensor) -> List[Tensor]: ...
8
-
9
- class DelayedError:
10
- def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
11
- def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_functorch.pyi DELETED
@@ -1,71 +0,0 @@
1
- from enum import Enum
2
- from typing import Optional, Tuple
3
-
4
- from torch import Tensor
5
-
6
- # Defined in torch/csrc/functorch/init.cpp
7
-
8
- def _set_dynamic_layer_keys_included(included: bool) -> None: ...
9
- def get_unwrapped(tensor: Tensor) -> Tensor: ...
10
- def is_batchedtensor(tensor: Tensor) -> bool: ...
11
- def is_functionaltensor(tensor: Tensor) -> bool: ...
12
- def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
13
- def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
14
- def maybe_get_bdim(tensor: Tensor) -> int: ...
15
- def maybe_get_level(tensor: Tensor) -> int: ...
16
- def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
17
- def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
18
- def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
19
- def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
20
- def current_level() -> int: ...
21
- def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
22
- def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
23
- def get_single_level_autograd_function_allowed() -> bool: ...
24
- def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
25
- def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
26
-
27
- # Defined in aten/src/ATen/functorch/Interpreter.h
28
- class TransformType(Enum):
29
- Torch: TransformType = ...
30
- Vmap: TransformType = ...
31
- Grad: TransformType = ...
32
- Jvp: TransformType = ...
33
- Functionalize: TransformType = ...
34
-
35
- class RandomnessType(Enum):
36
- Error: TransformType = ...
37
- Same: TransformType = ...
38
- Different: TransformType = ...
39
-
40
- class CInterpreter:
41
- def key(self) -> TransformType: ...
42
- def level(self) -> int: ...
43
-
44
- class CGradInterpreterPtr:
45
- def __init__(self, interpreter: CInterpreter): ...
46
- def lift(self, Tensor) -> Tensor: ...
47
- def prevGradMode(self) -> bool: ...
48
-
49
- class CJvpInterpreterPtr:
50
- def __init__(self, interpreter: CInterpreter): ...
51
- def lift(self, Tensor) -> Tensor: ...
52
- def prevFwdGradMode(self) -> bool: ...
53
-
54
- class CFunctionalizeInterpreterPtr:
55
- def __init__(self, interpreter: CInterpreter): ...
56
- def key(self) -> TransformType: ...
57
- def level(self) -> int: ...
58
- def functionalizeAddBackViews(self) -> bool: ...
59
-
60
- class CVmapInterpreterPtr:
61
- def __init__(self, interpreter: CInterpreter): ...
62
- def key(self) -> TransformType: ...
63
- def level(self) -> int: ...
64
- def batchSize(self) -> int: ...
65
- def randomness(self) -> RandomnessType: ...
66
-
67
- class DynamicLayer: ...
68
-
69
- def peek_interpreter_stack() -> CInterpreter: ...
70
- def pop_dynamic_layer_stack() -> DynamicLayer: ...
71
- def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_itt.pyi DELETED
@@ -1,5 +0,0 @@
1
- # Defined in torch/csrc/itt.cpp
2
- def is_available() -> None: ...
3
- def rangePush(message: str) -> None: ...
4
- def rangePop() -> None: ...
5
- def mark(message: str) -> None: ...
 
 
 
 
 
 
torch/_C/_lazy.pyi DELETED
@@ -1,28 +0,0 @@
1
- from typing import List
2
-
3
- from torch import Tensor
4
-
5
- # defined in torch/csrc/lazy/python/init.cpp
6
- def _mark_step(device: str, devices: List[str], wait: bool): ...
7
- def _wait_device_ops(devices: List[str]): ...
8
- def _reset_metrics(): ...
9
- def _counter_names() -> List[str]: ...
10
- def _counter_value(name: str) -> int: ...
11
- def _metrics_report() -> str: ...
12
- def _get_graph_hash(tensors: List[Tensor]) -> str: ...
13
- def _sync_multi(
14
- tensors: List[Tensor],
15
- devices: List[str],
16
- wait: bool = True,
17
- sync_ltc_data: bool = True,
18
- ): ...
19
- def _get_tensor_id(tensor: Tensor) -> int: ...
20
- def _get_tensors_text(tensors: List[Tensor]) -> str: ...
21
- def _get_tensors_dot(tensors: List[Tensor]) -> str: ...
22
- def _get_tensors_backend(tensors: List[Tensor]) -> str: ...
23
- def _get_force_fallback() -> str: ...
24
- def _set_force_fallback(newval: str): ...
25
- def _clear_ir_cache(): ...
26
- def _dump_ir_cache(filename: str): ...
27
- def _set_reuse_ir(val: bool): ...
28
- def _get_default_device_type(): ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_lazy_ts_backend.pyi DELETED
@@ -1,11 +0,0 @@
1
- # defined in torch/csrc/lazy/python/init.cpp
2
-
3
- from typing import Any, List, Tuple
4
-
5
- from torch import Tensor
6
-
7
- def _init(): ...
8
- def _get_tensors_ts_device_data_node(
9
- tensors: List[Tensor],
10
- ) -> Tuple[List[int], List[Any]]: ...
11
- def _run_cached_graph(hash_str: str, graph_inputs: List[Any]) -> List[Tensor]: ...
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_monitor.pyi DELETED
@@ -1,44 +0,0 @@
1
- # Defined in torch/csrc/monitor/python_init.cpp
2
-
3
- import datetime
4
- from enum import Enum
5
- from typing import Callable, Dict, List, Union
6
-
7
- class Aggregation(Enum):
8
- VALUE = ...
9
- MEAN = ...
10
- COUNT = ...
11
- SUM = ...
12
- MAX = ...
13
- MIN = ...
14
-
15
- class Stat:
16
- name: str
17
- count: int
18
- def __init__(
19
- self,
20
- name: str,
21
- aggregations: List[Aggregation],
22
- window_size: int,
23
- max_samples: int = -1,
24
- ) -> None: ...
25
- def add(self, v: float) -> None: ...
26
- def get(self) -> Dict[Aggregation, float]: ...
27
-
28
- class Event:
29
- name: str
30
- timestamp: datetime.datetime
31
- data: Dict[str, Union[int, float, bool, str]]
32
- def __init__(
33
- self,
34
- name: str,
35
- timestamp: datetime.datetime,
36
- data: Dict[str, Union[int, float, bool, str]],
37
- ) -> None: ...
38
-
39
- def log_event(e: Event) -> None: ...
40
-
41
- class EventHandlerHandle: ...
42
-
43
- def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ...
44
- def unregister_event_handler(handle: EventHandlerHandle) -> None: ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_nn.pyi DELETED
@@ -1,86 +0,0 @@
1
- # mypy: disable-error-code="type-arg"
2
- from typing import List, Optional, overload, Sequence, Tuple, Union
3
-
4
- from torch import memory_format, Tensor
5
- from torch.types import _bool, _device, _dtype, _int, _size
6
-
7
- # Defined in tools/autograd/templates/python_nn_functions.cpp
8
-
9
- def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
10
- def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
11
- def avg_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
12
- def avg_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
13
- def elu_(input: Tensor, alpha: float = ...) -> Tensor: ...
14
- def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
15
- def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
16
- def gelu(input: Tensor, approximate: str = ...) -> Tensor: ...
17
- def hardsigmoid(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ...
18
- def hardtanh(input: Tensor, min_val: float = ..., max_val: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
19
- def hardtanh_(input: Tensor, min_val: float = ..., max_val: float = ...) -> Tensor: ...
20
- def leaky_relu(input: Tensor, negative_slope: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
21
- def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ...
22
- def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: ...
23
- def log_sigmoid(input: Tensor) -> Tensor: ...
24
- def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ...
25
- def pad(input: Tensor, pad: Sequence[int], mode: str = ..., value: Optional[float] = None) -> Tensor: ...
26
- def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None) -> Tensor: ...
27
- def softplus(input: Tensor, beta: int = ..., threshold: int = ...) -> Tensor: ...
28
- def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
29
-
30
- # Defined in aten/src/ATen/native/mkldnn/Linear.cpp
31
- def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ...
32
-
33
- # Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
34
- def mkldnn_reorder_conv2d_weight(
35
- self: Tensor,
36
- padding: List,
37
- stride: List,
38
- dilatation: List,
39
- groups: int,
40
- ) -> Tensor: ...
41
- def mkldnn_reorder_conv3d_weight(
42
- self: Tensor,
43
- padding: List,
44
- stride: List,
45
- dilatation: List,
46
- groups: int,
47
- ) -> Tensor: ...
48
-
49
- # Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
50
- def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
51
-
52
- # Defined at tools/autograd/templates/python_nn_functions.cpp
53
- @overload
54
- def _parse_to(
55
- device: _device,
56
- dtype: _dtype,
57
- non_blocking: _bool,
58
- copy: _bool,
59
- *,
60
- memory_format: memory_format,
61
- ) -> Tuple[_device, _dtype, _bool, memory_format]: ...
62
- @overload
63
- def _parse_to(
64
- dtype: _dtype,
65
- non_blocking: _bool,
66
- copy: _bool,
67
- *,
68
- memory_format: memory_format,
69
- ) -> Tuple[_device, _dtype, _bool, memory_format]: ...
70
- @overload
71
- def _parse_to(
72
- tensor: Tensor,
73
- non_blocking: _bool,
74
- copy: _bool,
75
- *,
76
- memory_format: memory_format,
77
- ) -> Tuple[_device, _dtype, _bool, memory_format]: ...
78
-
79
- # Defined in aten/src/ATen/native/PadSequence.cpp
80
- def pad_sequence(
81
- sequences: List[Tensor],
82
- batch_first: bool = False,
83
- padding_value: float = ...,
84
- ) -> Tensor: ...
85
- def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ...
86
- def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_nvtx.pyi DELETED
@@ -1,6 +0,0 @@
1
- # Defined in torch/csrc/cuda/shared/nvtx.cpp
2
- def rangePushA(message: str) -> int: ...
3
- def rangePop() -> int: ...
4
- def rangeStartA(message: str) -> int: ...
5
- def rangeEnd(int) -> None: ...
6
- def markA(message: str) -> None: ...
 
 
 
 
 
 
 
torch/_C/_onnx.pyi DELETED
@@ -1,38 +0,0 @@
1
- # Defined in torch/csrc/onnx/init.cpp
2
-
3
- from enum import Enum
4
-
5
- _CAFFE2_ATEN_FALLBACK: bool
6
- PRODUCER_VERSION: str
7
-
8
- class TensorProtoDataType(Enum):
9
- UNDEFINED = ...
10
- FLOAT = ...
11
- UINT8 = ...
12
- INT8 = ...
13
- UINT16 = ...
14
- INT16 = ...
15
- INT32 = ...
16
- INT64 = ...
17
- STRING = ...
18
- BOOL = ...
19
- FLOAT16 = ...
20
- DOUBLE = ...
21
- UINT32 = ...
22
- UINT64 = ...
23
- COMPLEX64 = ...
24
- COMPLEX128 = ...
25
- BFLOAT16 = ...
26
- FLOAT8E5M2 = ...
27
- FLOAT8E4M3FN = ...
28
-
29
- class OperatorExportTypes(Enum):
30
- ONNX = ...
31
- ONNX_ATEN = ...
32
- ONNX_ATEN_FALLBACK = ...
33
- ONNX_FALLTHROUGH = ...
34
-
35
- class TrainingMode(Enum):
36
- EVAL = ...
37
- PRESERVE = ...
38
- TRAINING = ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_profiler.pyi DELETED
@@ -1,238 +0,0 @@
1
- from enum import Enum
2
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
3
-
4
- from torch._C import device, dtype, layout
5
- from typing_extensions import TypeAlias
6
-
7
- # defined in torch/csrc/profiler/python/init.cpp
8
-
9
- class RecordScope(Enum):
10
- FUNCTION = ...
11
- BACKWARD_FUNCTION = ...
12
- TORCHSCRIPT_FUNCTION = ...
13
- KERNEL_FUNCTION_DTYPE = ...
14
- CUSTOM_CLASS = ...
15
- BUILD_FEATURE = ...
16
- LITE_INTERPRETER = ...
17
- USER_SCOPE = ...
18
- STATIC_RUNTIME_OP = ...
19
- STATIC_RUNTIME_MODEL = ...
20
-
21
- class ProfilerState(Enum):
22
- Disable = ...
23
- CPU = ...
24
- CUDA = ...
25
- NVTX = ...
26
- ITT = ...
27
- KINETO = ...
28
- KINETO_GPU_FALLBACK = ...
29
- KINETO_PRIVATEUSE1_FALLBACK = ...
30
- KINETO_PRIVATEUSE1 = ...
31
-
32
- class ActiveProfilerType(Enum):
33
- NONE = ...
34
- LEGACY = ...
35
- KINETO = ...
36
- NVTX = ...
37
- ITT = ...
38
-
39
- class ProfilerActivity(Enum):
40
- CPU = ...
41
- CUDA = ...
42
- MTIA = ...
43
- PrivateUse1 = ...
44
-
45
- class _EventType(Enum):
46
- TorchOp = ...
47
- Backend = ...
48
- Allocation = ...
49
- OutOfMemory = ...
50
- PyCall = ...
51
- PyCCall = ...
52
- Kineto = ...
53
-
54
- class _ExperimentalConfig:
55
- def __init__(
56
- self,
57
- profiler_metrics: List[str] = ...,
58
- profiler_measure_per_kernel: bool = ...,
59
- verbose: bool = ...,
60
- performance_events: List[str] = ...,
61
- enable_cuda_sync_events: bool = ...,
62
- ) -> None: ...
63
-
64
- class ProfilerConfig:
65
- def __init__(
66
- self,
67
- state: ProfilerState,
68
- report_input_shapes: bool,
69
- profile_memory: bool,
70
- with_stack: bool,
71
- with_flops: bool,
72
- with_modules: bool,
73
- experimental_config: _ExperimentalConfig,
74
- ) -> None: ...
75
-
76
- class _ProfilerEvent:
77
- start_tid: int
78
- start_time_ns: int
79
- children: List[_ProfilerEvent]
80
-
81
- # TODO(robieta): remove in favor of `self.typed`
82
- extra_fields: Union[
83
- _ExtraFields_TorchOp,
84
- _ExtraFields_Backend,
85
- _ExtraFields_Allocation,
86
- _ExtraFields_OutOfMemory,
87
- _ExtraFields_PyCall,
88
- _ExtraFields_PyCCall,
89
- _ExtraFields_Kineto,
90
- ]
91
-
92
- @property
93
- def typed(
94
- self,
95
- ) -> Union[
96
- Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
97
- Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
98
- Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
99
- Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
100
- Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
101
- Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
102
- Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
103
- ]: ...
104
- @property
105
- def name(self) -> str: ...
106
- @property
107
- def tag(self) -> _EventType: ...
108
- @property
109
- def id(self) -> int: ...
110
- @property
111
- def parent(self) -> Optional[_ProfilerEvent]: ...
112
- @property
113
- def correlation_id(self) -> int: ...
114
- @property
115
- def end_time_ns(self) -> int: ...
116
- @property
117
- def duration_time_ns(self) -> int: ...
118
-
119
- class _TensorMetadata:
120
- impl_ptr: Optional[int]
121
- storage_data_ptr: Optional[int]
122
- id: Optional[int]
123
-
124
- @property
125
- def allocation_id(self) -> Optional[int]: ...
126
- @property
127
- def layout(self) -> layout: ...
128
- @property
129
- def device(self) -> device: ...
130
- @property
131
- def dtype(self) -> dtype: ...
132
- @property
133
- def sizes(self) -> List[int]: ...
134
- @property
135
- def strides(self) -> List[int]: ...
136
-
137
- Scalar: TypeAlias = Union[int, float, bool, complex]
138
- Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
139
-
140
- class _ExtraFields_TorchOp:
141
- name: str
142
- sequence_number: int
143
- allow_tf32_cublas: bool
144
-
145
- @property
146
- def inputs(self) -> List[Input]: ...
147
- @property
148
- def scope(self) -> RecordScope: ...
149
-
150
- class _ExtraFields_Backend: ...
151
-
152
- class _ExtraFields_Allocation:
153
- ptr: int
154
- id: Optional[int]
155
- alloc_size: int
156
- total_allocated: int
157
- total_reserved: int
158
-
159
- @property
160
- def allocation_id(self) -> Optional[int]: ...
161
- @property
162
- def device(self) -> device: ...
163
-
164
- class _ExtraFields_OutOfMemory: ...
165
-
166
- class _PyFrameState:
167
- line_number: int
168
- function_name: str
169
-
170
- @property
171
- def file_name(self) -> str: ...
172
-
173
- class _NNModuleInfo:
174
- @property
175
- def self_ptr(self) -> int: ...
176
- @property
177
- def cls_ptr(self) -> int: ...
178
- @property
179
- def cls_name(self) -> str: ...
180
- @property
181
- def parameters(
182
- self,
183
- ) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
184
-
185
- class _OptimizerInfo:
186
- @property
187
- def parameters(
188
- self,
189
- ) -> List[
190
- Tuple[
191
- # Parameter
192
- _TensorMetadata,
193
- #
194
- # Gradient (if present during optimizer.step())
195
- Optional[_TensorMetadata],
196
- #
197
- # Optimizer state for Parameter as (name, tensor) pairs
198
- List[Tuple[str, _TensorMetadata]],
199
- ]
200
- ]: ...
201
-
202
- class _ExtraFields_PyCCall:
203
- @property
204
- def caller(self) -> _PyFrameState: ...
205
-
206
- class _ExtraFields_PyCall:
207
- @property
208
- def callsite(self) -> _PyFrameState: ...
209
- @property
210
- def caller(self) -> _PyFrameState: ...
211
- @property
212
- def module(self) -> Optional[_NNModuleInfo]: ...
213
- @property
214
- def optimizer(self) -> Optional[_OptimizerInfo]: ...
215
-
216
- class _ExtraFields_Kineto: ...
217
-
218
- def _add_execution_trace_observer(output_file_path: str) -> bool: ...
219
- def _remove_execution_trace_observer() -> None: ...
220
- def _enable_execution_trace_observer() -> None: ...
221
- def _disable_execution_trace_observer() -> None: ...
222
- def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
223
- def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
224
- def _set_cuda_sync_enabled_val(val: bool) -> None: ...
225
-
226
- class CapturedTraceback: ...
227
-
228
- def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ...
229
-
230
- # The Dict has name, filename, line
231
- def symbolize_tracebacks(
232
- to_symbolize: List[CapturedTraceback],
233
- ) -> List[List[Dict[str, str]]]: ...
234
-
235
- class _RecordFunctionFast:
236
- def __init__(self, name: str) -> None: ...
237
- def __enter__(self) -> None: ...
238
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_C/_verbose.pyi DELETED
@@ -1,3 +0,0 @@
1
- # Defined in torch/csrc/utils/verbose.cpp
2
- def mkl_set_verbose(enable: int) -> int: ...
3
- def mkldnn_set_verbose(level: int) -> int: ...
 
 
 
 
torch/_VF.py DELETED
@@ -1,30 +0,0 @@
1
- """
2
- This makes the functions in torch._C._VariableFunctions available as
3
- torch._VF.<funcname>
4
- without mypy being able to find them.
5
-
6
- A subset of those functions are mapped to ATen functions in
7
- torch/jit/_builtins.py
8
-
9
- See https://github.com/pytorch/pytorch/issues/21478 for the reason for
10
- introducing torch._VF
11
-
12
- """
13
- import sys
14
- import types
15
-
16
- import torch
17
-
18
-
19
- class VFModule(types.ModuleType):
20
- vf: types.ModuleType
21
-
22
- def __init__(self, name):
23
- super().__init__(name)
24
- self.vf = torch._C._VariableFunctions
25
-
26
- def __getattr__(self, attr):
27
- return getattr(self.vf, attr)
28
-
29
-
30
- sys.modules[__name__] = VFModule(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_VF.pyi DELETED
The diff for this file is too large to render. See raw diff
 
torch/__config__.py DELETED
@@ -1,22 +0,0 @@
1
- import torch
2
-
3
-
4
- def show():
5
- """
6
- Return a human-readable string with descriptions of the
7
- configuration of PyTorch.
8
- """
9
- return torch._C._show_config()
10
-
11
-
12
- # TODO: In principle, we could provide more structured version/config
13
- # information here. For now only CXX_FLAGS is exposed, as Timer
14
- # uses them.
15
- def _cxx_flags():
16
- """Returns the CXX_FLAGS used when building PyTorch."""
17
- return torch._C._cxx_flags()
18
-
19
-
20
- def parallel_info():
21
- r"""Returns detailed string with parallelization settings"""
22
- return torch._C._parallel_info()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/__future__.py DELETED
@@ -1,21 +0,0 @@
1
- """
2
- This global flag controls whether to assign new tensors to the parameters
3
- instead of changing the existing parameters in-place when converting an `nn.Module`
4
- using the following methods:
5
- 1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
6
- 2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
7
- 3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
8
- 4. `module._apply(fn)` (for generic functions applied to `module`)
9
-
10
- Default: False
11
- """
12
- _overwrite_module_params_on_conversion = False
13
-
14
-
15
- def set_overwrite_module_params_on_conversion(value):
16
- global _overwrite_module_params_on_conversion
17
- _overwrite_module_params_on_conversion = value
18
-
19
-
20
- def get_overwrite_module_params_on_conversion():
21
- return _overwrite_module_params_on_conversion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_appdirs.py DELETED
@@ -1,666 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2005-2010 ActiveState Software Inc.
4
- # Copyright (c) 2013 Eddy Petrișor
5
-
6
- # flake8: noqa
7
-
8
- """
9
- This file is directly from
10
- https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py
11
-
12
- The license of https://github.com/ActiveState/appdirs copied below:
13
-
14
-
15
- # This is the MIT license
16
-
17
- Copyright (c) 2010 ActiveState Software Inc.
18
-
19
- Permission is hereby granted, free of charge, to any person obtaining a
20
- copy of this software and associated documentation files (the
21
- "Software"), to deal in the Software without restriction, including
22
- without limitation the rights to use, copy, modify, merge, publish,
23
- distribute, sublicense, and/or sell copies of the Software, and to
24
- permit persons to whom the Software is furnished to do so, subject to
25
- the following conditions:
26
-
27
- The above copyright notice and this permission notice shall be included
28
- in all copies or substantial portions of the Software.
29
-
30
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
31
- OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
32
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
33
- IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
34
- CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
35
- TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
36
- SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
37
- """
38
-
39
- """Utilities for determining application-specific dirs.
40
-
41
- See <https://github.com/ActiveState/appdirs> for details and usage.
42
- """
43
- # Dev Notes:
44
- # - MSDN on where to store app data files:
45
- # http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
46
- # - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
47
- # - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
48
-
49
- __version__ = "1.4.4"
50
- __version_info__ = tuple(int(segment) for segment in __version__.split("."))
51
-
52
-
53
- import os
54
- import sys
55
-
56
- unicode = str
57
-
58
- if sys.platform.startswith("java"):
59
- import platform
60
-
61
- os_name = platform.java_ver()[3][0]
62
- if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc.
63
- system = "win32"
64
- elif os_name.startswith("Mac"): # "Mac OS X", etc.
65
- system = "darwin"
66
- else: # "Linux", "SunOS", "FreeBSD", etc.
67
- # Setting this to "linux2" is not ideal, but only Windows or Mac
68
- # are actually checked for and the rest of the module expects
69
- # *sys.platform* style strings.
70
- system = "linux2"
71
- else:
72
- system = sys.platform
73
-
74
-
75
- def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
76
- r"""Return full path to the user-specific data dir for this application.
77
-
78
- "appname" is the name of application.
79
- If None, just the system directory is returned.
80
- "appauthor" (only used on Windows) is the name of the
81
- appauthor or distributing body for this application. Typically
82
- it is the owning company name. This falls back to appname. You may
83
- pass False to disable it.
84
- "version" is an optional version path element to append to the
85
- path. You might want to use this if you want multiple versions
86
- of your app to be able to run independently. If used, this
87
- would typically be "<major>.<minor>".
88
- Only applied when appname is present.
89
- "roaming" (boolean, default False) can be set True to use the Windows
90
- roaming appdata directory. That means that for users on a Windows
91
- network setup for roaming profiles, this user data will be
92
- sync'd on login. See
93
- <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
94
- for a discussion of issues.
95
-
96
- Typical user data directories are:
97
- Mac OS X: ~/Library/Application Support/<AppName>
98
- Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
99
- Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
100
- Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
101
- Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
102
- Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
103
-
104
- For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
105
- That means, by default "~/.local/share/<AppName>".
106
- """
107
- if system == "win32":
108
- if appauthor is None:
109
- appauthor = appname
110
- const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
111
- path = os.path.normpath(_get_win_folder(const))
112
- if appname:
113
- if appauthor is not False:
114
- path = os.path.join(path, appauthor, appname)
115
- else:
116
- path = os.path.join(path, appname)
117
- elif system == "darwin":
118
- path = os.path.expanduser("~/Library/Application Support/")
119
- if appname:
120
- path = os.path.join(path, appname)
121
- else:
122
- path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
123
- if appname:
124
- path = os.path.join(path, appname)
125
- if appname and version:
126
- path = os.path.join(path, version)
127
- return path
128
-
129
-
130
- def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
131
- r"""Return full path to the user-shared data dir for this application.
132
-
133
- "appname" is the name of application.
134
- If None, just the system directory is returned.
135
- "appauthor" (only used on Windows) is the name of the
136
- appauthor or distributing body for this application. Typically
137
- it is the owning company name. This falls back to appname. You may
138
- pass False to disable it.
139
- "version" is an optional version path element to append to the
140
- path. You might want to use this if you want multiple versions
141
- of your app to be able to run independently. If used, this
142
- would typically be "<major>.<minor>".
143
- Only applied when appname is present.
144
- "multipath" is an optional parameter only applicable to *nix
145
- which indicates that the entire list of data dirs should be
146
- returned. By default, the first item from XDG_DATA_DIRS is
147
- returned, or '/usr/local/share/<AppName>',
148
- if XDG_DATA_DIRS is not set
149
-
150
- Typical site data directories are:
151
- Mac OS X: /Library/Application Support/<AppName>
152
- Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
153
- Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
154
- Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
155
- Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
156
-
157
- For Unix, this is using the $XDG_DATA_DIRS[0] default.
158
-
159
- WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
160
- """
161
- if system == "win32":
162
- if appauthor is None:
163
- appauthor = appname
164
- path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
165
- if appname:
166
- if appauthor is not False:
167
- path = os.path.join(path, appauthor, appname)
168
- else:
169
- path = os.path.join(path, appname)
170
- elif system == "darwin":
171
- path = os.path.expanduser("/Library/Application Support")
172
- if appname:
173
- path = os.path.join(path, appname)
174
- else:
175
- # XDG default for $XDG_DATA_DIRS
176
- # only first, if multipath is False
177
- path = os.getenv(
178
- "XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"])
179
- )
180
- pathlist = [
181
- os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
182
- ]
183
- if appname:
184
- if version:
185
- appname = os.path.join(appname, version)
186
- pathlist = [os.sep.join([x, appname]) for x in pathlist]
187
-
188
- if multipath:
189
- path = os.pathsep.join(pathlist)
190
- else:
191
- path = pathlist[0]
192
- return path
193
-
194
- if appname and version:
195
- path = os.path.join(path, version)
196
- return path
197
-
198
-
199
- def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
200
- r"""Return full path to the user-specific config dir for this application.
201
-
202
- "appname" is the name of application.
203
- If None, just the system directory is returned.
204
- "appauthor" (only used on Windows) is the name of the
205
- appauthor or distributing body for this application. Typically
206
- it is the owning company name. This falls back to appname. You may
207
- pass False to disable it.
208
- "version" is an optional version path element to append to the
209
- path. You might want to use this if you want multiple versions
210
- of your app to be able to run independently. If used, this
211
- would typically be "<major>.<minor>".
212
- Only applied when appname is present.
213
- "roaming" (boolean, default False) can be set True to use the Windows
214
- roaming appdata directory. That means that for users on a Windows
215
- network setup for roaming profiles, this user data will be
216
- sync'd on login. See
217
- <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
218
- for a discussion of issues.
219
-
220
- Typical user config directories are:
221
- Mac OS X: ~/Library/Preferences/<AppName>
222
- Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
223
- Win *: same as user_data_dir
224
-
225
- For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
226
- That means, by default "~/.config/<AppName>".
227
- """
228
- if system == "win32":
229
- path = user_data_dir(appname, appauthor, None, roaming)
230
- elif system == "darwin":
231
- path = os.path.expanduser("~/Library/Preferences/")
232
- if appname:
233
- path = os.path.join(path, appname)
234
- else:
235
- path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
236
- if appname:
237
- path = os.path.join(path, appname)
238
- if appname and version:
239
- path = os.path.join(path, version)
240
- return path
241
-
242
-
243
- def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
244
- r"""Return full path to the user-shared data dir for this application.
245
-
246
- "appname" is the name of application.
247
- If None, just the system directory is returned.
248
- "appauthor" (only used on Windows) is the name of the
249
- appauthor or distributing body for this application. Typically
250
- it is the owning company name. This falls back to appname. You may
251
- pass False to disable it.
252
- "version" is an optional version path element to append to the
253
- path. You might want to use this if you want multiple versions
254
- of your app to be able to run independently. If used, this
255
- would typically be "<major>.<minor>".
256
- Only applied when appname is present.
257
- "multipath" is an optional parameter only applicable to *nix
258
- which indicates that the entire list of config dirs should be
259
- returned. By default, the first item from XDG_CONFIG_DIRS is
260
- returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
261
-
262
- Typical site config directories are:
263
- Mac OS X: same as site_data_dir
264
- Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
265
- $XDG_CONFIG_DIRS
266
- Win *: same as site_data_dir
267
- Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
268
-
269
- For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
270
-
271
- WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
272
- """
273
- if system == "win32":
274
- path = site_data_dir(appname, appauthor)
275
- if appname and version:
276
- path = os.path.join(path, version)
277
- elif system == "darwin":
278
- path = os.path.expanduser("/Library/Preferences")
279
- if appname:
280
- path = os.path.join(path, appname)
281
- else:
282
- # XDG default for $XDG_CONFIG_DIRS
283
- # only first, if multipath is False
284
- path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg")
285
- pathlist = [
286
- os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
287
- ]
288
- if appname:
289
- if version:
290
- appname = os.path.join(appname, version)
291
- pathlist = [os.sep.join([x, appname]) for x in pathlist]
292
-
293
- if multipath:
294
- path = os.pathsep.join(pathlist)
295
- else:
296
- path = pathlist[0]
297
- return path
298
-
299
-
300
- def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
301
- r"""Return full path to the user-specific cache dir for this application.
302
-
303
- "appname" is the name of application.
304
- If None, just the system directory is returned.
305
- "appauthor" (only used on Windows) is the name of the
306
- appauthor or distributing body for this application. Typically
307
- it is the owning company name. This falls back to appname. You may
308
- pass False to disable it.
309
- "version" is an optional version path element to append to the
310
- path. You might want to use this if you want multiple versions
311
- of your app to be able to run independently. If used, this
312
- would typically be "<major>.<minor>".
313
- Only applied when appname is present.
314
- "opinion" (boolean) can be False to disable the appending of
315
- "Cache" to the base app data dir for Windows. See
316
- discussion below.
317
-
318
- Typical user cache directories are:
319
- Mac OS X: ~/Library/Caches/<AppName>
320
- Unix: ~/.cache/<AppName> (XDG default)
321
- Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
322
- Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
323
-
324
- On Windows the only suggestion in the MSDN docs is that local settings go in
325
- the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
326
- app data dir (the default returned by `user_data_dir` above). Apps typically
327
- put cache data somewhere *under* the given dir here. Some examples:
328
- ...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
329
- ...\Acme\SuperApp\Cache\1.0
330
- OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
331
- This can be disabled with the `opinion=False` option.
332
- """
333
- if system == "win32":
334
- if appauthor is None:
335
- appauthor = appname
336
- path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
337
- if appname:
338
- if appauthor is not False:
339
- path = os.path.join(path, appauthor, appname)
340
- else:
341
- path = os.path.join(path, appname)
342
- if opinion:
343
- path = os.path.join(path, "Cache")
344
- elif system == "darwin":
345
- path = os.path.expanduser("~/Library/Caches")
346
- if appname:
347
- path = os.path.join(path, appname)
348
- else:
349
- path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
350
- if appname:
351
- path = os.path.join(path, appname)
352
- if appname and version:
353
- path = os.path.join(path, version)
354
- return path
355
-
356
-
357
- def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
358
- r"""Return full path to the user-specific state dir for this application.
359
-
360
- "appname" is the name of application.
361
- If None, just the system directory is returned.
362
- "appauthor" (only used on Windows) is the name of the
363
- appauthor or distributing body for this application. Typically
364
- it is the owning company name. This falls back to appname. You may
365
- pass False to disable it.
366
- "version" is an optional version path element to append to the
367
- path. You might want to use this if you want multiple versions
368
- of your app to be able to run independently. If used, this
369
- would typically be "<major>.<minor>".
370
- Only applied when appname is present.
371
- "roaming" (boolean, default False) can be set True to use the Windows
372
- roaming appdata directory. That means that for users on a Windows
373
- network setup for roaming profiles, this user data will be
374
- sync'd on login. See
375
- <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
376
- for a discussion of issues.
377
-
378
- Typical user state directories are:
379
- Mac OS X: same as user_data_dir
380
- Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
381
- Win *: same as user_data_dir
382
-
383
- For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
384
- to extend the XDG spec and support $XDG_STATE_HOME.
385
-
386
- That means, by default "~/.local/state/<AppName>".
387
- """
388
- if system in ["win32", "darwin"]:
389
- path = user_data_dir(appname, appauthor, None, roaming)
390
- else:
391
- path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state"))
392
- if appname:
393
- path = os.path.join(path, appname)
394
- if appname and version:
395
- path = os.path.join(path, version)
396
- return path
397
-
398
-
399
- def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
400
- r"""Return full path to the user-specific log dir for this application.
401
-
402
- "appname" is the name of application.
403
- If None, just the system directory is returned.
404
- "appauthor" (only used on Windows) is the name of the
405
- appauthor or distributing body for this application. Typically
406
- it is the owning company name. This falls back to appname. You may
407
- pass False to disable it.
408
- "version" is an optional version path element to append to the
409
- path. You might want to use this if you want multiple versions
410
- of your app to be able to run independently. If used, this
411
- would typically be "<major>.<minor>".
412
- Only applied when appname is present.
413
- "opinion" (boolean) can be False to disable the appending of
414
- "Logs" to the base app data dir for Windows, and "log" to the
415
- base cache dir for Unix. See discussion below.
416
-
417
- Typical user log directories are:
418
- Mac OS X: ~/Library/Logs/<AppName>
419
- Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
420
- Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
421
- Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
422
-
423
- On Windows the only suggestion in the MSDN docs is that local settings
424
- go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
425
- examples of what some windows apps use for a logs dir.)
426
-
427
- OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
428
- value for Windows and appends "log" to the user cache dir for Unix.
429
- This can be disabled with the `opinion=False` option.
430
- """
431
- if system == "darwin":
432
- path = os.path.join(os.path.expanduser("~/Library/Logs"), appname)
433
- elif system == "win32":
434
- path = user_data_dir(appname, appauthor, version)
435
- version = False
436
- if opinion:
437
- path = os.path.join(path, "Logs")
438
- else:
439
- path = user_cache_dir(appname, appauthor, version)
440
- version = False
441
- if opinion:
442
- path = os.path.join(path, "log")
443
- if appname and version:
444
- path = os.path.join(path, version)
445
- return path
446
-
447
-
448
- class AppDirs(object):
449
- """Convenience wrapper for getting application dirs."""
450
-
451
- def __init__(
452
- self, appname=None, appauthor=None, version=None, roaming=False, multipath=False
453
- ):
454
- self.appname = appname
455
- self.appauthor = appauthor
456
- self.version = version
457
- self.roaming = roaming
458
- self.multipath = multipath
459
-
460
- @property
461
- def user_data_dir(self):
462
- return user_data_dir(
463
- self.appname, self.appauthor, version=self.version, roaming=self.roaming
464
- )
465
-
466
- @property
467
- def site_data_dir(self):
468
- return site_data_dir(
469
- self.appname, self.appauthor, version=self.version, multipath=self.multipath
470
- )
471
-
472
- @property
473
- def user_config_dir(self):
474
- return user_config_dir(
475
- self.appname, self.appauthor, version=self.version, roaming=self.roaming
476
- )
477
-
478
- @property
479
- def site_config_dir(self):
480
- return site_config_dir(
481
- self.appname, self.appauthor, version=self.version, multipath=self.multipath
482
- )
483
-
484
- @property
485
- def user_cache_dir(self):
486
- return user_cache_dir(self.appname, self.appauthor, version=self.version)
487
-
488
- @property
489
- def user_state_dir(self):
490
- return user_state_dir(self.appname, self.appauthor, version=self.version)
491
-
492
- @property
493
- def user_log_dir(self):
494
- return user_log_dir(self.appname, self.appauthor, version=self.version)
495
-
496
-
497
- # ---- internal support stuff
498
-
499
-
500
- def _get_win_folder_from_registry(csidl_name):
501
- """This is a fallback technique at best. I'm not sure if using the
502
- registry for this guarantees us the correct answer for all CSIDL_*
503
- names.
504
- """
505
- import winreg as _winreg
506
-
507
- shell_folder_name = {
508
- "CSIDL_APPDATA": "AppData",
509
- "CSIDL_COMMON_APPDATA": "Common AppData",
510
- "CSIDL_LOCAL_APPDATA": "Local AppData",
511
- }[csidl_name]
512
-
513
- key = _winreg.OpenKey(
514
- _winreg.HKEY_CURRENT_USER,
515
- r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
516
- )
517
- dir, type = _winreg.QueryValueEx(key, shell_folder_name)
518
- return dir
519
-
520
-
521
- def _get_win_folder_with_pywin32(csidl_name):
522
- from win32com.shell import shell, shellcon
523
-
524
- dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
525
- # Try to make this a unicode path because SHGetFolderPath does
526
- # not return unicode strings when there is unicode data in the
527
- # path.
528
- try:
529
- dir = unicode(dir)
530
-
531
- # Downgrade to short path name if have highbit chars. See
532
- # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
533
- has_high_char = False
534
- for c in dir:
535
- if ord(c) > 255:
536
- has_high_char = True
537
- break
538
- if has_high_char:
539
- try:
540
- import win32api
541
-
542
- dir = win32api.GetShortPathName(dir)
543
- except ImportError:
544
- pass
545
- except UnicodeError:
546
- pass
547
- return dir
548
-
549
-
550
- def _get_win_folder_with_ctypes(csidl_name):
551
- import ctypes
552
-
553
- csidl_const = {
554
- "CSIDL_APPDATA": 26,
555
- "CSIDL_COMMON_APPDATA": 35,
556
- "CSIDL_LOCAL_APPDATA": 28,
557
- }[csidl_name]
558
-
559
- buf = ctypes.create_unicode_buffer(1024)
560
- ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
561
-
562
- # Downgrade to short path name if have highbit chars. See
563
- # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
564
- has_high_char = False
565
- for c in buf:
566
- if ord(c) > 255:
567
- has_high_char = True
568
- break
569
- if has_high_char:
570
- buf2 = ctypes.create_unicode_buffer(1024)
571
- if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
572
- buf = buf2
573
-
574
- return buf.value
575
-
576
-
577
- def _get_win_folder_with_jna(csidl_name):
578
- import array
579
-
580
- from com.sun import jna
581
- from com.sun.jna.platform import win32
582
-
583
- buf_size = win32.WinDef.MAX_PATH * 2
584
- buf = array.zeros("c", buf_size)
585
- shell = win32.Shell32.INSTANCE
586
- shell.SHGetFolderPath(
587
- None,
588
- getattr(win32.ShlObj, csidl_name),
589
- None,
590
- win32.ShlObj.SHGFP_TYPE_CURRENT,
591
- buf,
592
- )
593
- dir = jna.Native.toString(buf.tostring()).rstrip("\0")
594
-
595
- # Downgrade to short path name if have highbit chars. See
596
- # <http://bugs.activestate.com/show_bug.cgi?id=85099>.
597
- has_high_char = False
598
- for c in dir:
599
- if ord(c) > 255:
600
- has_high_char = True
601
- break
602
- if has_high_char:
603
- buf = array.zeros("c", buf_size)
604
- kernel = win32.Kernel32.INSTANCE
605
- if kernel.GetShortPathName(dir, buf, buf_size):
606
- dir = jna.Native.toString(buf.tostring()).rstrip("\0")
607
-
608
- return dir
609
-
610
-
611
- if system == "win32":
612
- try:
613
- import win32com.shell
614
-
615
- _get_win_folder = _get_win_folder_with_pywin32
616
- except ImportError:
617
- try:
618
- from ctypes import windll
619
-
620
- _get_win_folder = _get_win_folder_with_ctypes
621
- except ImportError:
622
- try:
623
- import com.sun.jna
624
-
625
- _get_win_folder = _get_win_folder_with_jna
626
- except ImportError:
627
- _get_win_folder = _get_win_folder_from_registry
628
-
629
-
630
- # ---- self test code
631
-
632
- if __name__ == "__main__":
633
- appname = "MyApp"
634
- appauthor = "MyCompany"
635
-
636
- props = (
637
- "user_data_dir",
638
- "user_config_dir",
639
- "user_cache_dir",
640
- "user_state_dir",
641
- "user_log_dir",
642
- "site_data_dir",
643
- "site_config_dir",
644
- )
645
-
646
- print(f"-- app dirs {__version__} --")
647
-
648
- print("-- app dirs (with optional 'version')")
649
- dirs = AppDirs(appname, appauthor, version="1.0")
650
- for prop in props:
651
- print(f"{prop}: {getattr(dirs, prop)}")
652
-
653
- print("\n-- app dirs (without optional 'version')")
654
- dirs = AppDirs(appname, appauthor)
655
- for prop in props:
656
- print(f"{prop}: {getattr(dirs, prop)}")
657
-
658
- print("\n-- app dirs (without optional 'appauthor')")
659
- dirs = AppDirs(appname)
660
- for prop in props:
661
- print(f"{prop}: {getattr(dirs, prop)}")
662
-
663
- print("\n-- app dirs (with disabled 'appauthor')")
664
- dirs = AppDirs(appname, appauthor=False)
665
- for prop in props:
666
- print(f"{prop}: {getattr(dirs, prop)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_awaits/__init__.py DELETED
@@ -1,54 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import cast, Callable, Generic, Type, TypeVar
4
-
5
- import torch
6
-
7
- __all__ = ['Await']
8
-
9
- W = TypeVar("W")
10
-
11
- class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
12
- pass
13
-
14
- class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
15
- r"""
16
- Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
17
- of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
18
- ``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
19
-
20
- Torch scriptable manipulations:
21
- ``torch.jit._awaitable(func, *args)``
22
- Creates ``Await[W]`` object, where W is return type of func.
23
-
24
- Returns:
25
- ``torch.jit._awaitable_wait(Await[W])``
26
- Returns the result of the function, specified at ``_awaitable``, with specified arguments.
27
-
28
- Returns:
29
- The result of type ``W`` of the function call. The result is owned by ``Await[W]``
30
- and returned on all following ``_awaitable_wait`` calls.
31
-
32
-
33
- ``torch.jit._awaitable_nowait(W)``
34
- Returns:
35
- Trivial ``Await[W]`` with specified result.
36
-
37
-
38
- Only in eager mode:
39
- ``fn() -> Callable[Tuple[Any], W]``
40
- Returns:
41
- Specified at ``_awaitable`` python function ``func``.
42
-
43
- ``args() -> Tuple[Any]``
44
- Returns:
45
- Specified at ``_awaitable`` python args.
46
-
47
- ``is_nowait() -> _bool``
48
- Returns:
49
- ``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
50
-
51
- In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
52
- ``_awaitable_wait()`` call will be transparently added.
53
- """
54
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_awaits/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (2.08 kB)
 
torch/_classes.py DELETED
@@ -1,55 +0,0 @@
1
- import types
2
-
3
- import torch._C
4
-
5
-
6
- class _ClassNamespace(types.ModuleType):
7
- def __init__(self, name):
8
- super().__init__("torch.classes" + name)
9
- self.name = name
10
-
11
- def __getattr__(self, attr):
12
- proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
13
- if proxy is None:
14
- raise RuntimeError(f"Class {self.name}.{attr} not registered!")
15
- return proxy
16
-
17
-
18
- class _Classes(types.ModuleType):
19
- __file__ = "_classes.py"
20
-
21
- def __init__(self):
22
- super().__init__("torch.classes")
23
-
24
- def __getattr__(self, name):
25
- namespace = _ClassNamespace(name)
26
- setattr(self, name, namespace)
27
- return namespace
28
-
29
- @property
30
- def loaded_libraries(self):
31
- return torch.ops.loaded_libraries
32
-
33
- def load_library(self, path):
34
- """
35
- Loads a shared library from the given path into the current process.
36
-
37
- The library being loaded may run global initialization code to register
38
- custom classes with the PyTorch JIT runtime. This allows dynamically
39
- loading custom classes. For this, you should compile your class
40
- and the static registration code into a shared library object, and then
41
- call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
42
- shared object.
43
-
44
- After the library is loaded, it is added to the
45
- ``torch.classes.loaded_libraries`` attribute, a set that may be inspected
46
- for the paths of all libraries loaded using this function.
47
-
48
- Args:
49
- path (str): A path to a shared library to load.
50
- """
51
- torch.ops.load_library(path)
52
-
53
-
54
- # The classes "namespace"
55
- classes = _Classes()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_compile.py DELETED
@@ -1,30 +0,0 @@
1
- """
2
- APIs related to torch.compile which lazily import torch._dynamo to avoid
3
- circular dependencies.
4
- """
5
- import functools
6
-
7
-
8
- def _disable_dynamo(fn=None, recursive=True):
9
- """
10
- This API should be only used inside torch, external users should still use
11
- torch._dynamo.disable. The main goal of this API is to avoid circular
12
- imports issues that is common while using _dynamo.disable inside torch
13
- itself.
14
-
15
- This API avoids it by lazily importing torch._dynamo from the import time to
16
- the invocation of the decorated function.
17
- """
18
- if fn is not None:
19
-
20
- @functools.wraps(fn)
21
- def inner(*args, **kwargs):
22
- import torch._dynamo
23
-
24
- return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
25
-
26
- return inner
27
- else:
28
- # decorator usage like @_disable_dynamo(recursive=False). The resulting
29
- # object expects the original decorated function as the arg.
30
- return functools.partial(_disable_dynamo, recursive=recursive)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_custom_op/__init__.py DELETED
File without changes
torch/_custom_op/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (156 Bytes)
 
torch/_custom_op/__pycache__/autograd.cpython-310.pyc DELETED
Binary file (8.86 kB)
 
torch/_custom_op/__pycache__/functional.cpython-310.pyc DELETED
Binary file (5.95 kB)
 
torch/_custom_op/__pycache__/impl.cpython-310.pyc DELETED
Binary file (33.5 kB)
 
torch/_custom_op/autograd.py DELETED
@@ -1,274 +0,0 @@
1
- import torch
2
- import torch.utils._pytree as pytree
3
- from collections import namedtuple
4
- import functools
5
-
6
-
7
- # NOTE [CustomOp autograd kernel indirection]
8
- # We register `inner` as the autograd kernel for this custom_op.
9
- # `inner` either calls the autograd formula registered by the user,
10
- # or goes into an `autograd_not_implemented` kernel.
11
- #
12
- # The reason why this indirection exists is
13
- # so that we can swap out the autograd kernel (the PyTorch dispatcher
14
- # doesn't actually allow us to do this). By default, we want
15
- # the `autograd_not_implemented` behavior, but then the user may come
16
- # and register something that is actually a backward formula
17
- def autograd_kernel_indirection(custom_op):
18
- autograd_fallback = autograd_not_implemented(custom_op)
19
-
20
- def inner(*args, **kwargs):
21
- if custom_op._has_impl('autograd'):
22
- kernel = custom_op._get_impl('autograd').func
23
- return kernel(*args, **kwargs)
24
- # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
25
- # after the user gives us "backward" and "save_for_backward", we generate
26
- # the "autograd" impl. If the user only provided one, then we tell
27
- # the user they've done something wrong.
28
- if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
29
- missing = (
30
- 'save_for_backward' if custom_op._has_impl('backward')
31
- else 'backward'
32
- )
33
- found = 'save_for_backward' if missing == 'backward' else 'backward'
34
- loc = custom_op._get_impl(found).location
35
- raise RuntimeError(
36
- f"We found a '{found}' registration for {custom_op} at "
37
- f"{loc} but were unable to find a '{missing}' registration. "
38
- f"To use the CustomOp API to register a backward formula, "
39
- f"please provide us both a backward function and a "
40
- f"'save for backward' function via `impl_backward` and "
41
- f"`impl_save_for_backward` respectively.")
42
- return autograd_fallback(*args, **kwargs)
43
- return inner
44
-
45
-
46
- # TODO(#101191): Use the actual C++ autograd not implemented fallback,
47
- # or change the default autograd fallback to the autograd not implemented fallback.
48
- def autograd_not_implemented(custom_op):
49
- def kernel(*args, **kwargs):
50
- if torch.is_grad_enabled() and pytree.tree_any(
51
- lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
52
- ):
53
- raise RuntimeError("Autograd has not been implemented for operator")
54
- with torch._C._AutoDispatchBelowAutograd():
55
- return custom_op(*args, **kwargs)
56
- return kernel
57
-
58
-
59
- def mark_non_differentiable(ctx, output, output_differentiability):
60
- # Output types are restricted to be:
61
- # - Tensor
62
- # - Tensor[]
63
- # - int, bool, Scalar, float
64
- # See _check_can_register_backward
65
- if output_differentiability is not None:
66
- if not isinstance(output, tuple):
67
- tuple_output = (output,)
68
- else:
69
- tuple_output = output # type: ignore[assignment]
70
- assert len(output_differentiability) == len(tuple_output)
71
- non_differentiable_tensors = []
72
- for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
73
- if isinstance(out, torch.Tensor):
74
- if not differentiable:
75
- non_differentiable_tensors.append(out)
76
- continue
77
- if isinstance(out, list):
78
- if not differentiable:
79
- non_differentiable_tensors.extend(out)
80
- continue
81
- if differentiable:
82
- raise RuntimeError(
83
- f"With output_differentiability={output_differentiability}. "
84
- f"At idx {idx}, we received an object of type {type(out)} that "
85
- f"is not a Tensor, so it cannot have be marked as differentiable in "
86
- f"output_differentiability.")
87
- if non_differentiable_tensors:
88
- ctx.mark_non_differentiable(*non_differentiable_tensors)
89
-
90
-
91
- def construct_autograd_kernel(
92
- schema,
93
- output_differentiability,
94
- custom_op,
95
- op_overload,
96
- save_for_backward_fn,
97
- backward_fn):
98
-
99
- def apply(*args):
100
- flat_args, spec = pytree.tree_flatten(args)
101
- out_spec = None
102
-
103
- def forward(ctx, *flat_args):
104
- ctx.set_materialize_grads(True)
105
- args = pytree.tree_unflatten(list(flat_args), spec)
106
- with torch._C._AutoDispatchBelowAutograd():
107
- output = op_overload(*args)
108
-
109
- # We use the info about args to give better error messages in backward
110
- args_info = namedtuple_args(
111
- schema, pytree.tree_map(type, args))
112
-
113
- save_for_backward_fn_inputs = namedtuple_args(schema, args)
114
- to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
115
-
116
- save_pytree_for_backward(ctx, (to_save, args_info))
117
- mark_non_differentiable(ctx, output, output_differentiability)
118
-
119
- nonlocal out_spec
120
- flat_output, out_spec = pytree.tree_flatten(output)
121
- return tuple(flat_output)
122
-
123
- def backward(ctx, *flat_grad_output):
124
- assert out_spec is not None
125
- grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
126
- saved, args_info = unpack_saved(ctx)
127
- # There is nothing on the ctx object for now, it is just there so
128
- # that we can add additional things in the future.
129
- inner_ctx = object()
130
- if not isinstance(grads, tuple):
131
- grads = (grads,)
132
- grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
133
-
134
- # Massage the grad_inputs_dict to a form acceptable by
135
- # autograd.Function.
136
- validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
137
- return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
138
-
139
- generated_cls = gen_autograd_function(
140
- custom_op._opname + '_customop', forward, backward)
141
-
142
- flat_output = generated_cls.apply(*flat_args)
143
- assert out_spec is not None
144
- return pytree.tree_unflatten(list(flat_output), out_spec)
145
- return apply
146
-
147
-
148
- def gen_autograd_function(name, forward, backward):
149
- generated_cls = type(
150
- name,
151
- (torch.autograd.Function,),
152
- {
153
- 'forward': staticmethod(forward),
154
- 'backward': staticmethod(backward),
155
- }
156
- )
157
- return generated_cls
158
-
159
-
160
- @functools.lru_cache
161
- def namedtuple_args_cls(schema):
162
- attribs = [arg.name for arg in schema.arguments.flat_all]
163
- name = str(schema.name) + "_args"
164
- # mypy doesn't support dynamic namedtuple name
165
- tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
166
- return tuple_cls
167
-
168
-
169
- def namedtuple_args(schema, args):
170
- assert isinstance(args, tuple)
171
- tuple_cls = namedtuple_args_cls(schema)
172
- return tuple_cls(*args)
173
-
174
-
175
- def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
176
- def error(what):
177
- backward = forward_op._get_impl('backward')
178
- raise RuntimeError(
179
- f"In the backward function defined for {forward_op} at "
180
- f"{backward.location} using the CustomOp API, {what}")
181
-
182
- if not isinstance(grad_inputs_dict, dict):
183
- error(f"expected the output of the backward function to be a dict but "
184
- f"got {type(grad_inputs_dict)}")
185
-
186
- expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
187
- if arg.type.is_tensor_like()}
188
- actual_keys = grad_inputs_dict.keys()
189
- if expected_keys != actual_keys:
190
- error(f"expected the returned grad_input dict to have keys "
191
- f"{expected_keys} but got {actual_keys}. The backward "
192
- f"function must return a gradient (can be None) for each arg "
193
- f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
194
- f"Args declared to be non-Tensor-like types should not appear "
195
- f"in the grad_input dict")
196
-
197
- for name, grad in grad_inputs_dict.items():
198
- arg_info = getattr(args_info, name)
199
-
200
- if isinstance(arg_info, list):
201
- if not isinstance(grad, (tuple, list)):
202
- error(f"for input '{name}' expected the grad_input dict to "
203
- f"hold a list of gradients but got object of type "
204
- f"{type(grad)}.")
205
- if not len(grad) == len(arg_info):
206
- error(f"for input '{name}' expected the grad_input dict to "
207
- f"hold a list of {len(arg_info)} gradients but got "
208
- f"{len(grad)}")
209
- for idx, (g, info) in enumerate(zip(grad, arg_info)):
210
- if g is None:
211
- continue
212
- if not isinstance(g, torch.Tensor):
213
- error(f"for input '{name}' expected the grad_input dict to "
214
- f"hold a list of None or Tensor gradients but got "
215
- f"object of {type(g)} at index {idx}")
216
- if not issubclass(info, torch.Tensor):
217
- error(f"for input '{name}', got a Tensor as the gradient "
218
- f"for the {idx}-th value but expected None because "
219
- f"the {idx}-th value was not a Tensor (it was "
220
- f"type {arg_info}")
221
- continue
222
-
223
- if grad is None:
224
- continue
225
- if not isinstance(grad, torch.Tensor):
226
- error(f"got object of type {type(grad)} as the gradient for input "
227
- f"'{name}', "
228
- f"but expected the gradient to be either None or a Tensor")
229
- if not issubclass(arg_info, torch.Tensor):
230
- error(f"got a Tensor as the gradient for input '{name}' but "
231
- f"expected None as the gradient because input '{name}' "
232
- f"was not a Tensor (it was type {arg_info}).")
233
-
234
-
235
- def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
236
- result = []
237
- for name, arg_info in args_info._asdict().items():
238
- if name not in grad_inputs_dict:
239
- result.append(pytree.tree_map(lambda x: None, arg_info))
240
- continue
241
- result.append(grad_inputs_dict[name])
242
- return tuple(pytree.tree_leaves(result))
243
-
244
- # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
245
- # autograd.Function prefers that users use ctx.save_for_backward to
246
- # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
247
- # ctx object.
248
- def save_pytree_for_backward(ctx, stuff):
249
- flat_stuff, spec = pytree.tree_flatten(stuff)
250
- num_elts = len(flat_stuff)
251
- tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
252
- if isinstance(thing, torch.Tensor)]
253
- non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
254
- if not isinstance(thing, torch.Tensor)]
255
- tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
256
- non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
257
-
258
- ctx.spec = spec
259
- ctx.num_elts = num_elts
260
- ctx.save_for_backward(*tensors)
261
- ctx.tensor_idxs = tensor_idxs
262
- ctx.saved_non_tensors = non_tensors
263
- ctx.non_tensor_idxs = non_tensor_idxs
264
-
265
-
266
- # Inverse operation to save_pytree_for_backward
267
- def unpack_saved(ctx):
268
- flat_stuff = [None] * ctx.num_elts
269
- for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
270
- flat_stuff[idx] = tensor
271
- for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
272
- flat_stuff[idx] = non_tensor
273
- stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
274
- return stuff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_custom_op/functional.py DELETED
@@ -1,187 +0,0 @@
1
- import weakref
2
-
3
- import torch
4
- import torch.utils._pytree as pytree
5
- from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
6
- from torch._ops import OpOverload
7
- from torch.library import Library
8
- from torchgen.model import (
9
- BaseTy,
10
- BaseType,
11
- FunctionSchema,
12
- OperatorName,
13
- OptionalType,
14
- SchemaKind,
15
- )
16
-
17
- from .autograd import autograd_not_implemented
18
-
19
-
20
- def register_functional_op(
21
- lib: Library,
22
- new_op_name: str,
23
- mutable_op: OpOverload,
24
- ) -> None:
25
- """Given a mutable operator, registers the functional variant.
26
-
27
- This API also correctly links the functional variant with the mutable
28
- operator for the purposes of functionalization.
29
-
30
- All of the new registrations are performed on the ``lib`` passed in.
31
-
32
- Arguments:
33
- lib (Library): Should be a torch.library.Library object that has
34
- the same namespace as ``mutable_op``'s namespace.
35
- lib will be used to register the new functional op as well
36
- as a functionalization kernel for the ``mutable_op``
37
- If you don't have a library handy, use
38
- ``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
39
- new_op_name (str): The name of the functional operator (without the
40
- namespace). If no namespace, the new functional variant will be
41
- accessible under ``torch.ops.{lib.ns}.new_op_name``.
42
- mutable_op (OpOverload): The mutable custom operator. Note
43
- that you may need to add a `.default` to it, like
44
- `torch.ops.aten.abs_.default`.
45
-
46
- """
47
- validate(mutable_op)
48
- schema = functional_schema(new_op_name, mutable_op)
49
- lib.define(schema)
50
-
51
- functional_impl = construct_functional_impl(mutable_op)
52
- lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
53
-
54
- functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
55
-
56
- # There's no easy way for us to generate the autograd kernel, so we
57
- # use autograd_not_implemented. Also, this makes it so that the user
58
- # is unable to register an autograd formula themselves. This shouldn't
59
- # be a problem if the user doesn't use the functional op direclty
60
- # in their program, but we may need to revist this in the future.
61
- lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
62
-
63
- f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
64
-
65
- lib.impl(mutable_op, f_kernel, 'Functionalize')
66
-
67
-
68
- def construct_functional_impl(mutable_op):
69
- def functional_impl(*args):
70
- # Strategy:
71
- # - clone args that would have been mutated
72
- # - run mutable_op
73
- # - return the cloned args as additional outputs
74
- new_args = []
75
- extra_rets = []
76
- for is_write, arg in zip(mutable_args(mutable_op), args):
77
- if is_write:
78
- cloned = arg.clone() if arg is not None else None
79
- new_args.append(cloned)
80
- extra_rets.append(cloned)
81
- else:
82
- new_args.append(arg)
83
- result = mutable_op(*new_args)
84
- if result is None:
85
- return tuple(extra_rets)
86
- if isinstance(result, tuple):
87
- return (*result, *extra_rets)
88
- return (result, *extra_rets)
89
- return functional_impl
90
-
91
-
92
- def construct_functionalization_kernel(mutable_op, functional_op):
93
- def kernel(*args):
94
- # There's nothing to be functionalized!
95
- # We can still end up here because DispatchKey::Functionalize is a mode key
96
- if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
97
- with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
98
- return mutable_op(*args)
99
-
100
- # NB: This differs from the codegen -- codegen handles cases where there
101
- # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
102
- # This only really matters for XLA (mixed CPU-XLA tensors) and
103
- # running functionalization without the PT2 stack (which guarantees to us that
104
- # all tensors are FunctionalTensorWrapper).
105
- if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
106
- raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
107
-
108
- unwrapped_args = []
109
- for arg in args:
110
- if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
111
- torch._sync(arg)
112
- unwrapped = torch._from_functional_tensor(arg)
113
- unwrapped_args.append(unwrapped)
114
- else:
115
- unwrapped_args.append(arg)
116
-
117
- with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
118
- output = functional_op(*unwrapped_args)
119
-
120
- num_actual_output = len(mutable_op._schema.returns)
121
- actual_output = pytree.tree_map(
122
- torch._to_functional_tensor, output[:num_actual_output])
123
-
124
- new_values_to_propagate = output[num_actual_output:]
125
- inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
126
- if is_write]
127
- assert len(new_values_to_propagate) == len(inputs_to_replace)
128
- for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
129
- if (arg is None and new_value is None) or (arg is not None and new_value is not None):
130
- continue
131
- torch._C._propagate_xla_data(arg, new_value)
132
- torch._C._replace_(arg, new_value)
133
- torch._C._commit_update(arg)
134
- torch._sync(arg)
135
-
136
- if len(actual_output) == 1:
137
- return actual_output[0]
138
- elif len(actual_output) == 0:
139
- return None
140
- return actual_output
141
-
142
- return kernel
143
-
144
-
145
- def validate(mutable_op: OpOverload):
146
- if not isinstance(mutable_op, OpOverload):
147
- raise TypeError(
148
- f"register_functional_op(mutable_op): expected mutable_op to be instance of "
149
- f"OpOverload but got {type(mutable_op)}")
150
-
151
- # There are generally three types of "in-place" or "mutable" ops.
152
- # Each of them have their own conventions:
153
- # - inplace (first input modified in-place and returned as only output)
154
- # - out= (some args modified in-place and returned as outputs)
155
- # - mutable (some args modified in-place but none of those returned as outputs)
156
- # In theory we can support all three, but we'll just support the last
157
- # option right now for simplicity.
158
- schema = FunctionSchema.parse(str(mutable_op._schema))
159
- if not schema.kind() == SchemaKind.mutable:
160
- raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
161
- for ret in schema.returns:
162
- # construct_functionalization_kernel assumes this for simplicity
163
- if ret.annotation is not None:
164
- raise NotImplementedError(
165
- "NYI: register_functional_op(op) where op returns a mutated or aliased value. "
166
- "Please file an issue (and as a workaround, modify your operator to "
167
- "not return the mutated value or aliases)")
168
- for arg in schema.arguments.flat_all:
169
- # construct_functionalization_kernel assumes this for simplicity
170
- if arg.type.is_tensor_like() and (
171
- arg.type != BaseType(BaseTy.Tensor)
172
- and arg.type != OptionalType(BaseType(BaseTy.Tensor))
173
- ):
174
- raise NotImplementedError(
175
- "NYI: register_functional_op(op) where op has a List[Tensor] input."
176
- "Please file an issue.")
177
-
178
-
179
- def functional_schema(new_op_name, op: OpOverload):
180
- schema = FunctionSchema.parse(str(op._schema))
181
- schema = schema.signature().with_name(OperatorName.parse(new_op_name))
182
- return str(schema)
183
-
184
-
185
- def mutable_args(op: OpOverload):
186
- return tuple(False if arg.alias_info is None else arg.alias_info.is_write
187
- for arg in op._schema.arguments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_custom_op/impl.py DELETED
@@ -1,976 +0,0 @@
1
- import dataclasses
2
- import functools
3
- import inspect
4
- import sys
5
- import typing
6
- import weakref
7
-
8
- from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
9
-
10
- import torch
11
- import torch._C as _C
12
- import torch.library as library
13
- from torch._library.abstract_impl import AbstractImplCtx
14
- from torch.library import get_ctx
15
-
16
- from .autograd import autograd_kernel_indirection, construct_autograd_kernel
17
-
18
- """
19
- For a detailed guide on custom ops, please see
20
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
21
-
22
- This file includes pieces of the implementation of our custom operator API.
23
- """
24
-
25
- __all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
26
-
27
-
28
- SUPPORTED_DEVICE_TYPE_TO_KEY = {
29
- "cpu": "CPU",
30
- "cuda": "CUDA",
31
- }
32
-
33
- # We will not let users register CustomOps with anything that could look like
34
- # PyTorch internals to avoid confusion.
35
- RESERVED_NS = {
36
- "prim",
37
- "prims",
38
- "aten",
39
- "at",
40
- "torch",
41
- "pytorch",
42
- }
43
-
44
-
45
- def custom_op(
46
- qualname: str, manual_schema: typing.Optional[str] = None
47
- ) -> typing.Callable:
48
- r"""Creates a new CustomOp object.
49
-
50
- WARNING: if you're a user, please do not use this directly
51
- (instead use the torch._custom_ops APIs).
52
- Also please see the following for a detailed guide on custom ops.
53
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
54
-
55
- In PyTorch, defining an op (short for "operator") is a two step-process:
56
- - we need to define (create) the op
57
- - we need to implement behavior for how the operator interacts with
58
- various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
59
-
60
- This entrypoint defines the CustomOp object (the first step);
61
- you must then perform the second step by calling various methods on
62
- the CustomOp object.
63
-
64
- This API is used as a decorator (see examples).
65
-
66
- Arguments:
67
- qualname (str): Should be a string that looks like
68
- "namespace::operator_name". Operators in PyTorch need a namespace to
69
- avoid name collisions; a given operator may only be created once.
70
- If you are writing a Python library, we recommend the namespace to
71
- be the name of your top-level module. The operator_name must be
72
- the same as the name of the function you pass to custom_op
73
- (see examples).
74
- manual_schema (Optional[str]): Each PyTorch operator needs a schema that
75
- tells PyTorch the types of the inputs/outputs. If None (default),
76
- we will infer the schema from the type annotations on the function
77
- (see examples). Otherwise, if you don't want to use type annotations,
78
- you may provide us the schema string.
79
-
80
- Example::
81
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
82
- >>> import numpy as np
83
- >>> from torch import Tensor
84
- >>>
85
- >>> # Step 1: define the CustomOp.
86
- >>> # We need to provide the decorator a "prototype function"
87
- >>> # (a function with Python ellipses as the body).
88
- >>> @custom_op("my_library::numpy_sin")
89
- >>> def numpy_sin(x: Tensor) -> Tensor:
90
- >>> ...
91
- >>>
92
- >>> # numpy_sin is now an instance of class CustomOp
93
- >>> print(type(numpy_sin))
94
- >>>
95
- >>> # Step 2: Register an implementation for various PyTorch subsystems
96
- >>>
97
- >>> # Register an implementation for CPU tensors
98
- >>> @numpy_sin.impl('cpu')
99
- >>> def numpy_sin_impl_cpu(x):
100
- >>> return torch.from_numpy(np.sin(x.numpy()))
101
- >>>
102
- >>> # Register an implementation for CUDA tensors
103
- >>> @numpy_sin.impl('cuda')
104
- >>> def numpy_sin_impl_cuda(x):
105
- >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
106
- >>>
107
- >>> x = torch.randn(3)
108
- >>> numpy_sin(x) # calls numpy_sin_impl_cpu
109
- >>>
110
- >>> x_cuda = x.cuda()
111
- >>> numpy_sin(x) # calls numpy_sin_impl_cuda
112
-
113
- """
114
-
115
- def inner(func):
116
- if not inspect.isfunction(func):
117
- raise ValueError(
118
- f"custom_op(...)(func): Expected `func` to be a Python "
119
- f"function, got: {type(func)}"
120
- )
121
-
122
- ns, name = parse_qualname(qualname)
123
- validate_namespace(ns)
124
- if func.__name__ != name:
125
- raise ValueError(
126
- f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
127
- f"to have name '{name}' but got '{func.__name__}'. "
128
- f"Please either change the name of `func` or the qualname that "
129
- f"is passed to `custom_op`"
130
- )
131
-
132
- schema = infer_schema(func) if manual_schema is None else manual_schema
133
- schema_str = f"{name}{schema}"
134
- function_schema = FunctionSchema.parse(schema_str)
135
- validate_schema(function_schema)
136
- if manual_schema is not None:
137
- validate_function_matches_schema(function_schema, func)
138
-
139
- lib = library.Library(ns, "FRAGMENT")
140
- lib.define(schema_str)
141
- ophandle = find_ophandle_or_throw(ns, function_schema.name)
142
- result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
143
-
144
- result.__name__ = func.__name__
145
- result.__module__ = func.__module__
146
- result.__doc__ = func.__doc__
147
-
148
- library.impl(lib, result._opname, "Autograd")(
149
- autograd_kernel_indirection(weakref.proxy(result))
150
- )
151
-
152
- torch._C._dispatch_set_report_error_callback(
153
- ophandle, functools.partial(report_error_callback, weakref.proxy(result))
154
- )
155
-
156
- return result
157
-
158
- return inner
159
-
160
-
161
- # Global dictionary holding references to all CustomOp objects
162
- # Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
163
- # Used to query the CustomOp associated with a specific C++ dispatcher operator.
164
- # An example usage is FakeTensor: FakeTensor checks if a specific operator
165
- # has an implementation registered via the CustomOp API.
166
- # Indexed by qualname (e.g. aten::foo)
167
- global_registry: typing.Dict[str, "CustomOp"] = {}
168
-
169
-
170
- class CustomOp:
171
- r"""Class for custom operators in PyTorch.
172
-
173
- Use the CustomOp API to create user-defined custom operators that behave
174
- just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
175
- comes to various PyTorch subsystems (like torch.compile).
176
-
177
- To construct a `CustomOp`, use `custom_op`.
178
- """
179
-
180
- def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
181
- super().__init__()
182
- if not _private_access:
183
- raise RuntimeError(
184
- "The CustomOp constructor is private and we do not guarantee "
185
- "BC for it. Please use custom_op(...) to create a CustomOp object"
186
- )
187
- name = f"{cpp_ns}::{operator_name}"
188
- self._schema = schema
189
- self._cpp_ns = cpp_ns
190
- self._lib: library.Library = lib
191
- self._ophandle: _C._DispatchOperatorHandle = ophandle
192
- # Has the name of the op, e.g. "foo". We cache here for convenience.
193
- self._opname: str = operator_name
194
- # this is _opname but with namespace. e.g. "custom::foo"
195
- self._qualname: str = name
196
- self.__name__ = None # mypy requires this
197
- # NB: Some of these impls are registered as kernels to DispatchKeys.
198
- # Modifying the _impls dict directly won't do anything in that case.
199
- self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
200
- # See NOTE [CustomOp autograd kernel indirection]
201
- self._registered_autograd_kernel_indirection = False
202
-
203
- global_registry[self._qualname] = self
204
-
205
- def _register_autograd_kernel_indirection(self):
206
- assert not self._registered_autograd_kernel_indirection
207
- self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
208
- self._registered_autograd_kernel_indirection = True
209
-
210
- # Records the impl and the source location in self._impls
211
- # Note that this doesn't cause torch.library to use the impl, that
212
- # needs to be done in a separate self._lib.impl call.
213
- def _register_impl(self, kind, func, stacklevel=2):
214
- if self._has_impl(kind):
215
- func_and_location = self._impls[kind]
216
- assert func_and_location is not None # Pacify mypy
217
- location = func_and_location.location
218
- raise RuntimeError(
219
- f"Attempting to register a {kind} impl for operator {self._qualname} "
220
- f"that already has a {kind} impl registered from Python at "
221
- f"{location}. This is not supported."
222
- )
223
- frame = inspect.getframeinfo(sys._getframe(stacklevel))
224
- location = f"{frame.filename}:{frame.lineno}"
225
- self._impls[kind] = FuncAndLocation(func, location)
226
-
227
- def _get_impl(self, kind):
228
- return self._impls[kind]
229
-
230
- def _has_impl(self, kind):
231
- return kind in self._impls
232
-
233
- def _destroy(self):
234
- # NOTE: [CustomOp lifetime]
235
- # A CustomOp, once created, lives forever. The mechanism is that the
236
- # global registry holds a reference to it. However, to make testing
237
- # easier, we want to be able to destroy CustomOp objects.
238
- # CustomOp._destroy does the job, though it leaves the CustomOp
239
- # in a garbage state.
240
- del self._lib
241
-
242
- opnamespace = getattr(torch.ops, self._cpp_ns)
243
- if hasattr(opnamespace, self._opname):
244
- delattr(opnamespace, self._opname)
245
-
246
- del global_registry[self._qualname]
247
-
248
- def __repr__(self):
249
- return f'<CustomOp(op="{self._qualname}")>'
250
-
251
- def __call__(self, *args, **kwargs):
252
- # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
253
- # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
254
- # issues from caching operators that make testing CustomOp difficult).
255
- result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
256
- return result
257
-
258
- def impl(
259
- self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
260
- ) -> typing.Callable:
261
- r"""Register an implementation for a device type for this CustomOp object.
262
-
263
- WARNING: if you're a user, please do not use this directly
264
- (instead use the torch._custom_ops APIs).
265
- Also please see the following for a detailed guide on custom ops.
266
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
267
-
268
- If the CustomOp is passed multiple Tensor inputs with different device
269
- types, it will dispatch to the registered implementation for the highest
270
- priority device type among those present.
271
- The supported device types, in order of priority, are {'cuda', 'cpu'}.
272
-
273
- This API is used as a decorator (see examples).
274
-
275
- Arguments:
276
- device_types (str or Iterable[str]): the device type(s) to register the function for.
277
-
278
- Examples::
279
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
280
- >>> import numpy as np
281
- >>> from torch import Tensor
282
- >>>
283
- >>> @custom_op("my_library::numpy_cos")
284
- >>> def numpy_cos(x: Tensor) -> Tensor:
285
- >>> ...
286
- >>>
287
- >>> # Register an implementation for CPU Tensors
288
- >>> @numpy_cos.impl('cpu')
289
- >>> def numpy_cos_impl_cpu(x):
290
- >>> return torch.from_numpy(np.cos(x.numpy()))
291
- >>>
292
- >>> # Register an implementation for CUDA Tensors
293
- >>> @numpy_cos.impl('cuda')
294
- >>> def numpy_cos_impl_cuda(x):
295
- >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
296
- >>>
297
- >>> x = torch.randn(3)
298
- >>> numpy_cos(x) # calls numpy_cos_impl_cpu
299
- >>>
300
- >>> x_cuda = x.cuda()
301
- >>> numpy_cos(x) # calls numpy_cos_impl_cuda
302
-
303
- """
304
- if isinstance(device_types, str):
305
- device_types = [device_types]
306
- for device_type in device_types:
307
- validate_device_type(device_type)
308
-
309
- def inner(f):
310
- for device_type in set(device_types):
311
- self._check_doesnt_have_library_impl(device_type)
312
- self._register_impl(device_type, f, stacklevel=_stacklevel)
313
- dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
314
- library.impl(self._lib, self._opname, dispatch_key)(f)
315
- return f
316
-
317
- return inner
318
-
319
- def _check_doesnt_have_library_impl(self, device_type):
320
- if self._has_impl(device_type):
321
- return
322
- key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
323
- if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
324
- raise RuntimeError(
325
- f"impl(..., device_types={device_type}): the operator {self._qualname} "
326
- f"already has an implementation for this device type via a "
327
- f"pre-existing torch.library or TORCH_LIBRARY registration.")
328
-
329
- def impl_factory(self) -> typing.Callable:
330
- r"""Register an implementation for a factory function."""
331
-
332
- def inner(f):
333
- self._register_impl("factory", f)
334
- library.impl(self._lib, self._opname, "BackendSelect")(f)
335
- return f
336
-
337
- return inner
338
-
339
- def impl_abstract(self, _stacklevel=2) -> typing.Callable:
340
- r"""Register an abstract implementation for this operator.
341
-
342
- WARNING: please do not use this directly (and instead use the torch._custom_ops
343
- APIs). Also please see the following for a detailed guide on custom ops.
344
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
345
-
346
- An "abstract implementation" specifies the behavior of this operator on
347
- Tensors that carry no data. Given some input Tensors with certain properties
348
- (sizes/strides/storage_offset/device), it specifies what the properties of
349
- the output Tensors are.
350
-
351
- The abstract implementation has the same signature as the operator.
352
- It is run for both FakeTensors and meta tensors. To write an abstract
353
- implementation, assume that all Tensor inputs to the operator are
354
- regular CPU/CUDA/Meta tensors, but they do not have storage, and
355
- you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
356
- The abstract implementation must consist of only PyTorch operations
357
- (and may not directly access the storage or data of any input or
358
- intermediate Tensors).
359
-
360
- This API is used as a decorator (see examples).
361
-
362
- Examples::
363
- >>> import numpy as np
364
- >>> from torch import Tensor
365
- >>>
366
- >>> # Example 1: an operator without data-dependent output shape
367
- >>> @custom_op('my_library::custom_linear')
368
- >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
369
- >>> ...
370
- >>>
371
- >>> @custom_linear.impl_abstract()
372
- >>> def custom_linear_abstract(x, weight):
373
- >>> assert x.dim() == 2
374
- >>> assert weight.dim() == 2
375
- >>> assert bias.dim() == 1
376
- >>> assert x.shape[1] == weight.shape[1]
377
- >>> assert weight.shape[0] == bias.shape[0]
378
- >>> assert x.device == weight.device
379
- >>>
380
- >>> return (x @ weight.t()) + bias
381
- >>>
382
- >>> # Example 2: an operator with data-dependent output shape
383
- >>> @custom_op('my_library::custom_nonzero')
384
- >>> def custom_nonzero(x: Tensor) -> Tensor:
385
- >>> ...
386
- >>>
387
- >>> @custom_nonzero.impl_abstract()
388
- >>> def custom_nonzero_abstract(x):
389
- >>> # Number of nonzero-elements is data-dependent.
390
- >>> # Since we cannot peek at the data in an abstract impl,
391
- >>> # we use the ctx object to construct a new symint that
392
- >>> # represents the data-dependent size.
393
- >>> ctx = torch._custom_op.get_ctx()
394
- >>> nnz = ctx.create_unbacked_symint()
395
- >>> shape = [x.dim(), nnz]
396
- >>> result = x.new_empty(shape, dtype=torch.long)
397
- >>> return result
398
- >>>
399
- >>> @custom_nonzero.impl(['cpu', 'cuda'])
400
- >>> def custom_nonzero_impl(x):
401
- >>> x_np = to_numpy(x)
402
- >>> res = np.stack(np.nonzero(x_np), axis=1)
403
- >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
404
- >>> # constrain the range to at least 2
405
- >>> if res.shape[0] <= 1:
406
- >>> raise RuntimeError("not supported")
407
- >>> return torch.tensor(res, device=x.device)
408
-
409
- """
410
-
411
- def inner(f):
412
- self._check_doesnt_have_library_meta_impl()
413
- self._register_impl("abstract", f, stacklevel=_stacklevel)
414
- location = self._get_impl("abstract").location
415
-
416
- qualname = self._qualname
417
-
418
- # Handle DispatchKey.Meta registration
419
- @functools.wraps(f)
420
- def f_with_ctx(*args, **kwargs):
421
- def error_on_ctx():
422
- raise RuntimeError(
423
- f"Attempted to call get_ctx() for the meta implementation "
424
- f"for {qualname}."
425
- f"You have presumably called get_ctx() because the operator "
426
- f"has a data-dependent output shape; if so, there is no "
427
- f"such meta implementation and this error is the correct "
428
- f"behavior. Otherwise, please remove the call to get_ctx() "
429
- f"in the implementation registered with impl_abstract "
430
- f"at {location}"
431
- )
432
-
433
- with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
434
- return f(*args, **kwargs)
435
-
436
- self._lib.impl(self._opname, f_with_ctx, "Meta")
437
- return f
438
-
439
- return inner
440
-
441
- def _check_can_register_backward(self):
442
- def error(detail):
443
- raise RuntimeError(
444
- f"Cannot use torch._custom_ops APIs to register backward "
445
- f"formula for {detail}. Got operator "
446
- f"{self._qualname} with schema: {schema}"
447
- )
448
-
449
- schema = self._schema
450
- if schema.kind() != SchemaKind.functional:
451
- error("non-functional operator")
452
-
453
- rets = schema.returns
454
- if not schema.returns:
455
- error("operator with no returns")
456
-
457
- assert len(rets) > 0
458
- is_non_mutating_view = any(
459
- r.annotation is not None and not r.annotation.is_write for r in rets
460
- )
461
- if is_non_mutating_view:
462
- error("operator that returns views")
463
-
464
- # We make assumptions about the schema's return types.
465
- allowed_return_types = {
466
- BaseType(BaseTy.int): "int",
467
- BaseType(BaseTy.SymInt): "SymInt",
468
- BaseType(BaseTy.bool): "bool",
469
- BaseType(BaseTy.float): "float",
470
- BaseType(BaseTy.Tensor): "Tensor",
471
- ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
472
- }
473
- for ret in schema.returns:
474
- if ret.type in allowed_return_types:
475
- continue
476
- error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
477
-
478
- def _check_doesnt_have_library_autograd_impl(self):
479
- if self._registered_autograd_kernel_indirection:
480
- return
481
-
482
- if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
483
- raise RuntimeError(
484
- f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
485
- f"already has an implementation for this device type via a "
486
- f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
487
- f"CompositeImplicitAutograd operators do not need an autograd formula; "
488
- f"instead, the operator will decompose into its constituents and those "
489
- f"can have autograd formulas defined on them.")
490
-
491
- # We can improve this by adding "all Autograd<BACKEND> keys", but
492
- # realistically people will just be using this API for CPU/CUDA for now.
493
- for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
494
- if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
495
- raise RuntimeError(
496
- f"impl_backward/impl_save_for_backward: "
497
- f"the operator {self._qualname} already has an Autograd kernel "
498
- f"registered to DispatchKey::{key} vi a pre-existing "
499
- f"torch.library or TORCH_LIBRARY registration. Please either "
500
- f"remove those registrations or don't use the torch._custom_ops APIs")
501
-
502
- def _check_doesnt_have_library_meta_impl(self):
503
- if self._has_impl("abstract"):
504
- return
505
-
506
- # If the user's operator is CompositeExplicitAutograd,
507
- # allow them to impl_abstract. This is being pragmatic
508
- # (existing custom ops may have CompositeExplicitAutograd
509
- # registration that don't work with Meta kernels, so this
510
- # gives them an escape hatch).
511
- if (
512
- _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
513
- and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
514
- ):
515
- return
516
-
517
- # Otherwise, if the user's already has a Meta kernel or their
518
- # op is CompositeImplicitAutograd or some other alias dispatch key,
519
- # raise.
520
-
521
- # Special case for CompositeImplicitAutograd
522
- if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
523
- raise RuntimeError(
524
- f"impl_abstract(...): the operator {self._qualname} "
525
- f"already has an implementation for this device type via a "
526
- f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
527
- f"CompositeImplicitAutograd operators do not need an abstract impl; "
528
- f"instead, the operator will decompose into its constituents and those "
529
- f"can have abstract impls defined on them.")
530
-
531
- if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
532
- raise RuntimeError(
533
- f"impl_abstract(...): the operator {self._qualname} "
534
- f"already has an DispatchKey::Meta implementation via a "
535
- f"pre-existing torch.library or TORCH_LIBRARY registration. "
536
- f"Please either remove that registration or don't call impl_abstract.")
537
-
538
- # NOTE ["backward", "save_for_backward", and "autograd"]
539
- # As a part of the explicit autograd API, a user must provide us
540
- # a "save_for_backward" function and a "backward" function.
541
- # When both of these have been provided, then we automatically
542
- # construct the "autograd" kernel.
543
- def _register_autograd_kernel(self):
544
- assert self._has_impl("backward")
545
- assert self._has_impl("save_for_backward")
546
- kernel = construct_autograd_kernel(
547
- self._schema,
548
- self._output_differentiability,
549
- self,
550
- get_op(self._qualname),
551
- self._get_impl("save_for_backward").func,
552
- self._get_impl("backward").func)
553
- self._register_impl("autograd", kernel)
554
-
555
- def impl_save_for_backward(self, _stacklevel=2):
556
- r"""Register a function that tells us what to save for backward.
557
-
558
- Please see impl_backward for more details.
559
- """
560
- def inner(f):
561
- self._check_can_register_backward()
562
- self._check_doesnt_have_library_autograd_impl()
563
- if not self._registered_autograd_kernel_indirection:
564
- self._register_autograd_kernel_indirection()
565
- self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
566
- if self._has_impl("backward"):
567
- self._register_autograd_kernel()
568
- return inner
569
-
570
- def impl_backward(self, output_differentiability=None, _stacklevel=2):
571
- r"""Registers a backward formula.
572
-
573
- WARNING: if you're a user, please do not use this directly
574
- (instead use the torch._custom_ops APIs).
575
- Also please see the following for a detailed guide on custom ops.
576
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
577
-
578
- In order for the CustomOp to work with autograd, you need to register
579
- a backward formula. There are two pieces to this:
580
- 1. You must give us a function to specify what to save for backward.
581
- Call this the "save for backward" function.
582
- 2. You must give us a function that computes gradients. Call this the
583
- "backward" function.
584
-
585
- Use `impl_save_for_backward` to define a "save for backward" function
586
- that specifies what gets saved for backward. The function should accept
587
- two arguments ``(inputs, output)`` and return the quantities to be saved
588
- for backward.
589
-
590
- During runtime, when you call the CustomOp, PyTorch will invoke the
591
- "save for backward" function with the inputs and output of the CustomOp.
592
-
593
- Use `impl_backward` to define the "backward" function. The backward
594
- function must accept ``(ctx, saved, *grads)``:
595
- - ``ctx`` is a context object where we may provide information
596
- - ``saved`` is exactly what gets returned from the "save for backward"
597
- function
598
- - ``grads`` is one or more gradients. The number of gradients matches
599
- the number of outputs of the CustomOp.
600
-
601
- The backward function must return a dict that maps the name of
602
- an input to the CustomOp to its corresponding gradient. All inputs that
603
- were declared to be Tensors in the CustomOp definition must be accounted
604
- for in the dict. The gradient may be a Tensor or None.
605
-
606
- """
607
- if output_differentiability is not None:
608
- def yell():
609
- raise RuntimeError(
610
- f"impl_backward(output_differentiability): expected "
611
- f"output_differentiability to be a list of bools with "
612
- f"length equal to the number of outputs of this CustomOp "
613
- f"got: {output_differentiability}")
614
-
615
- if not isinstance(output_differentiability, list):
616
- yell()
617
- for diff in output_differentiability:
618
- if not isinstance(diff, bool):
619
- yell()
620
- if len(self._schema.returns) != len(output_differentiability):
621
- yell()
622
-
623
- def inner(f):
624
- self._check_can_register_backward()
625
- self._check_doesnt_have_library_autograd_impl()
626
- if not self._registered_autograd_kernel_indirection:
627
- self._register_autograd_kernel_indirection()
628
- self._register_impl("backward", f, stacklevel=_stacklevel)
629
- self._output_differentiability = output_differentiability
630
- if self._has_impl("save_for_backward"):
631
- self._register_autograd_kernel()
632
- return inner
633
-
634
-
635
- @dataclasses.dataclass
636
- class FuncAndLocation:
637
- func: typing.Callable
638
- location: str
639
-
640
-
641
- def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
642
- overload_name = (
643
- "" if operator_name.overload_name is None else operator_name.overload_name
644
- )
645
- return _C._dispatch_find_schema_or_throw(
646
- f"{cpp_ns}::{str(operator_name.name)}", overload_name
647
- )
648
-
649
-
650
- def validate_namespace(ns: str) -> None:
651
- if "." in ns:
652
- raise ValueError(
653
- f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
654
- f"valid variable name)"
655
- )
656
- if ns in RESERVED_NS:
657
- raise ValueError(
658
- f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
659
- f"please choose something else. "
660
- )
661
-
662
- def validate_schema(schema: FunctionSchema) -> None:
663
- if not torch._library.utils.is_functional_schema(schema):
664
- raise ValueError(
665
- f"custom_op only supports functional operators "
666
- f"(ops that do not mutate any inputs, do not return "
667
- f"views of the inputs, and has at least one return). "
668
- f"Got the following non-functional schema: {schema}"
669
- )
670
-
671
- # For simplicity: don't allow self arguments
672
- if schema.arguments.self_arg is not None:
673
- raise ValueError(
674
- f"custom_op does not support arguments named 'self'. Please "
675
- f"rename your argument. Got: {schema}"
676
- )
677
-
678
-
679
- def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
680
- names = qualname.split("::", 1)
681
- if len(names) != 2:
682
- raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
683
- f"operator name should look something like ns::foo")
684
- if '.' in names[1]:
685
- raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
686
- f"i.e. operator names with '.' in them. "
687
- f"Please name your operator something like ns::foo. "
688
- f"Got: {qualname}")
689
- return names[0], names[1]
690
-
691
-
692
- def validate_device_type(device_type: str) -> None:
693
- if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
694
- raise ValueError(
695
- f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
696
- f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
697
- )
698
-
699
-
700
- def supported_param(param: inspect.Parameter) -> bool:
701
- return param.kind in (
702
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
703
- inspect.Parameter.KEYWORD_ONLY,
704
- )
705
-
706
-
707
- def validate_function_matches_schema(
708
- schema: FunctionSchema, func: typing.Callable
709
- ) -> None:
710
- sig = inspect.signature(func)
711
-
712
- if not all(supported_param(p) for _, p in sig.parameters.items()):
713
- raise ValueError(
714
- f"custom_op(..., manual_schema)(func): positional-only args, "
715
- f"varargs, and kwargs are not supported. Please rewrite `func` "
716
- f"to not have them. Got `func` with signature: {sig}"
717
- )
718
-
719
- if (
720
- any(
721
- p.annotation is not inspect.Parameter.empty
722
- for _, p in sig.parameters.items()
723
- )
724
- or sig.return_annotation is not inspect.Signature.empty
725
- ):
726
- raise ValueError(
727
- f"custom_op(..., manual_schema)(func): When passing in a manual "
728
- f"schema, we expect `func` to have no type annotations to avoid "
729
- f"ambiguity. Got `func` with signature: {sig}"
730
- )
731
-
732
- positional = [
733
- (name, param)
734
- for name, param in sig.parameters.items()
735
- if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
736
- ]
737
- kwargonly = [
738
- (name, param)
739
- for name, param in sig.parameters.items()
740
- if param.kind == inspect.Parameter.KEYWORD_ONLY
741
- ]
742
-
743
- def error():
744
- raise ValueError(
745
- f"custom_op(..., manual_schema)(func): When passing in a manual "
746
- f"schema, we expect `func`'s signature to match `manual_schema` "
747
- f"(aside from type annotations). "
748
- f"func's signature: {sig}, manual_schema: {schema}"
749
- )
750
-
751
- def error_default_args():
752
- raise ValueError(
753
- f"custom_op(..., manual_schema)(func): "
754
- f"neither func nor manual_schema should have default "
755
- f"arguments. Got "
756
- f"func's signature: {sig}, manual_schema: {schema}"
757
- )
758
-
759
- def compare(sig_args, schema_args):
760
- if len(sig_args) != len(schema_args):
761
- error()
762
- for (name, param), arg in zip(sig_args, schema_args):
763
- if name != arg.name:
764
- error()
765
- if param.default is not inspect.Parameter.empty or arg.default is not None:
766
- error_default_args()
767
-
768
- compare(positional, schema.arguments.flat_positional)
769
- compare(kwargonly, schema.arguments.flat_kwarg_only)
770
-
771
-
772
- def infer_schema(prototype_function: typing.Callable) -> str:
773
- sig = inspect.signature(prototype_function)
774
-
775
- def error_fn(what):
776
- raise ValueError(
777
- f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
778
- )
779
-
780
- params = [
781
- parse_param(name, param, error_fn) for name, param in sig.parameters.items()
782
- ]
783
- ret = parse_return(sig.return_annotation, error_fn)
784
- return f"({', '.join(params)}) -> {ret}"
785
-
786
-
787
- def parse_param(name, param, error_fn):
788
- if not supported_param(param):
789
- error_fn("We do not support positional-only args, varargs, or varkwargs.")
790
-
791
- if param.annotation is inspect.Parameter.empty:
792
- error_fn(f"Parameter {name} must have a type annotation.")
793
-
794
- if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
795
- error_fn(
796
- f"Parameter {name} has unsupported type {param.annotation}. "
797
- f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
798
- )
799
-
800
- if param.default is not inspect.Parameter.empty:
801
- error_fn(
802
- f"Parameter {name} has a default value; this is not supported. "
803
- f"If you want to use default values then create a function with "
804
- f"default values that calls the CustomOp"
805
- )
806
-
807
- return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
808
-
809
-
810
- def derived_types(
811
- base_type, cpp_type, list_base, optional_base_list, optional_list_base
812
- ):
813
- result = [
814
- (base_type, cpp_type),
815
- (typing.Optional[base_type], f"{cpp_type}?"),
816
- ]
817
- if list_base:
818
- result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type]
819
- if optional_base_list:
820
- result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type]
821
- if optional_list_base:
822
- result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type]
823
- return result
824
-
825
-
826
- def get_supported_param_types():
827
- data = [
828
- # (python type, schema type, type[] variant, type?[] variant, type[]? variant
829
- (torch.Tensor, "Tensor", True, True, False),
830
- (int, "SymInt", True, False, True),
831
- (float, "float", True, False, True),
832
- (bool, "bool", True, False, True),
833
- (str, "str", False, False, False),
834
- (torch.types.Number, "Scalar", True, False, False),
835
- (torch.dtype, "ScalarType", False, False, False),
836
- (torch.device, "Device", False, False, False),
837
- ]
838
- result = []
839
- for line in data:
840
- result.extend(derived_types(*line))
841
- return dict(result)
842
-
843
-
844
- SUPPORTED_RETURN_TYPES = {
845
- torch.Tensor: "Tensor",
846
- typing.List[torch.Tensor]: "Tensor[]",
847
- int: "SymInt",
848
- float: "float",
849
- bool: "bool",
850
- torch.types.Number: "Scalar",
851
- }
852
-
853
-
854
- def parse_return(annotation, error_fn):
855
- origin = typing.get_origin(annotation)
856
- if origin is not tuple:
857
- if annotation not in SUPPORTED_RETURN_TYPES.keys():
858
- error_fn(
859
- f"Return has unsupported type {annotation}. "
860
- f"The valid types are: {SUPPORTED_RETURN_TYPES}."
861
- )
862
- return SUPPORTED_RETURN_TYPES[annotation]
863
-
864
- args = typing.get_args(annotation)
865
- for arg in args:
866
- if arg not in SUPPORTED_RETURN_TYPES:
867
- error_fn(
868
- f"Return has unsupported type {annotation}. "
869
- f"The valid types are: {SUPPORTED_RETURN_TYPES}."
870
- )
871
-
872
- return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
873
-
874
-
875
- SUPPORTED_PARAM_TYPES = get_supported_param_types()
876
-
877
-
878
- def report_error_callback(custom_op: typing.Any, key: str) -> None:
879
- if key == "Undefined":
880
- raise NotImplementedError(
881
- f"{custom_op}: There were no Tensor inputs to this operator "
882
- f"(e.g. you passed an empty list of Tensors). If your operator is a "
883
- f"factory function (that is, it takes no Tensors and constructs "
884
- f"a new one), then please use CustomOp.impl_factory to register "
885
- f"an implementation for it"
886
- )
887
- if key == "Meta":
888
- raise NotImplementedError(
889
- f"{custom_op}: when running with device='Meta' tensors: there is no "
890
- f"abstract impl registered for this CustomOp. Please register one via "
891
- f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
892
- )
893
- if key in ("CPU", "CUDA"):
894
- device = key.lower()
895
- raise NotImplementedError(
896
- f"{custom_op}: when running with device='{device}' tensors: there is no "
897
- f"{device} impl registered for this CustomOp. Please register one via "
898
- f"CustomOp.impl(device_type='{device}')"
899
- )
900
- raise NotImplementedError(
901
- f"{custom_op}: No implementation for dispatch key {key}. It is likely "
902
- f"that we have not added this functionality yet, please either open an "
903
- f"issue or if you're feeling adventurous, use the low-level "
904
- f"torch.library API"
905
- )
906
-
907
-
908
- def custom_op_from_existing(op):
909
- ns = op.namespace
910
- lib = torch.library.Library(ns, "FRAGMENT")
911
- name = op.name().split("::")[-1]
912
- schema_str = str(op._schema)
913
- # CustomOp expects the schema string without the namespace
914
- schema_str = schema_str.split("::")[-1]
915
- schema = FunctionSchema.parse(schema_str)
916
- return CustomOp(lib, ns, schema, name, op, _private_access=True)
917
-
918
-
919
- def get_op(qualname):
920
- def error_not_found():
921
- raise ValueError(
922
- f"Could not find the operator {qualname}. Please make sure you have "
923
- f"already registered the operator and (if registered from C++) "
924
- f"loaded it via torch.ops.load_library.")
925
-
926
- ns, name = parse_qualname(qualname)
927
- if not hasattr(torch.ops, ns):
928
- error_not_found()
929
- opnamespace = getattr(torch.ops, ns)
930
- if not hasattr(opnamespace, name):
931
- error_not_found()
932
- packet = getattr(opnamespace, name)
933
- if not hasattr(packet, 'default'):
934
- error_not_found()
935
- return packet.default
936
-
937
-
938
- def _find_custom_op(qualname, also_check_torch_library=False):
939
- if qualname in global_registry:
940
- return global_registry[qualname]
941
- if not also_check_torch_library:
942
- raise RuntimeError(
943
- f"Could not find custom op \"{qualname}\". Did you register it via "
944
- f"the torch._custom_ops API?")
945
- overload = get_op(qualname)
946
- result = custom_op_from_existing(overload)
947
- return result
948
-
949
-
950
- def get_abstract_impl(qualname):
951
- if qualname not in torch._custom_op.impl.global_registry:
952
- return None
953
- custom_op = torch._custom_op.impl.global_registry[qualname]
954
- if custom_op is None:
955
- return None
956
- if not custom_op._has_impl("abstract"):
957
- return None
958
- return custom_op._get_impl("abstract").func
959
-
960
-
961
- def _custom_op_with_schema(qualname, schema):
962
- ns, name = qualname.split("::")
963
- schema_str = f"{name}{schema}"
964
- function_schema = FunctionSchema.parse(schema_str)
965
- validate_schema(function_schema)
966
-
967
- lib = library.Library(ns, "FRAGMENT")
968
- lib.define(schema_str)
969
- ophandle = find_ophandle_or_throw(ns, function_schema.name)
970
- result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
971
- result._register_autograd_kernel_indirection()
972
-
973
- torch._C._dispatch_set_report_error_callback(
974
- ophandle, functools.partial(report_error_callback, weakref.proxy(result))
975
- )
976
- return get_op(qualname)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_custom_ops.py DELETED
@@ -1,322 +0,0 @@
1
- import inspect
2
-
3
- from torch._custom_op.impl import (
4
- _custom_op_with_schema,
5
- _find_custom_op,
6
- infer_schema,
7
- parse_qualname,
8
- validate_namespace,
9
- )
10
- from torch.library import get_ctx
11
-
12
- __all__ = [
13
- "custom_op",
14
- "impl",
15
- "impl_abstract",
16
- "get_ctx",
17
- "impl_save_for_backward",
18
- "impl_backward",
19
- ]
20
-
21
-
22
- def custom_op(qualname, func_or_schema=None):
23
- r"""Register a new custom operator
24
-
25
- In PyTorch, defining an op (short for "operator") is a two step-process:
26
- - we need to define the op (by providing an operator name and schema)
27
- - we need to implement behavior for how the operator interacts with
28
- various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
29
-
30
- This entrypoint defines the custom operator (the first step)
31
- you must then perform the second step by calling various
32
- ``impl_*`` APIs.
33
-
34
- This API may be used as a decorator (see examples).
35
-
36
- For a detailed guide on custom ops, please see
37
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
38
-
39
- Arguments:
40
- qualname (str): Should be a string that looks like
41
- "namespace::operator_name". Operators in PyTorch need a namespace to
42
- avoid name collisions; a given operator may only be created once.
43
- If you are writing a Python library, we recommend the namespace to
44
- be the name of your top-level module.
45
- func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
46
- schema that tells PyTorch the types of the inputs/outputs.
47
- If this is a Callable, we will automatically infer the schema from
48
- the type annotations on the function (see examples). Otherwise,
49
- if you don't want to use type annotations, you may provide us the
50
- schema string.
51
-
52
- Example::
53
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
54
- >>> import torch
55
- >>> import numpy as np
56
- >>> from torch import Tensor
57
- >>>
58
- >>> # Step 1: define the custom op.
59
- >>> # We need to provide the API a "prototype function"
60
- >>> # (a function that returns NotImplementedError), from which
61
- >>> # we will infer the types of the inputs and outputs.
62
- >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
63
- >>> def numpy_sin(x: Tensor) -> Tensor:
64
- >>> raise NotImplementedError()
65
- >>>
66
- >>> # The custom op is now accessible via the torch.ops module:
67
- >>> torch.ops.mylibrary.numpy_sin
68
- >>>
69
- >>> # Step 2: Register an implementation for various PyTorch subsystems
70
- >>>
71
- >>> # Register an implementation for CPU tensors
72
- >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
73
- >>> def numpy_sin_impl_cpu(x):
74
- >>> return torch.from_numpy(np.sin(x.numpy()))
75
- >>>
76
- >>> # Register an implementation for CUDA tensors
77
- >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
78
- >>> def numpy_sin_impl_cuda(x):
79
- >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
80
- >>>
81
- >>> x = torch.randn(3)
82
- >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
83
- >>>
84
- >>> x_cuda = x.cuda()
85
- >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
86
-
87
- """
88
- ns, name = parse_qualname(qualname)
89
- validate_namespace(ns)
90
-
91
- def inner(func):
92
- if not inspect.isfunction(func):
93
- raise ValueError(
94
- f"custom_op(...)(func): Expected `func` to be a Python "
95
- f"function, got: {type(func)}"
96
- )
97
-
98
- if func.__name__ != name:
99
- raise ValueError(
100
- f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
101
- f"to have name '{name}' but got '{func.__name__}'. "
102
- f"Please either change the name of `func` or the qualname that "
103
- f"is passed to `custom_op`"
104
- )
105
-
106
- schema = infer_schema(func)
107
- _custom_op_with_schema(qualname, schema)
108
- return func
109
-
110
- if func_or_schema is None:
111
- return inner
112
- if isinstance(func_or_schema, str):
113
- _custom_op_with_schema(qualname, func_or_schema)
114
- else:
115
- return inner(func_or_schema)
116
-
117
-
118
- def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
119
- r"""Register an implementation for a device type for this custom op.
120
-
121
- If the op is passed multiple Tensor inputs with different device
122
- types, it will dispatch to the registered implementation for the highest
123
- priority device type among those present.
124
- The supported device types, in order of priority, are {'cuda', 'cpu'}.
125
-
126
- This API may be used as a decorator (see examples).
127
-
128
- For a detailed guide on custom ops, please see
129
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
130
-
131
- Arguments:
132
- device_types (str or Iterable[str]): the device type(s) to register the function for.
133
-
134
- Example::
135
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
136
- >>> import torch
137
- >>> import numpy as np
138
- >>> from torch import Tensor
139
- >>>
140
- >>> # Step 1: define the custom op.
141
- >>> # We need to provide the API a "prototype function"
142
- >>> # (a function that returns NotImplementedError), from which
143
- >>> # we will infer the types of the inputs and outputs.
144
- >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
145
- >>> def numpy_cos(x: Tensor) -> Tensor:
146
- >>> raise NotImplementedError()
147
- >>>
148
- >>> # The custom op is now accessible via the torch.ops module:
149
- >>> torch.ops.mylibrary.numpy_cos
150
- >>>
151
- >>> # Step 2: Register an implementation for various PyTorch subsystems
152
- >>>
153
- >>> # Register an implementation for CPU tensors
154
- >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
155
- >>> def numpy_cos_impl_cpu(x):
156
- >>> return torch.from_numpy(np.cos(x.numpy()))
157
- >>>
158
- >>> # Register an implementation for CUDA tensors
159
- >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
160
- >>> def numpy_cos_impl_cuda(x):
161
- >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
162
- >>>
163
- >>> x = torch.randn(3)
164
- >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
165
- >>>
166
- >>> x_cuda = x.cuda()
167
- >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
168
-
169
- """
170
-
171
- def inner(func):
172
- custom_op = _find_custom_op(qualname, also_check_torch_library=True)
173
- custom_op.impl(device_types, _stacklevel=3)(func)
174
- return func
175
-
176
- if func is None:
177
- return inner
178
- return inner(func)
179
-
180
-
181
- def impl_abstract(qualname, *, func=None):
182
- r"""Register an abstract implementation for this operator.
183
-
184
- An "abstract implementation" specifies the behavior of this operator on
185
- Tensors that carry no data. Given some input Tensors with certain properties
186
- (sizes/strides/storage_offset/device), it specifies what the properties of
187
- the output Tensors are.
188
-
189
- The abstract implementation has the same signature as the operator.
190
- It is run for both FakeTensors and meta tensors. To write an abstract
191
- implementation, assume that all Tensor inputs to the operator are
192
- regular CPU/CUDA/Meta tensors, but they do not have storage, and
193
- you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
194
- The abstract implementation must consist of only PyTorch operations
195
- (and may not directly access the storage or data of any input or
196
- intermediate Tensors).
197
-
198
- This API may be used as a decorator (see examples).
199
-
200
- For a detailed guide on custom ops, please see
201
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
202
-
203
- Examples::
204
- >>> import numpy as np
205
- >>> from torch import Tensor
206
- >>>
207
- >>> # Example 1: an operator without data-dependent output shape
208
- >>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
209
- >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
210
- >>> raise NotImplementedError()
211
- >>>
212
- >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
213
- >>> def custom_linear_abstract(x, weight):
214
- >>> assert x.dim() == 2
215
- >>> assert weight.dim() == 2
216
- >>> assert bias.dim() == 1
217
- >>> assert x.shape[1] == weight.shape[1]
218
- >>> assert weight.shape[0] == bias.shape[0]
219
- >>> assert x.device == weight.device
220
- >>>
221
- >>> return (x @ weight.t()) + bias
222
- >>>
223
- >>> # Example 2: an operator with data-dependent output shape
224
- >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
225
- >>> def custom_nonzero(x: Tensor) -> Tensor:
226
- >>> ...
227
- >>>
228
- >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
229
- >>> def custom_nonzero_abstract(x):
230
- >>> # Number of nonzero-elements is data-dependent.
231
- >>> # Since we cannot peek at the data in an abstract impl,
232
- >>> # we use the ctx object to construct a new symint that
233
- >>> # represents the data-dependent size.
234
- >>> ctx = torch._custom_ops.get_ctx()
235
- >>> nnz = ctx.create_unbacked_symint()
236
- >>> shape = [x.dim(), nnz]
237
- >>> result = x.new_empty(shape, dtype=torch.long)
238
- >>> return result
239
- >>>
240
- >>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
241
- >>> def custom_nonzero_impl(x):
242
- >>> x_np = to_numpy(x)
243
- >>> res = np.stack(np.nonzero(x_np), axis=1)
244
- >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
245
- >>> # constrain the range to at least 2
246
- >>> if res.shape[0] <= 1:
247
- >>> raise RuntimeError("not supported")
248
- >>> return torch.tensor(res, device=x.device)
249
-
250
- """
251
- import torch.library
252
-
253
- return torch.library.impl_abstract(qualname, func, _stacklevel=2)
254
-
255
-
256
- def impl_save_for_backward(qualname, *, func=None):
257
- r"""Register a function that tells us what to save for backward.
258
-
259
- Please see :func:`impl_backward` for more details.
260
- """
261
-
262
- def inner(func):
263
- custom_op = _find_custom_op(qualname, also_check_torch_library=True)
264
- custom_op.impl_save_for_backward(_stacklevel=3)(func)
265
- return func
266
-
267
- if func is None:
268
- return inner
269
- return inner(func)
270
-
271
-
272
- def impl_backward(qualname, output_differentiability=None, *, func=None):
273
- r"""Registers a backward formula for an operator.
274
-
275
- In order for an operator to work with autograd, you need to register
276
- a backward formula. There are two pieces to this:
277
- 1. You must give us a function to specify what to save for backward.
278
- Call this the "save for backward" function.
279
- 2. You must give us a function that computes gradients. Call this the
280
- "backward" function.
281
-
282
- Use `impl_save_for_backward` to define a "save for backward" function
283
- that specifies what gets saved for backward. The function should accept
284
- two arguments ``(inputs, output)`` and return the quantities to be saved
285
- for backward.
286
-
287
- During runtime, when you call the operator in a forwards pass, PyTorch
288
- will invoke the "save for backward" function with the inputs and output
289
- of the operator.
290
-
291
- Use `impl_backward` to define the "backward" function. The backward
292
- function must accept ``(ctx, saved, *grads)``:
293
- - ``ctx`` is a context object where we may provide information
294
- - ``saved`` is exactly what gets returned from the "save for backward"
295
- function
296
- - ``grads`` is one or more gradients. The number of gradients matches
297
- the number of outputs of the operator.
298
-
299
- The backward function must return a dict that maps the name of
300
- an input to the operator to its corresponding gradient. All inputs that
301
- were declared to be Tensors in the operator definition must be accounted
302
- for in the dict. The gradient may be a Tensor or None.
303
-
304
- For a detailed guide on custom ops, please see
305
- https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
306
-
307
- """
308
-
309
- def inner(func):
310
- custom_op = _find_custom_op(qualname, also_check_torch_library=True)
311
- custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
312
- return func
313
-
314
- if func is None:
315
- return inner
316
- return inner(func)
317
-
318
-
319
- def _destroy(qualname):
320
- """De-registers a custom op. For testing purposes only"""
321
- custom_op = _find_custom_op(qualname)
322
- custom_op._destroy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_decomp/__init__.py DELETED
@@ -1,444 +0,0 @@
1
- import inspect
2
- from collections import defaultdict
3
- from functools import wraps
4
- from itertools import chain
5
- from typing import Callable, Dict, List, Sequence, Union
6
-
7
- import torch
8
- import torch.library
9
- from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
10
- from torch._prims_common import CustomOutParamAnnotation
11
- from torch.utils import _pytree as pytree
12
-
13
- __all__ = [
14
- "decomposition_table",
15
- "pre_autograd_decomposition_table",
16
- "meta_table",
17
- "register_decomposition",
18
- "get_decompositions",
19
- "core_aten_decompositions",
20
- ]
21
-
22
-
23
- # TODO: relax key type here; torch registrations should be possible to; but
24
- # right now this type is accurate
25
- global_decomposition_table: Dict[
26
- str, Dict[torch._ops.OperatorBase, Callable]
27
- ] = defaultdict(dict)
28
-
29
- decomposition_table = global_decomposition_table["post_autograd"]
30
- pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
31
- meta_table = global_decomposition_table["meta"]
32
-
33
-
34
- def _add_op_to_registry(registry, op, fn):
35
- """
36
- This is an internal API for adding an op to the decomposition table.
37
-
38
- If op is OpOverload, it will be added to the registry directly.
39
- If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
40
- """
41
- overloads: List[Union[torch._ops.OperatorBase]] = []
42
- if isinstance(op, HigherOrderOperator):
43
- # There's no concept of overloads for HigherOrderOperator
44
- registry[op] = fn
45
- return
46
- elif isinstance(op, OpOverload):
47
- overloads.append(op)
48
- else:
49
- assert isinstance(op, OpOverloadPacket)
50
- for ol in op.overloads():
51
- overloads.append(getattr(op, ol))
52
-
53
- for op_overload in overloads:
54
- if op_overload in registry:
55
- raise RuntimeError(f"duplicate registrations for {op_overload}")
56
- # TorchScript dumps a bunch of extra nonsense overloads
57
- # which don't have corresponding dispatcher entries, we need
58
- # to filter those out, e.g aten.add.float_int
59
- if torch._C._dispatch_has_kernel(op_overload.name()):
60
- registry[op_overload] = fn
61
-
62
-
63
- def _convert_out_params(f):
64
- out_annotation = f.__annotations__.get("out")
65
-
66
- # If there are no out params, do not wrap the function.
67
- if not out_annotation:
68
- return f
69
-
70
- # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
71
- if getattr(out_annotation, "__origin__", None) is tuple:
72
- sig = inspect.signature(f)
73
- out_names = sig.return_annotation._fields
74
- # If out is a tuple, we need to register a function that unpacks all the out
75
- # elements as this is what native_functions.yaml expects
76
-
77
- @wraps(f)
78
- def _fn(*args, **kwargs):
79
- out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
80
- # Either all of the out kwargs are set or none of them
81
- is_none = out_kwargs[0] is None
82
- assert all((o is None) == is_none for o in out_kwargs)
83
- return f(*args, **kwargs, out=None if is_none else out_kwargs)
84
-
85
- out_params = [
86
- inspect.Parameter(
87
- o,
88
- kind=inspect.Parameter.KEYWORD_ONLY,
89
- default=None,
90
- annotation=t,
91
- )
92
- for o, t in zip(out_names, out_annotation.__args__)
93
- ]
94
- # Drop the out parameter and concatenate the new kwargs in the signature
95
- params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
96
- _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
97
- parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
98
- )
99
- # Drop the out parameter and concatenate the new kwargs in the annotations
100
- _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
101
- for o in out_params:
102
- _fn.__annotations__[o.name] = o.annotation
103
-
104
- # Propagate that this function is wrapped by `out_wrapper`
105
- _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
106
-
107
- return _fn
108
-
109
- # Alternatively, there may be a single tensor out parameter with a name
110
- # other than "out". This will need special treatment and is indicated by an
111
- # annotation, which we will remove here so it is not exposed after wrapping.
112
- custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
113
- if custom_out_param_name:
114
-
115
- @wraps(f)
116
- def _fn(*args, **kwargs):
117
- out_kwarg = kwargs.pop(custom_out_param_name, None)
118
- return f(*args, **kwargs, out=out_kwarg)
119
-
120
- out_param = inspect.Parameter(
121
- custom_out_param_name,
122
- kind=inspect.Parameter.KEYWORD_ONLY,
123
- default=None,
124
- annotation=out_annotation,
125
- )
126
-
127
- # Drop the out parameter and concatenate the new kwarg in the signature
128
- sig = inspect.signature(f)
129
- params = chain(
130
- (v for k, v in sig.parameters.items() if k != "out"), (out_param,)
131
- )
132
- _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
133
- parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
134
- )
135
-
136
- # Drop the out parameter and concatenate the new kwargs in the annotations
137
- _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
138
- _fn.__annotations__[out_param.name] = out_param.annotation
139
-
140
- return _fn
141
-
142
- return f
143
-
144
-
145
- def register_decomposition(
146
- aten_op, registry=None, *, type="post_autograd", unsafe=False
147
- ):
148
- """
149
- A decorator to register a function as a decomposition to the Python
150
- decomposition table. Use it like this::
151
-
152
- @register_decomposition(torch.ops.aten.clamp_min)
153
- def clamp_min(x):
154
- return torch.clamp(self, min=min)
155
-
156
- If you are writing a new decomposition, consider contributing it
157
- directly to PyTorch in torch._decomp.decompositions.
158
-
159
- This API is experimental; we are almost certainly going to extend
160
- the API when we make decompositions eligible for use in transforms (e.g.,
161
- autograd) and not just backend tracing, where we then need to know if a
162
- decomposition can be used to simulate a transform.
163
-
164
- By default, we also will register it to the Meta key of dispatcher,
165
- and replace the c++ Meta implementation if there is already one.
166
-
167
- unsafe kwarg is for reuse of this function for registering non-function
168
- things
169
- """
170
-
171
- assert type in {"post_autograd", "pre_autograd", "meta"}
172
-
173
- def decomposition_decorator(fn: Callable) -> Callable:
174
- if not unsafe:
175
- fn = _convert_out_params(fn)
176
-
177
- nonlocal registry
178
- if registry is None:
179
- registry = global_decomposition_table[type]
180
-
181
- def register(op):
182
- _add_op_to_registry(registry, op, fn)
183
-
184
- # To handle allowing multiple aten_ops at once
185
- pytree.tree_map_(register, aten_op)
186
- return fn
187
-
188
- return decomposition_decorator
189
-
190
-
191
- def get_decompositions(
192
- aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
193
- type: str = "post_autograd",
194
- ) -> Dict[torch._ops.OperatorBase, Callable]:
195
- """
196
- Retrieve a dictionary of decompositions corresponding to the list of
197
- operator overloads and overload packets passed as input. Overload
198
- packets will include all decomposed overloads in the packet. If there is
199
- no decomposition for a requested operator, it is silently ignored.
200
-
201
- This API is experimental; we are almost certainly going to give an alternate,
202
- more recommended formulation, where a user provides the set of operators
203
- they know how to implement, and we provide decompositions for everything
204
- not in this set.
205
- """
206
- assert type in {"post_autograd", "pre_autograd", "meta"}
207
-
208
- registry = global_decomposition_table[type]
209
- packets_to_overloads = defaultdict(list)
210
- for opo in registry:
211
- if isinstance(opo, (OpOverload, OpOverloadPacket)):
212
- packets_to_overloads[opo.overloadpacket].append(opo)
213
- decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
214
- for op in aten_ops:
215
- if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
216
- for op_overload in packets_to_overloads[op]:
217
- decompositions[op_overload] = registry[op_overload]
218
- elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
219
- decompositions[op] = registry[op]
220
- return decompositions
221
-
222
-
223
- def remove_decompositions(
224
- decompositions: Dict[torch._ops.OperatorBase, Callable],
225
- aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
226
- ) -> None:
227
- """
228
- Given a dictionary of decompositions obtained from get_decompositions(), removes
229
- operators associated with a list of operator overloads and overload packets passed
230
- as input. If the decomposition dictionary does not contain a decomposition that is
231
- specified to be removed, it is silently ignored.
232
- """
233
- for op in aten_ops:
234
- if isinstance(op, OpOverloadPacket):
235
- for overload_name in op.overloads():
236
- opo = getattr(op, overload_name)
237
- decompositions.pop(opo, None)
238
- elif isinstance(op, OpOverload):
239
- decompositions.pop(op, None)
240
-
241
-
242
- # populate the table
243
- import torch._decomp.decompositions
244
- import torch._refs
245
-
246
-
247
- # See NOTE [Core ATen Ops]
248
- #
249
- # list was copied from torch/_inductor/decomposition.py
250
- # excluding decompositions that results in prim ops
251
- # Resulting opset of decomposition is core aten ops
252
- def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
253
- aten = torch.ops.aten
254
- return get_decompositions(
255
- [
256
- aten.addcdiv,
257
- aten.addcdiv_,
258
- aten.addcmul,
259
- aten.addcmul_,
260
- aten.addr,
261
- aten.affine_grid_generator,
262
- aten.all,
263
- aten.aminmax,
264
- aten.arange.default,
265
- aten.arange.start,
266
- aten.avg_pool2d_backward,
267
- aten.baddbmm,
268
- aten.binary_cross_entropy,
269
- aten.binary_cross_entropy_backward,
270
- aten.binary_cross_entropy_with_logits,
271
- aten.celu,
272
- aten.celu_,
273
- aten.clamp_max,
274
- aten.clamp_min,
275
- aten.col2im,
276
- aten.count_nonzero,
277
- aten.cudnn_batch_norm,
278
- aten.cudnn_batch_norm_backward,
279
- aten.deg2rad,
280
- aten.deg2rad_,
281
- aten.detach,
282
- aten.diag_embed,
283
- aten.diagonal_backward,
284
- aten.dot,
285
- aten.vdot,
286
- aten.elu,
287
- aten.elu_,
288
- aten.elu_backward,
289
- aten._embedding_bag,
290
- aten.embedding_dense_backward,
291
- aten.empty_like,
292
- aten._euclidean_dist.default,
293
- aten.expand_as,
294
- aten.eye,
295
- aten.fill,
296
- aten.fill_,
297
- aten.floor_divide,
298
- aten.frac,
299
- aten.frac_,
300
- aten._fused_moving_avg_obs_fq_helper,
301
- aten.gelu_,
302
- aten.gelu_backward,
303
- aten.glu,
304
- aten.glu_backward,
305
- aten.hardshrink,
306
- aten.hardsigmoid,
307
- aten.hardsigmoid_,
308
- aten.hardsigmoid_backward,
309
- aten.hardswish,
310
- aten.hardswish_,
311
- aten.hardswish_backward,
312
- aten.hardtanh_,
313
- aten.hardtanh_backward,
314
- aten.heaviside,
315
- aten.heaviside_,
316
- aten.huber_loss,
317
- aten.huber_loss_backward,
318
- aten.im2col,
319
- aten.index_add,
320
- aten.index_add_,
321
- aten.index_copy,
322
- aten.index_copy_,
323
- aten.index_fill,
324
- aten.index_fill_,
325
- aten.isneginf,
326
- aten.isposinf,
327
- aten.l1_loss,
328
- aten.leaky_relu_,
329
- aten.leaky_relu_backward,
330
- aten.lerp,
331
- aten.lerp_,
332
- aten.linspace,
333
- aten.logaddexp,
334
- aten.logaddexp2,
335
- aten.logit,
336
- aten.logit_,
337
- aten.logit_backward,
338
- aten.log_sigmoid_backward,
339
- aten.log_sigmoid_forward,
340
- aten._log_softmax_backward_data,
341
- aten.logspace,
342
- aten.logsumexp.default,
343
- aten.masked_fill,
344
- aten.masked_fill_,
345
- aten.mish,
346
- aten.mish_,
347
- aten.mse_loss,
348
- aten.mse_loss_backward,
349
- aten.multi_margin_loss,
350
- aten.multilabel_margin_loss_forward,
351
- aten.mv,
352
- aten.mvlgamma,
353
- aten.mvlgamma_,
354
- aten.nansum,
355
- aten.nan_to_num,
356
- aten.nan_to_num_,
357
- aten.narrow,
358
- aten.native_batch_norm_backward,
359
- aten.native_dropout_backward,
360
- aten.native_group_norm_backward,
361
- aten.native_layer_norm_backward,
362
- aten.new_empty,
363
- aten.new_full,
364
- aten.new_ones,
365
- aten.new_zeros,
366
- aten.nll_loss_backward,
367
- aten.nll_loss_forward,
368
- aten.norm,
369
- aten.ones,
370
- aten.ones_like,
371
- aten._prelu_kernel,
372
- aten._prelu_kernel_backward,
373
- aten._reshape_alias,
374
- aten.rad2deg,
375
- aten.rad2deg_,
376
- aten.renorm,
377
- aten.renorm_,
378
- aten.replication_pad2d,
379
- aten.rot90,
380
- aten.rrelu_with_noise,
381
- aten.rrelu_with_noise_,
382
- aten.rsub.Scalar,
383
- aten.rsub.Tensor,
384
- aten._scaled_dot_product_flash_attention.default,
385
- aten.select_backward,
386
- aten.select_scatter,
387
- aten.sgn,
388
- aten.sgn_,
389
- aten.sigmoid_backward,
390
- aten.silu,
391
- aten.silu_,
392
- aten.silu_backward,
393
- aten.sinc,
394
- aten.sinc_,
395
- aten.slice_backward,
396
- aten.smooth_l1_loss,
397
- aten.smooth_l1_loss_backward,
398
- aten.soft_margin_loss,
399
- aten.soft_margin_loss_backward,
400
- aten._softmax_backward_data,
401
- aten.softplus,
402
- aten.softplus_backward,
403
- aten.softshrink,
404
- aten.special_entr,
405
- aten.special_log_ndtr,
406
- aten.special_xlog1py,
407
- aten.split.Tensor,
408
- aten.squeeze.default,
409
- aten.squeeze.dim,
410
- aten.std,
411
- aten.std_mean,
412
- aten.stack,
413
- aten.sum.default,
414
- aten.sum.out,
415
- aten.t,
416
- aten.tanh_backward,
417
- aten.threshold,
418
- aten.threshold_,
419
- aten.threshold_backward,
420
- aten.trace,
421
- aten.transpose.int,
422
- aten.tril,
423
- aten.tril_,
424
- aten.triu,
425
- aten.triu_,
426
- aten.unbind,
427
- aten.unfold_backward,
428
- aten.unfold_copy,
429
- aten._unsafe_index,
430
- aten.unsafe_split.Tensor,
431
- aten.unsafe_split_with_sizes,
432
- aten._unsafe_view,
433
- aten.upsample_bilinear2d,
434
- aten.upsample_nearest2d_backward,
435
- aten.view_as_complex,
436
- aten.xlogy,
437
- aten.xlogy_,
438
- aten.zero,
439
- aten.zero_,
440
- aten.zeros,
441
- aten.zeros_like,
442
- aten._weight_norm_interface,
443
- ]
444
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_decomp/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (12.2 kB)
 
torch/_decomp/__pycache__/decompositions.cpython-310.pyc DELETED
Binary file (102 kB)
 
torch/_decomp/__pycache__/decompositions_for_jvp.cpython-310.pyc DELETED
Binary file (6.27 kB)
 
torch/_decomp/__pycache__/decompositions_for_rng.cpython-310.pyc DELETED
Binary file (7.99 kB)
 
torch/_decomp/decompositions.py DELETED
The diff for this file is too large to render. See raw diff
 
torch/_decomp/decompositions_for_jvp.py DELETED
@@ -1,302 +0,0 @@
1
- import inspect
2
- from typing import Callable, Dict, List, Optional, Tuple
3
-
4
- import torch
5
- import torch._decomp
6
- from torch import Tensor
7
- from torch._prims_common.wrappers import _maybe_remove_out_wrapper
8
-
9
- decomposition_table = torch._decomp.decomposition_table
10
- decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
11
- register_decomposition = torch._decomp.register_decomposition
12
- aten = torch.ops.aten
13
-
14
- # NOTE: [forward-mode AD decompositions mechanism]
15
- #
16
- # The mechanism is in VariableType,
17
- # IF any inputs have forward grad
18
- # AND there is no forward AD formula implemented
19
- # AND the functions is actually differentiable
20
- # run the decomposition
21
- # See run_jit_decomposition_with_args_for_jvp
22
- # We currently use python decompositions that we torchscript.
23
- #
24
- # Note that we would be building the backward graph at the decomposed level
25
- # too, but that is OK, because we would've errored out otherwise anyway.
26
- #
27
- # TODO: The mechanism we are using to register decompositions doesn't
28
- # seem to be exclusively used for jvp. So open question here is whether
29
- # torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
30
- # If that is the case, we may go down the decomposition path unexpectedly
31
- # (and possibly produce an unintelligible error) vs erroring out earlier and
32
- # printing that the forward AD formula is not implemented.
33
- #
34
- # The solution to this may be to have a explicitly white list control when
35
- # to enable the decomposition.
36
-
37
-
38
- def maybe_register_decomposition(op):
39
- def decorator(f):
40
- try:
41
- return register_decomposition(op)(f)
42
- except Exception:
43
- return f
44
-
45
- return decorator
46
-
47
-
48
- # Functions where we need a special decomposition for jvp but there's another version that
49
- # should be used more generally (ex. for jvp we need to recompute the mean and variance for
50
- # the backwards of a normalization function. Without jvp, it should use the saved value)
51
- decomposition_table_for_jvp = {}
52
-
53
-
54
- def register_decomposition_for_jvp(fn):
55
- return register_decomposition(fn, registry=decomposition_table_for_jvp)
56
-
57
-
58
- def _register_jit_decomposition_for_jvp(decomp, use_python=False):
59
- if decomp in decomposition_table_for_jvp:
60
- decomposition_table_used = decomposition_table_for_jvp
61
- elif decomp in decomposition_table:
62
- decomposition_table_used = decomposition_table
63
- else:
64
- raise RuntimeError(f"could not find decomposition for {decomp}")
65
- decomp_fn = decomposition_table_used[decomp]
66
-
67
- # `out_wrapper` extends a decompositions signature with
68
- # an `out` parameter. However jit will use the unwrapped function's
69
- # signature instead so we need to unwrap here to prevent an error
70
- decomp_fn = _maybe_remove_out_wrapper(decomp_fn)
71
-
72
- if use_python:
73
- decomp_fn = torch.jit.ignore(decomp_fn)
74
- sig = inspect.signature(decomp_fn)
75
-
76
- # Create a string wrapping the function from the signature
77
- # example output:
78
- # def wrapped_decomp(x: torch.Tensor, y: int, z: int):
79
- # return decomp_fn(x, y, z)
80
- # Thanks copilot!
81
- def get_function_def(sig):
82
- param_def = [f"{param_str}" for param_str in sig.parameters.values()]
83
- param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
84
-
85
- return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
86
-
87
- f_str = get_function_def(sig)
88
- graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
89
- else:
90
- graph = torch.jit.script(decomp_fn).graph
91
- torch.jit._register_decomposition(decomp, graph)
92
-
93
-
94
- # The only decompositions here are temporary or hacks for the purposes of jvp
95
-
96
-
97
- # TODO: do these also belong here?
98
- @maybe_register_decomposition(aten.trace.default)
99
- def trace(self: Tensor) -> Tensor:
100
- return torch.sum(torch.diag(self))
101
-
102
-
103
- @maybe_register_decomposition(aten.log_sigmoid_forward.default)
104
- def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
105
- min = torch.minimum(self.new_zeros(()), self)
106
- z = torch.exp(-torch.abs(self))
107
- if self.is_cuda:
108
- buffer = self.new_zeros((0,))
109
- else:
110
- buffer = z
111
- return min - torch.log1p(z), buffer
112
-
113
-
114
- def recompute_mean_var(
115
- input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
116
- ):
117
- # for most norm decompositions, it will be the same as the core version except for here.
118
- # We recompute the mean and variance so that they track gradients through input
119
-
120
- mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
121
- var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
122
- eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
123
- eps = eps.detach()
124
- rstd = 1 / torch.sqrt(var + eps)
125
- return mean, rstd
126
-
127
-
128
- @register_decomposition_for_jvp(aten.native_layer_norm_backward)
129
- def native_layer_norm_backward(
130
- grad_out: Tensor,
131
- input: Tensor,
132
- normalized_shape: List[int],
133
- mean: Tensor,
134
- rstd: Tensor,
135
- weight: Optional[Tensor],
136
- bias: Optional[Tensor],
137
- output_mask: List[bool],
138
- ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
139
- input_shape = input.shape
140
- input_ndim = input.dim()
141
-
142
- axis = input_ndim - len(normalized_shape)
143
- inner_dims = input_shape[axis:]
144
- outer_dims = input_shape[:axis]
145
- inner_dim_indices = list(range(axis, input_ndim))
146
- outer_dim_indices = list(range(0, axis))
147
-
148
- N = 1
149
- for i in inner_dims:
150
- N *= i
151
- M = 1
152
- for i in outer_dims:
153
- M *= i
154
- if M <= 0 or N <= 0:
155
- return (
156
- input.new_zeros(input_shape),
157
- input.new_zeros(input_shape[axis:]),
158
- input.new_zeros(input_shape[axis:]),
159
- )
160
-
161
- mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
162
-
163
- x_hat = (input - mean_) * rstd_
164
- if weight is not None:
165
- grad_x_hat = grad_out * weight
166
- else:
167
- grad_x_hat = grad_out
168
- a = grad_x_hat * N
169
- b = torch.sum(grad_x_hat, inner_dim_indices, True)
170
- c1 = torch.mul(grad_x_hat, x_hat)
171
- c2 = torch.sum(c1, inner_dim_indices, True)
172
- c3 = torch.mul(x_hat, c2)
173
- inner = a - b - c3
174
-
175
- if output_mask[0]:
176
- d_input: Optional[Tensor] = (rstd_ / N) * inner
177
- else:
178
- d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
179
-
180
- if output_mask[1] and weight is not None:
181
- if len(outer_dim_indices) > 0:
182
- d_weight: Optional[Tensor] = torch.sum(
183
- grad_out * x_hat, outer_dim_indices, False
184
- )
185
- else:
186
- d_weight = grad_out * x_hat
187
- elif weight is not None:
188
- d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
189
- else:
190
- d_weight = torch.zeros(()) # should be None but doesn't work with vjp
191
-
192
- if output_mask[2] and bias is not None:
193
- if len(outer_dim_indices) > 0:
194
- d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
195
- else:
196
- d_bias = grad_out.clone()
197
- elif bias is not None:
198
- d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
199
- else:
200
- d_bias = torch.zeros(()) # should be None but doesn't work with vjp
201
-
202
- return (d_input, d_weight, d_bias)
203
-
204
-
205
- def prod(x: List[int]):
206
- r = 1
207
- for i in x:
208
- r *= i
209
- return r
210
-
211
-
212
- @register_decomposition_for_jvp(aten.native_batch_norm_backward)
213
- def native_batch_norm_backward(
214
- grad_out: Tensor,
215
- input: Tensor,
216
- weight: Optional[Tensor],
217
- running_mean: Optional[Tensor],
218
- running_var: Optional[Tensor],
219
- save_mean: Optional[Tensor],
220
- save_invstd: Optional[Tensor],
221
- train: bool,
222
- eps: float,
223
- output_mask: List[bool],
224
- ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
225
- input_shape = input.shape
226
- input_rank = input.dim()
227
- assert input_rank >= 2, "rank of the input must be at least 2"
228
-
229
- axis = 1
230
- num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
231
- mean = save_mean
232
- invstd = save_invstd
233
- if train:
234
- assert (
235
- save_mean is not None and save_invstd is not None
236
- ), "when train=True, save_mean and save_invstd are required"
237
-
238
- reduciton_dims = [0] + list(range(2, input.dim()))
239
- assert invstd is not None # for typing
240
- mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
241
- else:
242
- assert running_mean is not None and running_var is not None
243
- mean = running_mean
244
- invstd = torch.rsqrt(running_var + eps)
245
-
246
- assert invstd is not None and mean is not None
247
-
248
- broadcast_mask = [1] * input_rank
249
- broadcast_mask[axis] = input_shape[axis]
250
-
251
- reduction_axes: List[int] = []
252
- for i in range(input_rank):
253
- if i != axis:
254
- reduction_axes.append(i)
255
-
256
- mean = torch.reshape(mean, broadcast_mask)
257
- norm = 1.0 / num_features
258
- grad_output_sum = torch.sum(grad_out, reduction_axes)
259
- dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
260
-
261
- grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
262
- proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
263
-
264
- if weight is None:
265
- grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
266
- else:
267
- grad_scale = torch.reshape(invstd * weight, broadcast_mask)
268
-
269
- if train:
270
- proj = (input - mean) * proj_scale
271
- grad_input = ((grad_out - proj) - grad_mean) * grad_scale
272
- else:
273
- grad_input = grad_out * grad_scale
274
-
275
- if output_mask[1]:
276
- grad_weight = dot_p * invstd
277
- elif weight is not None:
278
- grad_weight = torch.zeros_like(
279
- weight
280
- ) # should be None but doesn't work with vjp
281
- else:
282
- grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
283
-
284
- if output_mask[2]:
285
- grad_bias = grad_output_sum
286
- else:
287
- grad_bias = torch.zeros_like(
288
- grad_output_sum
289
- ) # should be None but doesn't work with vjp
290
-
291
- return (grad_input, grad_weight, grad_bias)
292
-
293
-
294
- _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
295
- _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
296
- _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
297
- _register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
298
- _register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
299
- _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
300
- _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
301
- _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
302
- _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_decomp/decompositions_for_rng.py DELETED
@@ -1,263 +0,0 @@
1
- import functools
2
- from collections import defaultdict
3
- from typing import Callable, Dict
4
-
5
- import torch
6
- import torch._decomp as decomp
7
- from torch._decomp import get_decompositions
8
- from torch._ops import OpOverload
9
-
10
- aten = torch.ops.aten
11
-
12
- rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
13
-
14
-
15
- def register_rng_decomposition(aten_op):
16
- return decomp.register_decomposition(aten_op, rng_decompositions)
17
-
18
-
19
- def throw_on_non_cuda(device):
20
- raise RuntimeError(
21
- f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
22
- f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
23
- "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
24
- )
25
-
26
-
27
- # TODO - We have to register many more distributions here, and also higher level
28
- # ops like dropout which have fused implementation and can hide the rand inside.
29
- @register_rng_decomposition(aten.rand)
30
- def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
31
- if device and device.type != "cuda":
32
- throw_on_non_cuda(device)
33
- seed, offset = PhiloxStateTracker.get_state_as_tuple()
34
- dtype = dtype or torch.float32
35
- out, offset_jump = torch.ops.rngprims.philox_rand(
36
- shape, seed, offset, None, device, dtype
37
- )
38
- PhiloxStateTracker.advance_offset(offset_jump)
39
- return out
40
-
41
-
42
- @register_rng_decomposition(aten.rand_like)
43
- def rand_like(
44
- x: torch.Tensor,
45
- dtype=None,
46
- layout=None,
47
- device=None,
48
- pin_memory=False,
49
- memory_format=torch.preserve_format,
50
- ):
51
- device = device or x.device
52
- if device.type != "cuda":
53
- throw_on_non_cuda(device)
54
- dtype = dtype or x.dtype
55
- seed, offset = PhiloxStateTracker.get_state_as_tuple()
56
- out, offset_jump = torch.ops.rngprims.philox_rand(
57
- x.shape, seed, offset, None, device, dtype
58
- )
59
- PhiloxStateTracker.advance_offset(offset_jump)
60
- return out
61
-
62
-
63
- class PhiloxState:
64
- """
65
- Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
66
- relative_offset. seed and base_offset basically point to the rng state just
67
- before tracing starts. relative offset tracks the totally consumed offset at
68
- trace time.
69
- """
70
-
71
- def __init__(self):
72
- self.reset()
73
-
74
- def reset(self):
75
- self.seed = torch.tensor(())
76
- self.base_offset = torch.tensor(())
77
- self.relative_offset = 0
78
- self.offset_advanced_alteast_once = False
79
-
80
- def validate_state(self):
81
- assert self.seed.numel() != 0 and self.base_offset.numel() != 0
82
-
83
- def advance_offset(self, consumed_offset):
84
- self.offset_advanced_alteast_once = True
85
- self.relative_offset = self.relative_offset + consumed_offset
86
-
87
- def set_state(self, seed, base_offset, relative_offset=0):
88
- self.seed = seed
89
- self.base_offset = base_offset
90
- self.relative_offset = relative_offset
91
-
92
- def get_state_as_tuple(self):
93
- self.validate_state()
94
- return (self.seed, self.base_offset + self.relative_offset)
95
-
96
- def get_state_as_tensor(self):
97
- # Only needed because we override get_rng_state.
98
- self.validate_state()
99
- return torch.stack([self.seed, self.base_offset + self.relative_offset])
100
-
101
- def set_state_from_tensor(self, state):
102
- # Only needed because we override set_rng_state.
103
- self.seed, self.base_offset = torch.unbind(state)
104
- self.relative_offset = 0
105
-
106
-
107
- class PhiloxStateTracker:
108
- """
109
- Singleton class to track the philox rng state during AOT Autograd tracing.
110
- For each aot tracing instance, AOT Autograd resets this tracker and keeps
111
- track of both forward and backward offsets. At runtime, we only care about
112
- the total consumed forward and backward offsets. For dynamic shapes, these
113
- offsets are a function of input shapes. Therefore, the AOT generated graphs
114
- have additional outputs that compute total consumed forward and backward
115
- offsets.
116
- """
117
-
118
- running_state: PhiloxState
119
- fwd_state: PhiloxState
120
- bwd_state: PhiloxState
121
-
122
- def __enter__(self):
123
- PhiloxStateTracker.reset()
124
- return self
125
-
126
- def __exit__(self, exc_type, exc_cal, exc_tb):
127
- PhiloxStateTracker.reset()
128
-
129
- @classmethod
130
- def reset(cls):
131
- cls.running_state = PhiloxState()
132
- cls.fwd_state = PhiloxState()
133
- cls.bwd_state = PhiloxState()
134
-
135
- @classmethod
136
- def mark_beginning_of_forward(cls):
137
- # Tells the tracker to use fwd_state as the running state
138
- cls.running_state = cls.fwd_state
139
-
140
- @classmethod
141
- def mark_beginning_of_backward(cls):
142
- # Tells the tracker to use bwd_state as the running state
143
- cls.running_state = cls.bwd_state
144
-
145
- @classmethod
146
- def record_state(cls, seed, offset, mode):
147
- # Records the seed and offset tensors. These tensors are used to invoke
148
- # the philox_rand functional primitives.
149
- if mode == "forward":
150
- cls.fwd_state.set_state(seed, offset)
151
- cls.mark_beginning_of_forward()
152
- else:
153
- assert mode == "backward"
154
- cls.bwd_state.set_state(seed, offset)
155
-
156
- @classmethod
157
- def get_state_as_tensor(cls):
158
- # The only reason this exists is because we override get_rng_state and
159
- # set_rng_state during tracing. get_rng_state expects a tensor output,
160
- # so return (seed, offset) tuple upset other parts of the program like
161
- # ctx.saved_tensors.
162
-
163
- # A bad consequence is that if user saves and restores rng state, we
164
- # have little bit of ugliness in the generated code, where we first
165
- # concat the (seed, offset) to create a tensor for get_rng_state, and
166
- # then split it back to get (seed, offset) tuple in set_rng_state.
167
-
168
- # TODO: Investigate if there is be a better way to wrap the tuple in a
169
- # false Tensor object, and then desugar it later on.
170
- return cls.running_state.get_state_as_tensor()
171
-
172
- @classmethod
173
- def get_state_as_tuple(cls):
174
- return cls.running_state.get_state_as_tuple()
175
-
176
- @classmethod
177
- def set_state_from_tensor(cls, x):
178
- # This is only needed because we override set_rng_state. Look at the
179
- # comment in get_state_from_tensor method.
180
- cls.running_state.set_state_from_tensor(x)
181
-
182
- @classmethod
183
- def advance_offset(cls, consumed_offset):
184
- cls.running_state.advance_offset(consumed_offset)
185
-
186
- @classmethod
187
- def get_current_relative_offset(cls):
188
- return cls.running_state.relative_offset
189
-
190
- @staticmethod
191
- def multiple_of_4(offset):
192
- # torch cuda rng state offset must be a multiple of 4. For inductor, as
193
- # we sum up all the numel, the result might not be a multiple of 4. This
194
- # method achieves that.
195
- return (offset + 3) // 4 * 4
196
-
197
- @classmethod
198
- def get_updated_fwd_offset(cls):
199
- # Short circuit if no rand ops were observed
200
- if not cls.fwd_state.offset_advanced_alteast_once:
201
- return cls.fwd_state.base_offset
202
- return cls.multiple_of_4(
203
- cls.fwd_state.base_offset + cls.fwd_state.relative_offset
204
- )
205
-
206
- @classmethod
207
- def get_updated_bwd_offset(cls):
208
- # Short circuit if no rand ops were observed
209
- if not cls.bwd_state.offset_advanced_alteast_once:
210
- return cls.bwd_state.base_offset
211
- return cls.multiple_of_4(
212
- cls.bwd_state.base_offset + cls.bwd_state.relative_offset
213
- )
214
-
215
-
216
- # Adding more decompositions which eventually use rand_like inside decomps.
217
- # Adding these in rng_decompositions ensures the functionalization of rand_like
218
- # ops used in these decomps. The list is copied from inductor codebase, which
219
- # uses it for similar purpose.
220
- #
221
- # Caution - These decomps do not have same accuracy as that of eager. However,
222
- # we can't just disable them with a config flag like fallback_random, because
223
- # for functionalization of rng ops, we have to decompose these ops.
224
- extra_random_decomps = get_decompositions(
225
- [
226
- aten.cauchy,
227
- aten.cauchy_,
228
- aten.exponential,
229
- aten.exponential_,
230
- aten.geometric,
231
- aten.geometric_,
232
- aten.native_dropout,
233
- aten.normal,
234
- aten.normal_,
235
- aten.normal_functional,
236
- aten.log_normal,
237
- aten.log_normal_,
238
- aten.rrelu_with_noise,
239
- aten.rrelu_with_noise_,
240
- aten.uniform_,
241
- ]
242
- )
243
- register_extra_random_decomp = functools.partial(
244
- decomp.register_decomposition, registry=extra_random_decomps
245
- )
246
-
247
-
248
- @register_extra_random_decomp([aten.bernoulli_])
249
- def bernoulli_(self, p=0.5):
250
- if self.device == torch.device("cpu"):
251
- return NotImplemented
252
- return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
253
-
254
-
255
- @register_extra_random_decomp([aten.bernoulli.p])
256
- def bernoulli_p(self, p=0.5, *, generator=None):
257
- if self.device == torch.device("cpu"):
258
- return NotImplemented
259
- assert generator is None
260
- return torch.rand_like(self, dtype=torch.float32) < p
261
-
262
-
263
- rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_deploy.py DELETED
@@ -1,105 +0,0 @@
1
- import io
2
-
3
- import torch
4
- from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
5
- from torch.package._package_pickler import create_pickler
6
- from torch.package._package_unpickler import PackageUnpickler
7
- from torch.serialization import _maybe_decode_ascii
8
-
9
-
10
- def _save_storages(importer, obj):
11
- serialized_storages = []
12
- serialized_dtypes = []
13
-
14
- importer = importer if isinstance(importer, torch.package.PackageImporter) else None
15
- importers: Importer
16
- if importer is not None:
17
- importers = OrderedImporter(importer, sys_importer)
18
- else:
19
- importers = sys_importer
20
-
21
- def persistent_id(obj):
22
- if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
23
- if isinstance(obj, torch.storage.TypedStorage):
24
- # TODO: Once we decide to break serialization FC, we can
25
- # remove this case
26
- storage = obj._untyped_storage
27
- dtype = obj.dtype
28
- else:
29
- storage = obj
30
- dtype = torch.uint8
31
-
32
- serialized_storages.append(obj)
33
- serialized_dtypes.append(dtype)
34
- return ("storage", len(serialized_storages) - 1)
35
-
36
- if hasattr(obj, "__reduce_deploy__"):
37
- if _serialized_reduces.get(id(obj)) is None:
38
- _serialized_reduces[id(obj)] = (
39
- "reduce_deploy",
40
- id(obj),
41
- *obj.__reduce_deploy__(importers),
42
- )
43
- return _serialized_reduces[id(obj)]
44
-
45
- return None
46
-
47
- # Write the pickle data for `obj`
48
- data_buf = io.BytesIO()
49
- pickler = create_pickler(data_buf, importers)
50
- pickler.persistent_id = persistent_id
51
- pickler.dump(obj)
52
- data_value = data_buf.getvalue()
53
- return (
54
- data_value,
55
- serialized_storages,
56
- serialized_dtypes,
57
- importer.zip_reader if importer else None,
58
- )
59
-
60
-
61
- def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
62
- def persistent_load(saved_id):
63
- assert isinstance(saved_id, tuple)
64
- typename = _maybe_decode_ascii(saved_id[0])
65
- data = saved_id[1:]
66
-
67
- if typename == "storage":
68
- # TODO: Once we decide to break serialization FC, we can
69
- # stop wrapping with TypedStorage
70
- storage = serialized_storages[data[0]]
71
- dtype = serialized_dtypes[data[0]]
72
- return torch.storage.TypedStorage(
73
- wrap_storage=storage.untyped(), dtype=dtype
74
- )
75
-
76
- if typename == "reduce_deploy":
77
- reduce_id, func, args = data
78
- if reduce_id not in _loaded_reduces:
79
- _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
80
- return _loaded_reduces[reduce_id]
81
-
82
- return None
83
-
84
- importer: Importer
85
- if zip_reader is not None:
86
- importer = OrderedImporter(_get_package(zip_reader), sys_importer)
87
- else:
88
- importer = sys_importer
89
-
90
- unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
91
- unpickler.persistent_load = persistent_load # type: ignore[assignment]
92
- result = _deploy_objects[id] = unpickler.load()
93
- return result
94
-
95
-
96
- def _get_package(zip_reader):
97
- if zip_reader not in _raw_packages:
98
- _raw_packages[zip_reader] = PackageImporter(zip_reader)
99
- return _raw_packages[zip_reader]
100
-
101
-
102
- _raw_packages: dict = {}
103
- _deploy_objects: dict = {}
104
- _serialized_reduces: dict = {}
105
- _loaded_reduces: dict = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch/_dispatch/__init__.py DELETED
File without changes
torch/_dispatch/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (155 Bytes)