Spaces:
Sleeping
Sleeping
Delete torch
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- torch/_C.cp310-win_amd64.pyd +0 -0
- torch/_C/_VariableFunctions.pyi +0 -0
- torch/_C/__init__.pyi +0 -0
- torch/_C/_autograd.pyi +0 -123
- torch/_C/_cpu.pyi +0 -5
- torch/_C/_cudnn.pyi +0 -17
- torch/_C/_distributed_autograd.pyi +0 -26
- torch/_C/_distributed_c10d.pyi +0 -478
- torch/_C/_distributed_rpc.pyi +0 -188
- torch/_C/_distributed_rpc_testing.pyi +0 -35
- torch/_C/_functions.pyi +0 -11
- torch/_C/_functorch.pyi +0 -71
- torch/_C/_itt.pyi +0 -5
- torch/_C/_lazy.pyi +0 -28
- torch/_C/_lazy_ts_backend.pyi +0 -11
- torch/_C/_monitor.pyi +0 -44
- torch/_C/_nn.pyi +0 -86
- torch/_C/_nvtx.pyi +0 -6
- torch/_C/_onnx.pyi +0 -38
- torch/_C/_profiler.pyi +0 -238
- torch/_C/_verbose.pyi +0 -3
- torch/_VF.py +0 -30
- torch/_VF.pyi +0 -0
- torch/__config__.py +0 -22
- torch/__future__.py +0 -21
- torch/_appdirs.py +0 -666
- torch/_awaits/__init__.py +0 -54
- torch/_awaits/__pycache__/__init__.cpython-310.pyc +0 -0
- torch/_classes.py +0 -55
- torch/_compile.py +0 -30
- torch/_custom_op/__init__.py +0 -0
- torch/_custom_op/__pycache__/__init__.cpython-310.pyc +0 -0
- torch/_custom_op/__pycache__/autograd.cpython-310.pyc +0 -0
- torch/_custom_op/__pycache__/functional.cpython-310.pyc +0 -0
- torch/_custom_op/__pycache__/impl.cpython-310.pyc +0 -0
- torch/_custom_op/autograd.py +0 -274
- torch/_custom_op/functional.py +0 -187
- torch/_custom_op/impl.py +0 -976
- torch/_custom_ops.py +0 -322
- torch/_decomp/__init__.py +0 -444
- torch/_decomp/__pycache__/__init__.cpython-310.pyc +0 -0
- torch/_decomp/__pycache__/decompositions.cpython-310.pyc +0 -0
- torch/_decomp/__pycache__/decompositions_for_jvp.cpython-310.pyc +0 -0
- torch/_decomp/__pycache__/decompositions_for_rng.cpython-310.pyc +0 -0
- torch/_decomp/decompositions.py +0 -0
- torch/_decomp/decompositions_for_jvp.py +0 -302
- torch/_decomp/decompositions_for_rng.py +0 -263
- torch/_deploy.py +0 -105
- torch/_dispatch/__init__.py +0 -0
- torch/_dispatch/__pycache__/__init__.cpython-310.pyc +0 -0
torch/_C.cp310-win_amd64.pyd
DELETED
Binary file (10.2 kB)
|
|
torch/_C/_VariableFunctions.pyi
DELETED
The diff for this file is too large to render.
See raw diff
|
|
torch/_C/__init__.pyi
DELETED
The diff for this file is too large to render.
See raw diff
|
|
torch/_C/_autograd.pyi
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
from typing import Any, Callable, List, Optional, Set
|
3 |
-
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from ._profiler import (
|
7 |
-
_ProfilerEvent,
|
8 |
-
ActiveProfilerType,
|
9 |
-
ProfilerActivity,
|
10 |
-
ProfilerConfig,
|
11 |
-
)
|
12 |
-
|
13 |
-
# Defined in tools/autograd/init.cpp
|
14 |
-
|
15 |
-
class DeviceType(Enum):
|
16 |
-
CPU = ...
|
17 |
-
CUDA = ...
|
18 |
-
MKLDNN = ...
|
19 |
-
OPENGL = ...
|
20 |
-
OPENCL = ...
|
21 |
-
IDEEP = ...
|
22 |
-
HIP = ...
|
23 |
-
FPGA = ...
|
24 |
-
ORT = ...
|
25 |
-
XLA = ...
|
26 |
-
MPS = ...
|
27 |
-
HPU = ...
|
28 |
-
Meta = ...
|
29 |
-
Vulkan = ...
|
30 |
-
Metal = ...
|
31 |
-
PrivateUse1 = ...
|
32 |
-
|
33 |
-
class ProfilerEvent:
|
34 |
-
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
35 |
-
def cpu_memory_usage(self) -> int: ...
|
36 |
-
def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
37 |
-
def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
|
38 |
-
def cuda_memory_usage(self) -> int: ...
|
39 |
-
def device(self) -> int: ...
|
40 |
-
def handle(self) -> int: ...
|
41 |
-
def has_cuda(self) -> bool: ...
|
42 |
-
def is_remote(self) -> bool: ...
|
43 |
-
def kind(self) -> int: ...
|
44 |
-
def name(self) -> str: ...
|
45 |
-
def node_id(self) -> int: ...
|
46 |
-
def sequence_nr(self) -> int: ...
|
47 |
-
def shapes(self) -> List[List[int]]: ...
|
48 |
-
def thread_id(self) -> int: ...
|
49 |
-
def flops(self) -> float: ...
|
50 |
-
def is_async(self) -> bool: ...
|
51 |
-
|
52 |
-
class _KinetoEvent:
|
53 |
-
def name(self) -> str: ...
|
54 |
-
def device_index(self) -> int: ...
|
55 |
-
def start_us(self) -> int: ...
|
56 |
-
def duration_us(self) -> int: ...
|
57 |
-
def is_async(self) -> bool: ...
|
58 |
-
def linked_correlation_id(self) -> int: ...
|
59 |
-
def shapes(self) -> List[List[int]]: ...
|
60 |
-
def dtypes(self) -> List[str]: ...
|
61 |
-
def concrete_inputs(self) -> List[Any]: ...
|
62 |
-
def device_type(self) -> DeviceType: ...
|
63 |
-
def start_thread_id(self) -> int: ...
|
64 |
-
def end_thread_id(self) -> int: ...
|
65 |
-
def correlation_id(self) -> int: ...
|
66 |
-
def fwd_thread_id(self) -> int: ...
|
67 |
-
def stack(self) -> List[str]: ...
|
68 |
-
def scope(self) -> int: ...
|
69 |
-
def sequence_nr(self) -> int: ...
|
70 |
-
def flops(self) -> int: ...
|
71 |
-
def cuda_elapsed_us(self) -> int: ...
|
72 |
-
def privateuse1_elapsed_us(self) -> int: ...
|
73 |
-
|
74 |
-
class _ProfilerResult:
|
75 |
-
def events(self) -> List[_KinetoEvent]: ...
|
76 |
-
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
|
77 |
-
def save(self, path: str) -> None: ...
|
78 |
-
def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
|
79 |
-
def trace_start_us(self) -> int: ...
|
80 |
-
|
81 |
-
class SavedTensor: ...
|
82 |
-
|
83 |
-
def _enable_profiler(
|
84 |
-
config: ProfilerConfig,
|
85 |
-
activities: Set[ProfilerActivity],
|
86 |
-
) -> None: ...
|
87 |
-
def _prepare_profiler(
|
88 |
-
config: ProfilerConfig,
|
89 |
-
activities: Set[ProfilerActivity],
|
90 |
-
) -> None: ...
|
91 |
-
def _disable_profiler() -> _ProfilerResult: ...
|
92 |
-
def _profiler_enabled() -> bool: ...
|
93 |
-
def _add_metadata_json(key: str, value: str) -> None: ...
|
94 |
-
def _kineto_step() -> None: ...
|
95 |
-
def _get_sequence_nr() -> int: ...
|
96 |
-
def kineto_available() -> bool: ...
|
97 |
-
def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
|
98 |
-
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
|
99 |
-
def _supported_activities() -> Set[ProfilerActivity]: ...
|
100 |
-
def _enable_record_function(enable: bool) -> None: ...
|
101 |
-
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
|
102 |
-
def _push_saved_tensors_default_hooks(
|
103 |
-
pack_hook: Callable[[torch.Tensor], Any],
|
104 |
-
unpack_hook: Callable[[Any], torch.Tensor],
|
105 |
-
) -> None: ...
|
106 |
-
def _pop_saved_tensors_default_hooks() -> None: ...
|
107 |
-
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
|
108 |
-
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
109 |
-
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
|
110 |
-
def _profiler_type() -> ActiveProfilerType: ...
|
111 |
-
def _saved_tensors_hooks_enable() -> None: ...
|
112 |
-
def _saved_tensors_hooks_disable(message: str) -> None: ...
|
113 |
-
def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
|
114 |
-
|
115 |
-
class CreationMeta(Enum):
|
116 |
-
DEFAULT = ...
|
117 |
-
IN_CUSTOM_FUNCTION = ...
|
118 |
-
MULTI_OUTPUT_NODE = ...
|
119 |
-
NO_GRAD_MODE = ...
|
120 |
-
INFERENCE_MODE = ...
|
121 |
-
|
122 |
-
def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
|
123 |
-
def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_cpu.pyi
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
from torch.types import _bool
|
2 |
-
|
3 |
-
# Defined in torch/csrc/cpu/Module.cpp
|
4 |
-
|
5 |
-
def _is_cpu_support_vnni() -> _bool: ...
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_cudnn.pyi
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
|
3 |
-
from torch.types import _bool, Tuple
|
4 |
-
|
5 |
-
# Defined in torch/csrc/cuda/shared/cudnn.cpp
|
6 |
-
is_cuda: _bool
|
7 |
-
|
8 |
-
def getRuntimeVersion() -> Tuple[int, int, int]: ...
|
9 |
-
def getCompileVersion() -> Tuple[int, int, int]: ...
|
10 |
-
def getVersionInt() -> int: ...
|
11 |
-
|
12 |
-
class RNNMode(int, Enum):
|
13 |
-
value: int
|
14 |
-
rnn_relu = ...
|
15 |
-
rnn_tanh = ...
|
16 |
-
lstm = ...
|
17 |
-
gru = ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_distributed_autograd.pyi
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, List, Set
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
# This module is defined in torch/csrc/distributed/autograd/init.cpp
|
6 |
-
|
7 |
-
class DistAutogradContext:
|
8 |
-
def _context_id(self) -> int: ...
|
9 |
-
def _recv_functions(self) -> Dict[int, Any]: ...
|
10 |
-
def _send_functions(self) -> Dict[int, Any]: ...
|
11 |
-
def _known_worker_ids(self) -> Set[int]: ...
|
12 |
-
|
13 |
-
def _new_context() -> DistAutogradContext: ...
|
14 |
-
def _release_context(context_id: int) -> None: ...
|
15 |
-
def _get_max_id() -> int: ...
|
16 |
-
def _is_valid_context(worker_id: int) -> bool: ...
|
17 |
-
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
|
18 |
-
def _current_context() -> DistAutogradContext: ...
|
19 |
-
def _init(worker_id: int) -> None: ...
|
20 |
-
def _get_debug_info() -> Dict[str, str]: ...
|
21 |
-
def backward(
|
22 |
-
context_id: int,
|
23 |
-
roots: List[torch.Tensor],
|
24 |
-
retain_graph=False,
|
25 |
-
) -> None: ...
|
26 |
-
def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_distributed_c10d.pyi
DELETED
@@ -1,478 +0,0 @@
|
|
1 |
-
# mypy: disable-error-code="type-arg"
|
2 |
-
from datetime import timedelta
|
3 |
-
from enum import Enum
|
4 |
-
from typing import Any, Dict, List, Optional, overload, Tuple, Union
|
5 |
-
|
6 |
-
from torch import Tensor
|
7 |
-
from torch._C import ScriptObject
|
8 |
-
from torch.futures import Future
|
9 |
-
|
10 |
-
# This module is defined in torch/csrc/distributed/c10d/init.cpp
|
11 |
-
|
12 |
-
_DEFAULT_FIRST_BUCKET_BYTES: int
|
13 |
-
_DEFAULT_NO_TIMEOUT: timedelta
|
14 |
-
_DEFAULT_PG_TIMEOUT: timedelta
|
15 |
-
_DEFAULT_PG_NCCL_TIMEOUT: timedelta
|
16 |
-
|
17 |
-
class BuiltinCommHookType(Enum):
|
18 |
-
ALLREDUCE = ...
|
19 |
-
FP16_COMPRESS = ...
|
20 |
-
|
21 |
-
def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
|
22 |
-
def _register_builtin_comm_hook(
|
23 |
-
reducer: Reducer,
|
24 |
-
comm_hook_type: BuiltinCommHookType,
|
25 |
-
): ...
|
26 |
-
|
27 |
-
class GradBucket:
|
28 |
-
def index(self) -> int: ...
|
29 |
-
def buffer(self) -> Tensor: ...
|
30 |
-
def gradients(self) -> List[Tensor]: ...
|
31 |
-
def is_last(self) -> bool: ...
|
32 |
-
def set_buffer(self, tensor: Tensor) -> None: ...
|
33 |
-
def parameters(self) -> List[Tensor]: ...
|
34 |
-
|
35 |
-
class Reducer:
|
36 |
-
def __init__(
|
37 |
-
self,
|
38 |
-
params: List[Tensor],
|
39 |
-
bucket_indices: List[List[int]],
|
40 |
-
per_bucket_size_limits: List[int],
|
41 |
-
process_group: ProcessGroup,
|
42 |
-
expect_sparse_gradients: List[bool] = ...,
|
43 |
-
bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
|
44 |
-
find_unused_parameters: bool = ...,
|
45 |
-
gradient_as_bucket_view: bool = ...,
|
46 |
-
param_to_name_mapping: Dict[int, str] = ...,
|
47 |
-
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
|
48 |
-
): ...
|
49 |
-
def prepare_for_forward(self) -> None: ...
|
50 |
-
def prepare_for_backward(self, output: List[Tensor]) -> None: ...
|
51 |
-
def get_backward_stats(self) -> List[int]: ...
|
52 |
-
def _install_post_backward_futures(self, futures: List[Future]) -> None: ...
|
53 |
-
def _rebuild_buckets(self) -> bool: ...
|
54 |
-
def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ...
|
55 |
-
def _push_all_rebuilt_params(self) -> None: ...
|
56 |
-
def _set_forward_pass_work_handle(
|
57 |
-
self,
|
58 |
-
work: Work,
|
59 |
-
use_static_world_size: bool,
|
60 |
-
): ...
|
61 |
-
def _get_local_used_map(self) -> Tensor: ...
|
62 |
-
def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
|
63 |
-
def _set_static_graph(self) -> None: ...
|
64 |
-
def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
|
65 |
-
def set_logger(self, logger: Logger) -> None: ...
|
66 |
-
def _remove_autograd_hooks(self) -> None: ...
|
67 |
-
def _check_reducer_finalized(self) -> None: ...
|
68 |
-
def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
|
69 |
-
def _reset_state(self) -> None: ...
|
70 |
-
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
|
71 |
-
|
72 |
-
class DDPLoggingData:
|
73 |
-
strs_map: Dict[str, str]
|
74 |
-
ints_map: Dict[str, int]
|
75 |
-
|
76 |
-
class Logger:
|
77 |
-
def __init__(self, reducer: Reducer): ...
|
78 |
-
def set_construction_data_and_log(
|
79 |
-
self,
|
80 |
-
module_name: str,
|
81 |
-
device_ids: List[int],
|
82 |
-
output_device: int,
|
83 |
-
broadcast_buffers: bool,
|
84 |
-
has_sync_bn: bool,
|
85 |
-
static_graph: bool,
|
86 |
-
): ...
|
87 |
-
def set_runtime_stats_and_log(self) -> None: ...
|
88 |
-
def set_error_and_log(self, error: str) -> None: ...
|
89 |
-
def _get_ddp_logging_data(self) -> DDPLoggingData: ...
|
90 |
-
def _set_comm_hook_name(self, comm_hook: str) -> None: ...
|
91 |
-
def _set_uneven_input_join(self) -> None: ...
|
92 |
-
def _set_static_graph(self) -> None: ...
|
93 |
-
|
94 |
-
def get_debug_level(): ...
|
95 |
-
def set_debug_level(): ...
|
96 |
-
def set_debug_level_from_env(): ...
|
97 |
-
|
98 |
-
class DebugLevel(Enum):
|
99 |
-
OFF = ...
|
100 |
-
INFO = ...
|
101 |
-
DETAIL = ...
|
102 |
-
|
103 |
-
class ReduceOp:
|
104 |
-
def __init__(self, op: RedOpType): ...
|
105 |
-
|
106 |
-
SUM: RedOpType = ...
|
107 |
-
AVG: RedOpType = ...
|
108 |
-
PRODUCT: RedOpType = ...
|
109 |
-
MIN: RedOpType = ...
|
110 |
-
MAX: RedOpType = ...
|
111 |
-
BAND: RedOpType = ...
|
112 |
-
BOR: RedOpType = ...
|
113 |
-
BXOR: RedOpType = ...
|
114 |
-
PREMUL_SUM: RedOpType = ...
|
115 |
-
UNUSED: RedOpType = ...
|
116 |
-
|
117 |
-
class RedOpType(Enum): ...
|
118 |
-
|
119 |
-
class BroadcastOptions:
|
120 |
-
rootRank: int
|
121 |
-
rootTensor: int
|
122 |
-
timeout: timedelta
|
123 |
-
asyncOp: bool
|
124 |
-
|
125 |
-
class AllreduceOptions:
|
126 |
-
reduceOp: ReduceOp
|
127 |
-
timeout: timedelta
|
128 |
-
|
129 |
-
class AllreduceCoalescedOptions(AllreduceOptions): ...
|
130 |
-
|
131 |
-
class ReduceOptions:
|
132 |
-
reduceOp: ReduceOp
|
133 |
-
rootRank: int
|
134 |
-
rootTensor: int
|
135 |
-
timeout: timedelta
|
136 |
-
|
137 |
-
class AllgatherOptions:
|
138 |
-
timeout: timedelta
|
139 |
-
asyncOp: bool
|
140 |
-
|
141 |
-
class GatherOptions:
|
142 |
-
rootRank: int
|
143 |
-
timeout: timedelta
|
144 |
-
|
145 |
-
class ScatterOptions:
|
146 |
-
rootRank: int
|
147 |
-
timeout: timedelta
|
148 |
-
asyncOp: bool
|
149 |
-
|
150 |
-
class ReduceScatterOptions:
|
151 |
-
reduceOp: ReduceOp
|
152 |
-
timeout: timedelta
|
153 |
-
asyncOp: bool
|
154 |
-
|
155 |
-
class BarrierOptions:
|
156 |
-
device_ids: List[int]
|
157 |
-
timeout: timedelta
|
158 |
-
|
159 |
-
class AllToAllOptions:
|
160 |
-
timeout: timedelta
|
161 |
-
|
162 |
-
class Store:
|
163 |
-
def set(self, key: str, value: str): ...
|
164 |
-
def get(self, key: str) -> bytes: ...
|
165 |
-
def add(self, key: str, value: int) -> int: ...
|
166 |
-
def compare_set(
|
167 |
-
self,
|
168 |
-
key: str,
|
169 |
-
expected_value: str,
|
170 |
-
desired_value: str,
|
171 |
-
) -> bytes: ...
|
172 |
-
def delete_key(self, key: str) -> bool: ...
|
173 |
-
def num_keys(self) -> int: ...
|
174 |
-
def set_timeout(self, timeout: timedelta): ...
|
175 |
-
@overload
|
176 |
-
def wait(self, keys: List[str]): ...
|
177 |
-
@overload
|
178 |
-
def wait(self, keys: List[str], timeout: timedelta): ...
|
179 |
-
|
180 |
-
class FileStore(Store):
|
181 |
-
def __init__(self, path: str, numWorkers: int = ...): ...
|
182 |
-
|
183 |
-
class HashStore(Store):
|
184 |
-
def __init__(self): ...
|
185 |
-
|
186 |
-
class TCPStore(Store):
|
187 |
-
def __init__(
|
188 |
-
self,
|
189 |
-
host_name: str,
|
190 |
-
port: int,
|
191 |
-
world_size: Optional[int] = ...,
|
192 |
-
is_master: bool = ...,
|
193 |
-
timeout: timedelta = ...,
|
194 |
-
wait_for_workers: bool = ...,
|
195 |
-
multi_tenant: bool = ...,
|
196 |
-
master_listen_fd: Optional[int] = ...,
|
197 |
-
use_libuv: Optional[bool] = ...,
|
198 |
-
): ...
|
199 |
-
@property
|
200 |
-
def host(self) -> str: ...
|
201 |
-
@property
|
202 |
-
def port(self) -> int: ...
|
203 |
-
|
204 |
-
class PrefixStore(Store):
|
205 |
-
def __init__(self, prefix: str, store: Store): ...
|
206 |
-
@property
|
207 |
-
def underlying_store(self) -> Store: ...
|
208 |
-
|
209 |
-
class Work:
|
210 |
-
def is_completed(self) -> bool: ...
|
211 |
-
def is_success(self) -> bool: ...
|
212 |
-
def exception(self) -> Any: ...
|
213 |
-
def wait(self, timeout: timedelta = ...) -> bool: ...
|
214 |
-
def source_rank(self) -> int: ...
|
215 |
-
def _source_rank(self) -> int: ...
|
216 |
-
def result(self) -> List[Tensor]: ...
|
217 |
-
def synchronize(self): ...
|
218 |
-
def boxed(self) -> ScriptObject: ...
|
219 |
-
@staticmethod
|
220 |
-
def unbox(obj: ScriptObject) -> Work: ...
|
221 |
-
|
222 |
-
class ProcessGroup:
|
223 |
-
class Options: ...
|
224 |
-
|
225 |
-
def __init__(self): ...
|
226 |
-
def rank(self) -> int: ...
|
227 |
-
def size(self) -> int: ...
|
228 |
-
@overload
|
229 |
-
def broadcast(
|
230 |
-
self,
|
231 |
-
tensors: List[Tensor],
|
232 |
-
opts=...,
|
233 |
-
) -> Work: ...
|
234 |
-
@overload
|
235 |
-
def broadcast(
|
236 |
-
self,
|
237 |
-
tensor: Tensor,
|
238 |
-
root: int,
|
239 |
-
) -> Work: ...
|
240 |
-
@overload
|
241 |
-
def allreduce(
|
242 |
-
self,
|
243 |
-
tensors: List[Tensor],
|
244 |
-
opts: AllreduceOptions = ...,
|
245 |
-
) -> Work: ...
|
246 |
-
@overload
|
247 |
-
def allreduce(
|
248 |
-
self,
|
249 |
-
tensors: List[Tensor],
|
250 |
-
op=...,
|
251 |
-
) -> Work: ...
|
252 |
-
@overload
|
253 |
-
def allreduce(
|
254 |
-
self,
|
255 |
-
tensor: Tensor,
|
256 |
-
op=...,
|
257 |
-
) -> Work: ...
|
258 |
-
def allreduce_coalesced(
|
259 |
-
self,
|
260 |
-
tensors: List[Tensor],
|
261 |
-
opts=...,
|
262 |
-
) -> Work: ...
|
263 |
-
@overload
|
264 |
-
def reduce(
|
265 |
-
self,
|
266 |
-
tensors: List[Tensor],
|
267 |
-
opts=...,
|
268 |
-
) -> Work: ...
|
269 |
-
@overload
|
270 |
-
def reduce(
|
271 |
-
self,
|
272 |
-
tensor: Tensor,
|
273 |
-
root: int,
|
274 |
-
op=...,
|
275 |
-
) -> Work: ...
|
276 |
-
@overload
|
277 |
-
def allgather(
|
278 |
-
self,
|
279 |
-
output_tensors: List[List[Tensor]],
|
280 |
-
input_tensors: List[Tensor],
|
281 |
-
opts=...,
|
282 |
-
) -> Work: ...
|
283 |
-
@overload
|
284 |
-
def allgather(
|
285 |
-
self,
|
286 |
-
output_tensors: List[Tensor],
|
287 |
-
input_tensor: Tensor,
|
288 |
-
) -> Work: ...
|
289 |
-
def _allgather_base(
|
290 |
-
self,
|
291 |
-
output: Tensor,
|
292 |
-
input: Tensor,
|
293 |
-
opts=...,
|
294 |
-
) -> Work: ...
|
295 |
-
def allgather_coalesced(
|
296 |
-
self,
|
297 |
-
output_lists: List[List[Tensor]],
|
298 |
-
input_list: List[Tensor],
|
299 |
-
opts=...,
|
300 |
-
) -> Work: ...
|
301 |
-
@overload
|
302 |
-
def gather(
|
303 |
-
self,
|
304 |
-
output_tensors: List[List[Tensor]],
|
305 |
-
input_tensors: List[Tensor],
|
306 |
-
opts=...,
|
307 |
-
) -> Work: ...
|
308 |
-
@overload
|
309 |
-
def gather(
|
310 |
-
self,
|
311 |
-
output_tensors: List[Tensor],
|
312 |
-
input_tensor: Tensor,
|
313 |
-
root: int,
|
314 |
-
) -> Work: ...
|
315 |
-
@overload
|
316 |
-
def scatter(
|
317 |
-
self,
|
318 |
-
output_tensors: List[Tensor],
|
319 |
-
input_tensors: List[List[Tensor]],
|
320 |
-
opts=...,
|
321 |
-
) -> Work: ...
|
322 |
-
@overload
|
323 |
-
def scatter(
|
324 |
-
self,
|
325 |
-
output_tensor: Tensor,
|
326 |
-
input_tensors: List[Tensor],
|
327 |
-
root: int,
|
328 |
-
) -> Work: ...
|
329 |
-
@overload
|
330 |
-
def reduce_scatter(
|
331 |
-
self,
|
332 |
-
output_tensors: List[Tensor],
|
333 |
-
input_tensors: List[List[Tensor]],
|
334 |
-
opts=...,
|
335 |
-
) -> Work: ...
|
336 |
-
@overload
|
337 |
-
def reduce_scatter(
|
338 |
-
self,
|
339 |
-
output_tensors: Tensor,
|
340 |
-
input_tensor: List[Tensor],
|
341 |
-
) -> Work: ...
|
342 |
-
def _reduce_scatter_base(
|
343 |
-
self,
|
344 |
-
outputTensor: Tensor,
|
345 |
-
inputTensor: Tensor,
|
346 |
-
) -> Work: ...
|
347 |
-
@overload
|
348 |
-
def alltoall_base(
|
349 |
-
self,
|
350 |
-
output_tensor: Tensor,
|
351 |
-
input_tensor: Tensor,
|
352 |
-
output_split_sizes: List[int],
|
353 |
-
input_split_sizes: List[int],
|
354 |
-
opts=...,
|
355 |
-
) -> Work: ...
|
356 |
-
@overload
|
357 |
-
def alltoall_base(
|
358 |
-
self,
|
359 |
-
output: Tensor,
|
360 |
-
input: Tensor,
|
361 |
-
output_split_sizes: List[int],
|
362 |
-
input_split_sizes: List[int],
|
363 |
-
) -> Work: ...
|
364 |
-
@overload
|
365 |
-
def alltoall(
|
366 |
-
self,
|
367 |
-
output_tensor: List[Tensor],
|
368 |
-
input_tensor: List[Tensor],
|
369 |
-
opts=...,
|
370 |
-
) -> Work: ...
|
371 |
-
@overload
|
372 |
-
def alltoall(
|
373 |
-
self,
|
374 |
-
output: List[Tensor],
|
375 |
-
input: List[Tensor],
|
376 |
-
) -> Work: ...
|
377 |
-
def send(
|
378 |
-
self,
|
379 |
-
tensors: List[Tensor],
|
380 |
-
dstRank: int,
|
381 |
-
tag: int,
|
382 |
-
) -> Work: ...
|
383 |
-
def recv(
|
384 |
-
self,
|
385 |
-
tensors: List[Tensor],
|
386 |
-
srcRank: int,
|
387 |
-
tag: int,
|
388 |
-
) -> Work: ...
|
389 |
-
def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
|
390 |
-
def barrier(self, opts=...) -> Work: ...
|
391 |
-
def boxed(self) -> ScriptObject: ...
|
392 |
-
@staticmethod
|
393 |
-
def unbox(obj: ScriptObject) -> ProcessGroup: ...
|
394 |
-
|
395 |
-
class ProcessGroupRoundRobin(ProcessGroup): ...
|
396 |
-
|
397 |
-
def _round_robin_process_groups(
|
398 |
-
process_groups: List[ProcessGroup],
|
399 |
-
) -> ProcessGroupRoundRobin: ...
|
400 |
-
|
401 |
-
class ProcessGroupGloo(ProcessGroup):
|
402 |
-
class Device: ...
|
403 |
-
class Options: ...
|
404 |
-
|
405 |
-
def __init__(
|
406 |
-
self,
|
407 |
-
store: Store,
|
408 |
-
rank: int,
|
409 |
-
size: int,
|
410 |
-
timeout: timedelta,
|
411 |
-
): ...
|
412 |
-
@staticmethod
|
413 |
-
def create_device(hostname="", interface="") -> Device: ...
|
414 |
-
@staticmethod
|
415 |
-
def create_default_device() -> Device: ...
|
416 |
-
|
417 |
-
class _ProcessGroupWrapper(ProcessGroup):
|
418 |
-
def __init__(self, pg: ProcessGroup, gloo_pg: ProcessGroupGloo): ...
|
419 |
-
wrapped_pg: ProcessGroup
|
420 |
-
|
421 |
-
class ProcessGroupNCCL(ProcessGroup):
|
422 |
-
class Options: ...
|
423 |
-
|
424 |
-
def __init__(
|
425 |
-
self,
|
426 |
-
store: Store,
|
427 |
-
rank: int,
|
428 |
-
size: int,
|
429 |
-
timeout: timedelta,
|
430 |
-
): ...
|
431 |
-
def _group_start(self) -> None: ...
|
432 |
-
def _group_end(self) -> None: ...
|
433 |
-
|
434 |
-
class ProcessGroupUCC(ProcessGroup):
|
435 |
-
def __init__(
|
436 |
-
self,
|
437 |
-
store: Store,
|
438 |
-
rank: int,
|
439 |
-
size: int,
|
440 |
-
timeout: timedelta,
|
441 |
-
): ...
|
442 |
-
|
443 |
-
class ProcessGroupMPI(ProcessGroup):
|
444 |
-
def __init__(
|
445 |
-
self,
|
446 |
-
rank: int,
|
447 |
-
size: int,
|
448 |
-
pgComm: int,
|
449 |
-
): ...
|
450 |
-
@staticmethod
|
451 |
-
def create(ranks: List[int]) -> ProcessGroupMPI: ...
|
452 |
-
|
453 |
-
def _compute_bucket_assignment_by_size(
|
454 |
-
tensors: List[Tensor],
|
455 |
-
bucket_size_limits: List[int],
|
456 |
-
expect_sparse_gradient: List[bool] = ...,
|
457 |
-
tensor_indices: List[int] = ...,
|
458 |
-
) -> Tuple[List[List[int]], List[int]]: ...
|
459 |
-
def _broadcast_coalesced(
|
460 |
-
process_group: ProcessGroup,
|
461 |
-
tensors: List[Tensor],
|
462 |
-
buffer_size: int,
|
463 |
-
src: int,
|
464 |
-
): ...
|
465 |
-
def _test_python_store(store: Store): ...
|
466 |
-
def _verify_params_across_processes(
|
467 |
-
process_group: ProcessGroup,
|
468 |
-
params: List[Tensor],
|
469 |
-
logger: Optional[Logger],
|
470 |
-
): ...
|
471 |
-
def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
|
472 |
-
|
473 |
-
class Backend:
|
474 |
-
def __init__(
|
475 |
-
self,
|
476 |
-
rank: int,
|
477 |
-
size: int,
|
478 |
-
): ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_distributed_rpc.pyi
DELETED
@@ -1,188 +0,0 @@
|
|
1 |
-
# mypy: disable-error-code="type-arg"
|
2 |
-
from datetime import timedelta
|
3 |
-
from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar
|
4 |
-
|
5 |
-
import torch
|
6 |
-
|
7 |
-
from . import Future
|
8 |
-
from ._autograd import ProfilerEvent
|
9 |
-
from ._distributed_c10d import Store
|
10 |
-
from ._profiler import ProfilerConfig
|
11 |
-
|
12 |
-
# This module is defined in torch/csrc/distributed/rpc/init.cpp
|
13 |
-
|
14 |
-
_DEFAULT_INIT_METHOD: str
|
15 |
-
_DEFAULT_NUM_WORKER_THREADS: int
|
16 |
-
_UNSET_RPC_TIMEOUT: float
|
17 |
-
_DEFAULT_RPC_TIMEOUT_SEC: float
|
18 |
-
|
19 |
-
_T = TypeVar("_T")
|
20 |
-
|
21 |
-
class RpcBackendOptions:
|
22 |
-
rpc_timeout: float
|
23 |
-
init_method: str
|
24 |
-
def __init__(
|
25 |
-
self,
|
26 |
-
rpc_timeout: float = ...,
|
27 |
-
init_method: str = ...,
|
28 |
-
): ...
|
29 |
-
|
30 |
-
class WorkerInfo:
|
31 |
-
def __init__(self, name: str, worker_id: int): ...
|
32 |
-
@property
|
33 |
-
def name(self) -> str: ...
|
34 |
-
@property
|
35 |
-
def id(self) -> int: ...
|
36 |
-
def __eq__(self, other: object) -> bool: ...
|
37 |
-
|
38 |
-
class RpcAgent:
|
39 |
-
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
40 |
-
def sync(self): ...
|
41 |
-
def shutdown(self): ...
|
42 |
-
@overload
|
43 |
-
def get_worker_info(self) -> WorkerInfo: ...
|
44 |
-
@overload
|
45 |
-
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
46 |
-
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
47 |
-
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
|
48 |
-
def get_debug_info(self) -> Dict[str, str]: ...
|
49 |
-
def get_metrics(self) -> Dict[str, str]: ...
|
50 |
-
|
51 |
-
class PyRRef(Generic[_T]):
|
52 |
-
def __init__(self, value: _T, type_hint: Any = None) -> None: ...
|
53 |
-
def is_owner(self) -> bool: ...
|
54 |
-
def confirmed_by_owner(self) -> bool: ...
|
55 |
-
def owner(self) -> WorkerInfo: ...
|
56 |
-
def owner_name(self) -> str: ...
|
57 |
-
def to_here(self, timeout: float = ...) -> _T: ...
|
58 |
-
def local_value(self) -> Any: ...
|
59 |
-
def rpc_sync(self, timeout: float = ...) -> Any: ...
|
60 |
-
def rpc_async(self, timeout: float = ...) -> Any: ...
|
61 |
-
def remote(self, timeout: float = ...) -> Any: ...
|
62 |
-
def _serialize(self) -> Tuple: ...
|
63 |
-
@staticmethod
|
64 |
-
def _deserialize(tp: Tuple) -> PyRRef: ...
|
65 |
-
def _get_type(self) -> Type[_T]: ...
|
66 |
-
def _get_future(self) -> Future[_T]: ...
|
67 |
-
def _get_profiling_future(self) -> Future[_T]: ...
|
68 |
-
def _set_profiling_future(self, profilingFuture: Future[_T]): ...
|
69 |
-
|
70 |
-
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
|
71 |
-
num_worker_threads: int
|
72 |
-
device_maps: Dict[str, Dict[torch.device, torch.device]]
|
73 |
-
devices: List[torch.device]
|
74 |
-
def __init__(
|
75 |
-
self,
|
76 |
-
num_worker_threads: int,
|
77 |
-
_transports: Optional[List],
|
78 |
-
_channels: Optional[List],
|
79 |
-
rpc_timeout: float = ...,
|
80 |
-
init_method: str = ...,
|
81 |
-
device_maps: Dict[str, Dict[torch.device, torch.device]] = {}, # noqa: B006
|
82 |
-
devices: List[torch.device] = [], # noqa: B006
|
83 |
-
): ...
|
84 |
-
def _set_device_map(
|
85 |
-
self,
|
86 |
-
to: str,
|
87 |
-
device_map: Dict[torch.device, torch.device],
|
88 |
-
): ...
|
89 |
-
|
90 |
-
class TensorPipeAgent(RpcAgent):
|
91 |
-
def __init__(
|
92 |
-
self,
|
93 |
-
store: Store,
|
94 |
-
name: str,
|
95 |
-
worker_id: int,
|
96 |
-
world_size: Optional[int],
|
97 |
-
opts: _TensorPipeRpcBackendOptionsBase,
|
98 |
-
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
|
99 |
-
devices: List[torch.device],
|
100 |
-
): ...
|
101 |
-
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
102 |
-
def shutdown(self): ...
|
103 |
-
@overload
|
104 |
-
def get_worker_info(self) -> WorkerInfo: ...
|
105 |
-
@overload
|
106 |
-
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
107 |
-
@overload
|
108 |
-
def get_worker_info(self, id: int) -> WorkerInfo: ...
|
109 |
-
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
110 |
-
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
|
111 |
-
def _update_group_membership(
|
112 |
-
self,
|
113 |
-
worker_info: WorkerInfo,
|
114 |
-
my_devices: List[torch.device],
|
115 |
-
reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
|
116 |
-
is_join: bool,
|
117 |
-
): ...
|
118 |
-
def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
|
119 |
-
@property
|
120 |
-
def is_static_group(self) -> bool: ...
|
121 |
-
@property
|
122 |
-
def store(self) -> Store: ...
|
123 |
-
|
124 |
-
def _is_current_rpc_agent_set() -> bool: ...
|
125 |
-
def _get_current_rpc_agent() -> RpcAgent: ...
|
126 |
-
def _set_and_start_rpc_agent(agent: RpcAgent): ...
|
127 |
-
def _reset_current_rpc_agent(): ...
|
128 |
-
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
|
129 |
-
def _destroy_rref_context(ignoreRRefLeak: bool): ...
|
130 |
-
def _rref_context_get_debug_info() -> Dict[str, str]: ...
|
131 |
-
def _cleanup_python_rpc_handler(): ...
|
132 |
-
def _invoke_rpc_builtin(
|
133 |
-
dst: WorkerInfo,
|
134 |
-
opName: str,
|
135 |
-
rpcTimeoutSeconds: float,
|
136 |
-
*args: Any,
|
137 |
-
**kwargs: Any,
|
138 |
-
): ...
|
139 |
-
def _invoke_rpc_python_udf(
|
140 |
-
dst: WorkerInfo,
|
141 |
-
pickledPythonUDF: str,
|
142 |
-
tensors: List[torch.Tensor],
|
143 |
-
rpcTimeoutSeconds: float,
|
144 |
-
isAsyncExecution: bool,
|
145 |
-
): ...
|
146 |
-
def _invoke_rpc_torchscript(
|
147 |
-
dstWorkerName: str,
|
148 |
-
qualifiedNameStr: str,
|
149 |
-
argsTuple: Tuple,
|
150 |
-
kwargsDict: Dict,
|
151 |
-
rpcTimeoutSeconds: float,
|
152 |
-
isAsyncExecution: bool,
|
153 |
-
): ...
|
154 |
-
def _invoke_remote_builtin(
|
155 |
-
dst: WorkerInfo,
|
156 |
-
opName: str,
|
157 |
-
rpcTimeoutSeconds: float,
|
158 |
-
*args: Any,
|
159 |
-
**kwargs: Any,
|
160 |
-
): ...
|
161 |
-
def _invoke_remote_python_udf(
|
162 |
-
dst: WorkerInfo,
|
163 |
-
pickledPythonUDF: str,
|
164 |
-
tensors: List[torch.Tensor],
|
165 |
-
rpcTimeoutSeconds: float,
|
166 |
-
isAsyncExecution: bool,
|
167 |
-
): ...
|
168 |
-
def _invoke_remote_torchscript(
|
169 |
-
dstWorkerName: WorkerInfo,
|
170 |
-
qualifiedNameStr: str,
|
171 |
-
rpcTimeoutSeconds: float,
|
172 |
-
isAsyncExecution: bool,
|
173 |
-
*args: Any,
|
174 |
-
**kwargs: Any,
|
175 |
-
): ...
|
176 |
-
def get_rpc_timeout() -> float: ...
|
177 |
-
def enable_gil_profiling(flag: bool): ...
|
178 |
-
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
|
179 |
-
|
180 |
-
class RemoteProfilerManager:
|
181 |
-
@staticmethod
|
182 |
-
def set_current_profiling_key(key: str): ...
|
183 |
-
|
184 |
-
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
|
185 |
-
def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
|
186 |
-
def _set_profiler_node_id(default_node_id: int): ...
|
187 |
-
def _enable_jit_rref_pickle(): ...
|
188 |
-
def _disable_jit_rref_pickle(): ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_distributed_rpc_testing.pyi
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
from typing import Dict, List
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from ._distributed_c10d import Store
|
6 |
-
from ._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
|
7 |
-
|
8 |
-
# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
|
9 |
-
|
10 |
-
class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
11 |
-
def __init__(
|
12 |
-
self,
|
13 |
-
num_worker_threads: int,
|
14 |
-
rpc_timeout: float,
|
15 |
-
init_method: str,
|
16 |
-
messages_to_fail: List[str],
|
17 |
-
messages_to_delay: Dict[str, float],
|
18 |
-
num_fail_sends: int,
|
19 |
-
): ...
|
20 |
-
num_send_recv_threads: int
|
21 |
-
messages_to_fail: List[str]
|
22 |
-
messages_to_delay: Dict[str, float]
|
23 |
-
num_fail_sends: int
|
24 |
-
|
25 |
-
class FaultyTensorPipeAgent(TensorPipeAgent):
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
store: Store,
|
29 |
-
name: str,
|
30 |
-
rank: int,
|
31 |
-
world_size: int,
|
32 |
-
options: FaultyTensorPipeRpcBackendOptions,
|
33 |
-
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
|
34 |
-
devices: List[torch.device],
|
35 |
-
): ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_functions.pyi
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
from typing import AnyStr, List
|
2 |
-
|
3 |
-
from torch import Tensor
|
4 |
-
|
5 |
-
class UndefinedGrad:
|
6 |
-
def __init__(self) -> None: ...
|
7 |
-
def __call__(self, *inputs: Tensor) -> List[Tensor]: ...
|
8 |
-
|
9 |
-
class DelayedError:
|
10 |
-
def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
|
11 |
-
def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_functorch.pyi
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
from typing import Optional, Tuple
|
3 |
-
|
4 |
-
from torch import Tensor
|
5 |
-
|
6 |
-
# Defined in torch/csrc/functorch/init.cpp
|
7 |
-
|
8 |
-
def _set_dynamic_layer_keys_included(included: bool) -> None: ...
|
9 |
-
def get_unwrapped(tensor: Tensor) -> Tensor: ...
|
10 |
-
def is_batchedtensor(tensor: Tensor) -> bool: ...
|
11 |
-
def is_functionaltensor(tensor: Tensor) -> bool: ...
|
12 |
-
def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
|
13 |
-
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
|
14 |
-
def maybe_get_bdim(tensor: Tensor) -> int: ...
|
15 |
-
def maybe_get_level(tensor: Tensor) -> int: ...
|
16 |
-
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
|
17 |
-
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
18 |
-
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
19 |
-
def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
|
20 |
-
def current_level() -> int: ...
|
21 |
-
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
22 |
-
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
|
23 |
-
def get_single_level_autograd_function_allowed() -> bool: ...
|
24 |
-
def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
|
25 |
-
def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
|
26 |
-
|
27 |
-
# Defined in aten/src/ATen/functorch/Interpreter.h
|
28 |
-
class TransformType(Enum):
|
29 |
-
Torch: TransformType = ...
|
30 |
-
Vmap: TransformType = ...
|
31 |
-
Grad: TransformType = ...
|
32 |
-
Jvp: TransformType = ...
|
33 |
-
Functionalize: TransformType = ...
|
34 |
-
|
35 |
-
class RandomnessType(Enum):
|
36 |
-
Error: TransformType = ...
|
37 |
-
Same: TransformType = ...
|
38 |
-
Different: TransformType = ...
|
39 |
-
|
40 |
-
class CInterpreter:
|
41 |
-
def key(self) -> TransformType: ...
|
42 |
-
def level(self) -> int: ...
|
43 |
-
|
44 |
-
class CGradInterpreterPtr:
|
45 |
-
def __init__(self, interpreter: CInterpreter): ...
|
46 |
-
def lift(self, Tensor) -> Tensor: ...
|
47 |
-
def prevGradMode(self) -> bool: ...
|
48 |
-
|
49 |
-
class CJvpInterpreterPtr:
|
50 |
-
def __init__(self, interpreter: CInterpreter): ...
|
51 |
-
def lift(self, Tensor) -> Tensor: ...
|
52 |
-
def prevFwdGradMode(self) -> bool: ...
|
53 |
-
|
54 |
-
class CFunctionalizeInterpreterPtr:
|
55 |
-
def __init__(self, interpreter: CInterpreter): ...
|
56 |
-
def key(self) -> TransformType: ...
|
57 |
-
def level(self) -> int: ...
|
58 |
-
def functionalizeAddBackViews(self) -> bool: ...
|
59 |
-
|
60 |
-
class CVmapInterpreterPtr:
|
61 |
-
def __init__(self, interpreter: CInterpreter): ...
|
62 |
-
def key(self) -> TransformType: ...
|
63 |
-
def level(self) -> int: ...
|
64 |
-
def batchSize(self) -> int: ...
|
65 |
-
def randomness(self) -> RandomnessType: ...
|
66 |
-
|
67 |
-
class DynamicLayer: ...
|
68 |
-
|
69 |
-
def peek_interpreter_stack() -> CInterpreter: ...
|
70 |
-
def pop_dynamic_layer_stack() -> DynamicLayer: ...
|
71 |
-
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_itt.pyi
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Defined in torch/csrc/itt.cpp
|
2 |
-
def is_available() -> None: ...
|
3 |
-
def rangePush(message: str) -> None: ...
|
4 |
-
def rangePop() -> None: ...
|
5 |
-
def mark(message: str) -> None: ...
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_lazy.pyi
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
from typing import List
|
2 |
-
|
3 |
-
from torch import Tensor
|
4 |
-
|
5 |
-
# defined in torch/csrc/lazy/python/init.cpp
|
6 |
-
def _mark_step(device: str, devices: List[str], wait: bool): ...
|
7 |
-
def _wait_device_ops(devices: List[str]): ...
|
8 |
-
def _reset_metrics(): ...
|
9 |
-
def _counter_names() -> List[str]: ...
|
10 |
-
def _counter_value(name: str) -> int: ...
|
11 |
-
def _metrics_report() -> str: ...
|
12 |
-
def _get_graph_hash(tensors: List[Tensor]) -> str: ...
|
13 |
-
def _sync_multi(
|
14 |
-
tensors: List[Tensor],
|
15 |
-
devices: List[str],
|
16 |
-
wait: bool = True,
|
17 |
-
sync_ltc_data: bool = True,
|
18 |
-
): ...
|
19 |
-
def _get_tensor_id(tensor: Tensor) -> int: ...
|
20 |
-
def _get_tensors_text(tensors: List[Tensor]) -> str: ...
|
21 |
-
def _get_tensors_dot(tensors: List[Tensor]) -> str: ...
|
22 |
-
def _get_tensors_backend(tensors: List[Tensor]) -> str: ...
|
23 |
-
def _get_force_fallback() -> str: ...
|
24 |
-
def _set_force_fallback(newval: str): ...
|
25 |
-
def _clear_ir_cache(): ...
|
26 |
-
def _dump_ir_cache(filename: str): ...
|
27 |
-
def _set_reuse_ir(val: bool): ...
|
28 |
-
def _get_default_device_type(): ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_lazy_ts_backend.pyi
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
# defined in torch/csrc/lazy/python/init.cpp
|
2 |
-
|
3 |
-
from typing import Any, List, Tuple
|
4 |
-
|
5 |
-
from torch import Tensor
|
6 |
-
|
7 |
-
def _init(): ...
|
8 |
-
def _get_tensors_ts_device_data_node(
|
9 |
-
tensors: List[Tensor],
|
10 |
-
) -> Tuple[List[int], List[Any]]: ...
|
11 |
-
def _run_cached_graph(hash_str: str, graph_inputs: List[Any]) -> List[Tensor]: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_monitor.pyi
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
# Defined in torch/csrc/monitor/python_init.cpp
|
2 |
-
|
3 |
-
import datetime
|
4 |
-
from enum import Enum
|
5 |
-
from typing import Callable, Dict, List, Union
|
6 |
-
|
7 |
-
class Aggregation(Enum):
|
8 |
-
VALUE = ...
|
9 |
-
MEAN = ...
|
10 |
-
COUNT = ...
|
11 |
-
SUM = ...
|
12 |
-
MAX = ...
|
13 |
-
MIN = ...
|
14 |
-
|
15 |
-
class Stat:
|
16 |
-
name: str
|
17 |
-
count: int
|
18 |
-
def __init__(
|
19 |
-
self,
|
20 |
-
name: str,
|
21 |
-
aggregations: List[Aggregation],
|
22 |
-
window_size: int,
|
23 |
-
max_samples: int = -1,
|
24 |
-
) -> None: ...
|
25 |
-
def add(self, v: float) -> None: ...
|
26 |
-
def get(self) -> Dict[Aggregation, float]: ...
|
27 |
-
|
28 |
-
class Event:
|
29 |
-
name: str
|
30 |
-
timestamp: datetime.datetime
|
31 |
-
data: Dict[str, Union[int, float, bool, str]]
|
32 |
-
def __init__(
|
33 |
-
self,
|
34 |
-
name: str,
|
35 |
-
timestamp: datetime.datetime,
|
36 |
-
data: Dict[str, Union[int, float, bool, str]],
|
37 |
-
) -> None: ...
|
38 |
-
|
39 |
-
def log_event(e: Event) -> None: ...
|
40 |
-
|
41 |
-
class EventHandlerHandle: ...
|
42 |
-
|
43 |
-
def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ...
|
44 |
-
def unregister_event_handler(handle: EventHandlerHandle) -> None: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_nn.pyi
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
# mypy: disable-error-code="type-arg"
|
2 |
-
from typing import List, Optional, overload, Sequence, Tuple, Union
|
3 |
-
|
4 |
-
from torch import memory_format, Tensor
|
5 |
-
from torch.types import _bool, _device, _dtype, _int, _size
|
6 |
-
|
7 |
-
# Defined in tools/autograd/templates/python_nn_functions.cpp
|
8 |
-
|
9 |
-
def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
|
10 |
-
def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
|
11 |
-
def avg_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
|
12 |
-
def avg_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
|
13 |
-
def elu_(input: Tensor, alpha: float = ...) -> Tensor: ...
|
14 |
-
def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
|
15 |
-
def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
|
16 |
-
def gelu(input: Tensor, approximate: str = ...) -> Tensor: ...
|
17 |
-
def hardsigmoid(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ...
|
18 |
-
def hardtanh(input: Tensor, min_val: float = ..., max_val: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
|
19 |
-
def hardtanh_(input: Tensor, min_val: float = ..., max_val: float = ...) -> Tensor: ...
|
20 |
-
def leaky_relu(input: Tensor, negative_slope: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
|
21 |
-
def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ...
|
22 |
-
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: ...
|
23 |
-
def log_sigmoid(input: Tensor) -> Tensor: ...
|
24 |
-
def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ...
|
25 |
-
def pad(input: Tensor, pad: Sequence[int], mode: str = ..., value: Optional[float] = None) -> Tensor: ...
|
26 |
-
def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None) -> Tensor: ...
|
27 |
-
def softplus(input: Tensor, beta: int = ..., threshold: int = ...) -> Tensor: ...
|
28 |
-
def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
|
29 |
-
|
30 |
-
# Defined in aten/src/ATen/native/mkldnn/Linear.cpp
|
31 |
-
def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ...
|
32 |
-
|
33 |
-
# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
|
34 |
-
def mkldnn_reorder_conv2d_weight(
|
35 |
-
self: Tensor,
|
36 |
-
padding: List,
|
37 |
-
stride: List,
|
38 |
-
dilatation: List,
|
39 |
-
groups: int,
|
40 |
-
) -> Tensor: ...
|
41 |
-
def mkldnn_reorder_conv3d_weight(
|
42 |
-
self: Tensor,
|
43 |
-
padding: List,
|
44 |
-
stride: List,
|
45 |
-
dilatation: List,
|
46 |
-
groups: int,
|
47 |
-
) -> Tensor: ...
|
48 |
-
|
49 |
-
# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
|
50 |
-
def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
|
51 |
-
|
52 |
-
# Defined at tools/autograd/templates/python_nn_functions.cpp
|
53 |
-
@overload
|
54 |
-
def _parse_to(
|
55 |
-
device: _device,
|
56 |
-
dtype: _dtype,
|
57 |
-
non_blocking: _bool,
|
58 |
-
copy: _bool,
|
59 |
-
*,
|
60 |
-
memory_format: memory_format,
|
61 |
-
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
|
62 |
-
@overload
|
63 |
-
def _parse_to(
|
64 |
-
dtype: _dtype,
|
65 |
-
non_blocking: _bool,
|
66 |
-
copy: _bool,
|
67 |
-
*,
|
68 |
-
memory_format: memory_format,
|
69 |
-
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
|
70 |
-
@overload
|
71 |
-
def _parse_to(
|
72 |
-
tensor: Tensor,
|
73 |
-
non_blocking: _bool,
|
74 |
-
copy: _bool,
|
75 |
-
*,
|
76 |
-
memory_format: memory_format,
|
77 |
-
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
|
78 |
-
|
79 |
-
# Defined in aten/src/ATen/native/PadSequence.cpp
|
80 |
-
def pad_sequence(
|
81 |
-
sequences: List[Tensor],
|
82 |
-
batch_first: bool = False,
|
83 |
-
padding_value: float = ...,
|
84 |
-
) -> Tensor: ...
|
85 |
-
def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ...
|
86 |
-
def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_nvtx.pyi
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# Defined in torch/csrc/cuda/shared/nvtx.cpp
|
2 |
-
def rangePushA(message: str) -> int: ...
|
3 |
-
def rangePop() -> int: ...
|
4 |
-
def rangeStartA(message: str) -> int: ...
|
5 |
-
def rangeEnd(int) -> None: ...
|
6 |
-
def markA(message: str) -> None: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_onnx.pyi
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
# Defined in torch/csrc/onnx/init.cpp
|
2 |
-
|
3 |
-
from enum import Enum
|
4 |
-
|
5 |
-
_CAFFE2_ATEN_FALLBACK: bool
|
6 |
-
PRODUCER_VERSION: str
|
7 |
-
|
8 |
-
class TensorProtoDataType(Enum):
|
9 |
-
UNDEFINED = ...
|
10 |
-
FLOAT = ...
|
11 |
-
UINT8 = ...
|
12 |
-
INT8 = ...
|
13 |
-
UINT16 = ...
|
14 |
-
INT16 = ...
|
15 |
-
INT32 = ...
|
16 |
-
INT64 = ...
|
17 |
-
STRING = ...
|
18 |
-
BOOL = ...
|
19 |
-
FLOAT16 = ...
|
20 |
-
DOUBLE = ...
|
21 |
-
UINT32 = ...
|
22 |
-
UINT64 = ...
|
23 |
-
COMPLEX64 = ...
|
24 |
-
COMPLEX128 = ...
|
25 |
-
BFLOAT16 = ...
|
26 |
-
FLOAT8E5M2 = ...
|
27 |
-
FLOAT8E4M3FN = ...
|
28 |
-
|
29 |
-
class OperatorExportTypes(Enum):
|
30 |
-
ONNX = ...
|
31 |
-
ONNX_ATEN = ...
|
32 |
-
ONNX_ATEN_FALLBACK = ...
|
33 |
-
ONNX_FALLTHROUGH = ...
|
34 |
-
|
35 |
-
class TrainingMode(Enum):
|
36 |
-
EVAL = ...
|
37 |
-
PRESERVE = ...
|
38 |
-
TRAINING = ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_profiler.pyi
DELETED
@@ -1,238 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
3 |
-
|
4 |
-
from torch._C import device, dtype, layout
|
5 |
-
from typing_extensions import TypeAlias
|
6 |
-
|
7 |
-
# defined in torch/csrc/profiler/python/init.cpp
|
8 |
-
|
9 |
-
class RecordScope(Enum):
|
10 |
-
FUNCTION = ...
|
11 |
-
BACKWARD_FUNCTION = ...
|
12 |
-
TORCHSCRIPT_FUNCTION = ...
|
13 |
-
KERNEL_FUNCTION_DTYPE = ...
|
14 |
-
CUSTOM_CLASS = ...
|
15 |
-
BUILD_FEATURE = ...
|
16 |
-
LITE_INTERPRETER = ...
|
17 |
-
USER_SCOPE = ...
|
18 |
-
STATIC_RUNTIME_OP = ...
|
19 |
-
STATIC_RUNTIME_MODEL = ...
|
20 |
-
|
21 |
-
class ProfilerState(Enum):
|
22 |
-
Disable = ...
|
23 |
-
CPU = ...
|
24 |
-
CUDA = ...
|
25 |
-
NVTX = ...
|
26 |
-
ITT = ...
|
27 |
-
KINETO = ...
|
28 |
-
KINETO_GPU_FALLBACK = ...
|
29 |
-
KINETO_PRIVATEUSE1_FALLBACK = ...
|
30 |
-
KINETO_PRIVATEUSE1 = ...
|
31 |
-
|
32 |
-
class ActiveProfilerType(Enum):
|
33 |
-
NONE = ...
|
34 |
-
LEGACY = ...
|
35 |
-
KINETO = ...
|
36 |
-
NVTX = ...
|
37 |
-
ITT = ...
|
38 |
-
|
39 |
-
class ProfilerActivity(Enum):
|
40 |
-
CPU = ...
|
41 |
-
CUDA = ...
|
42 |
-
MTIA = ...
|
43 |
-
PrivateUse1 = ...
|
44 |
-
|
45 |
-
class _EventType(Enum):
|
46 |
-
TorchOp = ...
|
47 |
-
Backend = ...
|
48 |
-
Allocation = ...
|
49 |
-
OutOfMemory = ...
|
50 |
-
PyCall = ...
|
51 |
-
PyCCall = ...
|
52 |
-
Kineto = ...
|
53 |
-
|
54 |
-
class _ExperimentalConfig:
|
55 |
-
def __init__(
|
56 |
-
self,
|
57 |
-
profiler_metrics: List[str] = ...,
|
58 |
-
profiler_measure_per_kernel: bool = ...,
|
59 |
-
verbose: bool = ...,
|
60 |
-
performance_events: List[str] = ...,
|
61 |
-
enable_cuda_sync_events: bool = ...,
|
62 |
-
) -> None: ...
|
63 |
-
|
64 |
-
class ProfilerConfig:
|
65 |
-
def __init__(
|
66 |
-
self,
|
67 |
-
state: ProfilerState,
|
68 |
-
report_input_shapes: bool,
|
69 |
-
profile_memory: bool,
|
70 |
-
with_stack: bool,
|
71 |
-
with_flops: bool,
|
72 |
-
with_modules: bool,
|
73 |
-
experimental_config: _ExperimentalConfig,
|
74 |
-
) -> None: ...
|
75 |
-
|
76 |
-
class _ProfilerEvent:
|
77 |
-
start_tid: int
|
78 |
-
start_time_ns: int
|
79 |
-
children: List[_ProfilerEvent]
|
80 |
-
|
81 |
-
# TODO(robieta): remove in favor of `self.typed`
|
82 |
-
extra_fields: Union[
|
83 |
-
_ExtraFields_TorchOp,
|
84 |
-
_ExtraFields_Backend,
|
85 |
-
_ExtraFields_Allocation,
|
86 |
-
_ExtraFields_OutOfMemory,
|
87 |
-
_ExtraFields_PyCall,
|
88 |
-
_ExtraFields_PyCCall,
|
89 |
-
_ExtraFields_Kineto,
|
90 |
-
]
|
91 |
-
|
92 |
-
@property
|
93 |
-
def typed(
|
94 |
-
self,
|
95 |
-
) -> Union[
|
96 |
-
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
|
97 |
-
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
|
98 |
-
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
|
99 |
-
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
|
100 |
-
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
|
101 |
-
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
|
102 |
-
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
|
103 |
-
]: ...
|
104 |
-
@property
|
105 |
-
def name(self) -> str: ...
|
106 |
-
@property
|
107 |
-
def tag(self) -> _EventType: ...
|
108 |
-
@property
|
109 |
-
def id(self) -> int: ...
|
110 |
-
@property
|
111 |
-
def parent(self) -> Optional[_ProfilerEvent]: ...
|
112 |
-
@property
|
113 |
-
def correlation_id(self) -> int: ...
|
114 |
-
@property
|
115 |
-
def end_time_ns(self) -> int: ...
|
116 |
-
@property
|
117 |
-
def duration_time_ns(self) -> int: ...
|
118 |
-
|
119 |
-
class _TensorMetadata:
|
120 |
-
impl_ptr: Optional[int]
|
121 |
-
storage_data_ptr: Optional[int]
|
122 |
-
id: Optional[int]
|
123 |
-
|
124 |
-
@property
|
125 |
-
def allocation_id(self) -> Optional[int]: ...
|
126 |
-
@property
|
127 |
-
def layout(self) -> layout: ...
|
128 |
-
@property
|
129 |
-
def device(self) -> device: ...
|
130 |
-
@property
|
131 |
-
def dtype(self) -> dtype: ...
|
132 |
-
@property
|
133 |
-
def sizes(self) -> List[int]: ...
|
134 |
-
@property
|
135 |
-
def strides(self) -> List[int]: ...
|
136 |
-
|
137 |
-
Scalar: TypeAlias = Union[int, float, bool, complex]
|
138 |
-
Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
|
139 |
-
|
140 |
-
class _ExtraFields_TorchOp:
|
141 |
-
name: str
|
142 |
-
sequence_number: int
|
143 |
-
allow_tf32_cublas: bool
|
144 |
-
|
145 |
-
@property
|
146 |
-
def inputs(self) -> List[Input]: ...
|
147 |
-
@property
|
148 |
-
def scope(self) -> RecordScope: ...
|
149 |
-
|
150 |
-
class _ExtraFields_Backend: ...
|
151 |
-
|
152 |
-
class _ExtraFields_Allocation:
|
153 |
-
ptr: int
|
154 |
-
id: Optional[int]
|
155 |
-
alloc_size: int
|
156 |
-
total_allocated: int
|
157 |
-
total_reserved: int
|
158 |
-
|
159 |
-
@property
|
160 |
-
def allocation_id(self) -> Optional[int]: ...
|
161 |
-
@property
|
162 |
-
def device(self) -> device: ...
|
163 |
-
|
164 |
-
class _ExtraFields_OutOfMemory: ...
|
165 |
-
|
166 |
-
class _PyFrameState:
|
167 |
-
line_number: int
|
168 |
-
function_name: str
|
169 |
-
|
170 |
-
@property
|
171 |
-
def file_name(self) -> str: ...
|
172 |
-
|
173 |
-
class _NNModuleInfo:
|
174 |
-
@property
|
175 |
-
def self_ptr(self) -> int: ...
|
176 |
-
@property
|
177 |
-
def cls_ptr(self) -> int: ...
|
178 |
-
@property
|
179 |
-
def cls_name(self) -> str: ...
|
180 |
-
@property
|
181 |
-
def parameters(
|
182 |
-
self,
|
183 |
-
) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
|
184 |
-
|
185 |
-
class _OptimizerInfo:
|
186 |
-
@property
|
187 |
-
def parameters(
|
188 |
-
self,
|
189 |
-
) -> List[
|
190 |
-
Tuple[
|
191 |
-
# Parameter
|
192 |
-
_TensorMetadata,
|
193 |
-
#
|
194 |
-
# Gradient (if present during optimizer.step())
|
195 |
-
Optional[_TensorMetadata],
|
196 |
-
#
|
197 |
-
# Optimizer state for Parameter as (name, tensor) pairs
|
198 |
-
List[Tuple[str, _TensorMetadata]],
|
199 |
-
]
|
200 |
-
]: ...
|
201 |
-
|
202 |
-
class _ExtraFields_PyCCall:
|
203 |
-
@property
|
204 |
-
def caller(self) -> _PyFrameState: ...
|
205 |
-
|
206 |
-
class _ExtraFields_PyCall:
|
207 |
-
@property
|
208 |
-
def callsite(self) -> _PyFrameState: ...
|
209 |
-
@property
|
210 |
-
def caller(self) -> _PyFrameState: ...
|
211 |
-
@property
|
212 |
-
def module(self) -> Optional[_NNModuleInfo]: ...
|
213 |
-
@property
|
214 |
-
def optimizer(self) -> Optional[_OptimizerInfo]: ...
|
215 |
-
|
216 |
-
class _ExtraFields_Kineto: ...
|
217 |
-
|
218 |
-
def _add_execution_trace_observer(output_file_path: str) -> bool: ...
|
219 |
-
def _remove_execution_trace_observer() -> None: ...
|
220 |
-
def _enable_execution_trace_observer() -> None: ...
|
221 |
-
def _disable_execution_trace_observer() -> None: ...
|
222 |
-
def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
|
223 |
-
def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
|
224 |
-
def _set_cuda_sync_enabled_val(val: bool) -> None: ...
|
225 |
-
|
226 |
-
class CapturedTraceback: ...
|
227 |
-
|
228 |
-
def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ...
|
229 |
-
|
230 |
-
# The Dict has name, filename, line
|
231 |
-
def symbolize_tracebacks(
|
232 |
-
to_symbolize: List[CapturedTraceback],
|
233 |
-
) -> List[List[Dict[str, str]]]: ...
|
234 |
-
|
235 |
-
class _RecordFunctionFast:
|
236 |
-
def __init__(self, name: str) -> None: ...
|
237 |
-
def __enter__(self) -> None: ...
|
238 |
-
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_C/_verbose.pyi
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# Defined in torch/csrc/utils/verbose.cpp
|
2 |
-
def mkl_set_verbose(enable: int) -> int: ...
|
3 |
-
def mkldnn_set_verbose(level: int) -> int: ...
|
|
|
|
|
|
|
|
torch/_VF.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This makes the functions in torch._C._VariableFunctions available as
|
3 |
-
torch._VF.<funcname>
|
4 |
-
without mypy being able to find them.
|
5 |
-
|
6 |
-
A subset of those functions are mapped to ATen functions in
|
7 |
-
torch/jit/_builtins.py
|
8 |
-
|
9 |
-
See https://github.com/pytorch/pytorch/issues/21478 for the reason for
|
10 |
-
introducing torch._VF
|
11 |
-
|
12 |
-
"""
|
13 |
-
import sys
|
14 |
-
import types
|
15 |
-
|
16 |
-
import torch
|
17 |
-
|
18 |
-
|
19 |
-
class VFModule(types.ModuleType):
|
20 |
-
vf: types.ModuleType
|
21 |
-
|
22 |
-
def __init__(self, name):
|
23 |
-
super().__init__(name)
|
24 |
-
self.vf = torch._C._VariableFunctions
|
25 |
-
|
26 |
-
def __getattr__(self, attr):
|
27 |
-
return getattr(self.vf, attr)
|
28 |
-
|
29 |
-
|
30 |
-
sys.modules[__name__] = VFModule(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_VF.pyi
DELETED
The diff for this file is too large to render.
See raw diff
|
|
torch/__config__.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
|
4 |
-
def show():
|
5 |
-
"""
|
6 |
-
Return a human-readable string with descriptions of the
|
7 |
-
configuration of PyTorch.
|
8 |
-
"""
|
9 |
-
return torch._C._show_config()
|
10 |
-
|
11 |
-
|
12 |
-
# TODO: In principle, we could provide more structured version/config
|
13 |
-
# information here. For now only CXX_FLAGS is exposed, as Timer
|
14 |
-
# uses them.
|
15 |
-
def _cxx_flags():
|
16 |
-
"""Returns the CXX_FLAGS used when building PyTorch."""
|
17 |
-
return torch._C._cxx_flags()
|
18 |
-
|
19 |
-
|
20 |
-
def parallel_info():
|
21 |
-
r"""Returns detailed string with parallelization settings"""
|
22 |
-
return torch._C._parallel_info()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/__future__.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This global flag controls whether to assign new tensors to the parameters
|
3 |
-
instead of changing the existing parameters in-place when converting an `nn.Module`
|
4 |
-
using the following methods:
|
5 |
-
1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
|
6 |
-
2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
|
7 |
-
3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
|
8 |
-
4. `module._apply(fn)` (for generic functions applied to `module`)
|
9 |
-
|
10 |
-
Default: False
|
11 |
-
"""
|
12 |
-
_overwrite_module_params_on_conversion = False
|
13 |
-
|
14 |
-
|
15 |
-
def set_overwrite_module_params_on_conversion(value):
|
16 |
-
global _overwrite_module_params_on_conversion
|
17 |
-
_overwrite_module_params_on_conversion = value
|
18 |
-
|
19 |
-
|
20 |
-
def get_overwrite_module_params_on_conversion():
|
21 |
-
return _overwrite_module_params_on_conversion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_appdirs.py
DELETED
@@ -1,666 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
# Copyright (c) 2005-2010 ActiveState Software Inc.
|
4 |
-
# Copyright (c) 2013 Eddy Petrișor
|
5 |
-
|
6 |
-
# flake8: noqa
|
7 |
-
|
8 |
-
"""
|
9 |
-
This file is directly from
|
10 |
-
https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py
|
11 |
-
|
12 |
-
The license of https://github.com/ActiveState/appdirs copied below:
|
13 |
-
|
14 |
-
|
15 |
-
# This is the MIT license
|
16 |
-
|
17 |
-
Copyright (c) 2010 ActiveState Software Inc.
|
18 |
-
|
19 |
-
Permission is hereby granted, free of charge, to any person obtaining a
|
20 |
-
copy of this software and associated documentation files (the
|
21 |
-
"Software"), to deal in the Software without restriction, including
|
22 |
-
without limitation the rights to use, copy, modify, merge, publish,
|
23 |
-
distribute, sublicense, and/or sell copies of the Software, and to
|
24 |
-
permit persons to whom the Software is furnished to do so, subject to
|
25 |
-
the following conditions:
|
26 |
-
|
27 |
-
The above copyright notice and this permission notice shall be included
|
28 |
-
in all copies or substantial portions of the Software.
|
29 |
-
|
30 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
31 |
-
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
32 |
-
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
33 |
-
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
34 |
-
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
35 |
-
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
36 |
-
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
37 |
-
"""
|
38 |
-
|
39 |
-
"""Utilities for determining application-specific dirs.
|
40 |
-
|
41 |
-
See <https://github.com/ActiveState/appdirs> for details and usage.
|
42 |
-
"""
|
43 |
-
# Dev Notes:
|
44 |
-
# - MSDN on where to store app data files:
|
45 |
-
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
|
46 |
-
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
|
47 |
-
# - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
|
48 |
-
|
49 |
-
__version__ = "1.4.4"
|
50 |
-
__version_info__ = tuple(int(segment) for segment in __version__.split("."))
|
51 |
-
|
52 |
-
|
53 |
-
import os
|
54 |
-
import sys
|
55 |
-
|
56 |
-
unicode = str
|
57 |
-
|
58 |
-
if sys.platform.startswith("java"):
|
59 |
-
import platform
|
60 |
-
|
61 |
-
os_name = platform.java_ver()[3][0]
|
62 |
-
if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc.
|
63 |
-
system = "win32"
|
64 |
-
elif os_name.startswith("Mac"): # "Mac OS X", etc.
|
65 |
-
system = "darwin"
|
66 |
-
else: # "Linux", "SunOS", "FreeBSD", etc.
|
67 |
-
# Setting this to "linux2" is not ideal, but only Windows or Mac
|
68 |
-
# are actually checked for and the rest of the module expects
|
69 |
-
# *sys.platform* style strings.
|
70 |
-
system = "linux2"
|
71 |
-
else:
|
72 |
-
system = sys.platform
|
73 |
-
|
74 |
-
|
75 |
-
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
|
76 |
-
r"""Return full path to the user-specific data dir for this application.
|
77 |
-
|
78 |
-
"appname" is the name of application.
|
79 |
-
If None, just the system directory is returned.
|
80 |
-
"appauthor" (only used on Windows) is the name of the
|
81 |
-
appauthor or distributing body for this application. Typically
|
82 |
-
it is the owning company name. This falls back to appname. You may
|
83 |
-
pass False to disable it.
|
84 |
-
"version" is an optional version path element to append to the
|
85 |
-
path. You might want to use this if you want multiple versions
|
86 |
-
of your app to be able to run independently. If used, this
|
87 |
-
would typically be "<major>.<minor>".
|
88 |
-
Only applied when appname is present.
|
89 |
-
"roaming" (boolean, default False) can be set True to use the Windows
|
90 |
-
roaming appdata directory. That means that for users on a Windows
|
91 |
-
network setup for roaming profiles, this user data will be
|
92 |
-
sync'd on login. See
|
93 |
-
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
94 |
-
for a discussion of issues.
|
95 |
-
|
96 |
-
Typical user data directories are:
|
97 |
-
Mac OS X: ~/Library/Application Support/<AppName>
|
98 |
-
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
|
99 |
-
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
|
100 |
-
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
|
101 |
-
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
|
102 |
-
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
|
103 |
-
|
104 |
-
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
|
105 |
-
That means, by default "~/.local/share/<AppName>".
|
106 |
-
"""
|
107 |
-
if system == "win32":
|
108 |
-
if appauthor is None:
|
109 |
-
appauthor = appname
|
110 |
-
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
|
111 |
-
path = os.path.normpath(_get_win_folder(const))
|
112 |
-
if appname:
|
113 |
-
if appauthor is not False:
|
114 |
-
path = os.path.join(path, appauthor, appname)
|
115 |
-
else:
|
116 |
-
path = os.path.join(path, appname)
|
117 |
-
elif system == "darwin":
|
118 |
-
path = os.path.expanduser("~/Library/Application Support/")
|
119 |
-
if appname:
|
120 |
-
path = os.path.join(path, appname)
|
121 |
-
else:
|
122 |
-
path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
|
123 |
-
if appname:
|
124 |
-
path = os.path.join(path, appname)
|
125 |
-
if appname and version:
|
126 |
-
path = os.path.join(path, version)
|
127 |
-
return path
|
128 |
-
|
129 |
-
|
130 |
-
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
|
131 |
-
r"""Return full path to the user-shared data dir for this application.
|
132 |
-
|
133 |
-
"appname" is the name of application.
|
134 |
-
If None, just the system directory is returned.
|
135 |
-
"appauthor" (only used on Windows) is the name of the
|
136 |
-
appauthor or distributing body for this application. Typically
|
137 |
-
it is the owning company name. This falls back to appname. You may
|
138 |
-
pass False to disable it.
|
139 |
-
"version" is an optional version path element to append to the
|
140 |
-
path. You might want to use this if you want multiple versions
|
141 |
-
of your app to be able to run independently. If used, this
|
142 |
-
would typically be "<major>.<minor>".
|
143 |
-
Only applied when appname is present.
|
144 |
-
"multipath" is an optional parameter only applicable to *nix
|
145 |
-
which indicates that the entire list of data dirs should be
|
146 |
-
returned. By default, the first item from XDG_DATA_DIRS is
|
147 |
-
returned, or '/usr/local/share/<AppName>',
|
148 |
-
if XDG_DATA_DIRS is not set
|
149 |
-
|
150 |
-
Typical site data directories are:
|
151 |
-
Mac OS X: /Library/Application Support/<AppName>
|
152 |
-
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
|
153 |
-
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
|
154 |
-
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
155 |
-
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
|
156 |
-
|
157 |
-
For Unix, this is using the $XDG_DATA_DIRS[0] default.
|
158 |
-
|
159 |
-
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
160 |
-
"""
|
161 |
-
if system == "win32":
|
162 |
-
if appauthor is None:
|
163 |
-
appauthor = appname
|
164 |
-
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
|
165 |
-
if appname:
|
166 |
-
if appauthor is not False:
|
167 |
-
path = os.path.join(path, appauthor, appname)
|
168 |
-
else:
|
169 |
-
path = os.path.join(path, appname)
|
170 |
-
elif system == "darwin":
|
171 |
-
path = os.path.expanduser("/Library/Application Support")
|
172 |
-
if appname:
|
173 |
-
path = os.path.join(path, appname)
|
174 |
-
else:
|
175 |
-
# XDG default for $XDG_DATA_DIRS
|
176 |
-
# only first, if multipath is False
|
177 |
-
path = os.getenv(
|
178 |
-
"XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"])
|
179 |
-
)
|
180 |
-
pathlist = [
|
181 |
-
os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
|
182 |
-
]
|
183 |
-
if appname:
|
184 |
-
if version:
|
185 |
-
appname = os.path.join(appname, version)
|
186 |
-
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
187 |
-
|
188 |
-
if multipath:
|
189 |
-
path = os.pathsep.join(pathlist)
|
190 |
-
else:
|
191 |
-
path = pathlist[0]
|
192 |
-
return path
|
193 |
-
|
194 |
-
if appname and version:
|
195 |
-
path = os.path.join(path, version)
|
196 |
-
return path
|
197 |
-
|
198 |
-
|
199 |
-
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
|
200 |
-
r"""Return full path to the user-specific config dir for this application.
|
201 |
-
|
202 |
-
"appname" is the name of application.
|
203 |
-
If None, just the system directory is returned.
|
204 |
-
"appauthor" (only used on Windows) is the name of the
|
205 |
-
appauthor or distributing body for this application. Typically
|
206 |
-
it is the owning company name. This falls back to appname. You may
|
207 |
-
pass False to disable it.
|
208 |
-
"version" is an optional version path element to append to the
|
209 |
-
path. You might want to use this if you want multiple versions
|
210 |
-
of your app to be able to run independently. If used, this
|
211 |
-
would typically be "<major>.<minor>".
|
212 |
-
Only applied when appname is present.
|
213 |
-
"roaming" (boolean, default False) can be set True to use the Windows
|
214 |
-
roaming appdata directory. That means that for users on a Windows
|
215 |
-
network setup for roaming profiles, this user data will be
|
216 |
-
sync'd on login. See
|
217 |
-
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
218 |
-
for a discussion of issues.
|
219 |
-
|
220 |
-
Typical user config directories are:
|
221 |
-
Mac OS X: ~/Library/Preferences/<AppName>
|
222 |
-
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
|
223 |
-
Win *: same as user_data_dir
|
224 |
-
|
225 |
-
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
|
226 |
-
That means, by default "~/.config/<AppName>".
|
227 |
-
"""
|
228 |
-
if system == "win32":
|
229 |
-
path = user_data_dir(appname, appauthor, None, roaming)
|
230 |
-
elif system == "darwin":
|
231 |
-
path = os.path.expanduser("~/Library/Preferences/")
|
232 |
-
if appname:
|
233 |
-
path = os.path.join(path, appname)
|
234 |
-
else:
|
235 |
-
path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
|
236 |
-
if appname:
|
237 |
-
path = os.path.join(path, appname)
|
238 |
-
if appname and version:
|
239 |
-
path = os.path.join(path, version)
|
240 |
-
return path
|
241 |
-
|
242 |
-
|
243 |
-
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
|
244 |
-
r"""Return full path to the user-shared data dir for this application.
|
245 |
-
|
246 |
-
"appname" is the name of application.
|
247 |
-
If None, just the system directory is returned.
|
248 |
-
"appauthor" (only used on Windows) is the name of the
|
249 |
-
appauthor or distributing body for this application. Typically
|
250 |
-
it is the owning company name. This falls back to appname. You may
|
251 |
-
pass False to disable it.
|
252 |
-
"version" is an optional version path element to append to the
|
253 |
-
path. You might want to use this if you want multiple versions
|
254 |
-
of your app to be able to run independently. If used, this
|
255 |
-
would typically be "<major>.<minor>".
|
256 |
-
Only applied when appname is present.
|
257 |
-
"multipath" is an optional parameter only applicable to *nix
|
258 |
-
which indicates that the entire list of config dirs should be
|
259 |
-
returned. By default, the first item from XDG_CONFIG_DIRS is
|
260 |
-
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
|
261 |
-
|
262 |
-
Typical site config directories are:
|
263 |
-
Mac OS X: same as site_data_dir
|
264 |
-
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
|
265 |
-
$XDG_CONFIG_DIRS
|
266 |
-
Win *: same as site_data_dir
|
267 |
-
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
268 |
-
|
269 |
-
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
|
270 |
-
|
271 |
-
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
272 |
-
"""
|
273 |
-
if system == "win32":
|
274 |
-
path = site_data_dir(appname, appauthor)
|
275 |
-
if appname and version:
|
276 |
-
path = os.path.join(path, version)
|
277 |
-
elif system == "darwin":
|
278 |
-
path = os.path.expanduser("/Library/Preferences")
|
279 |
-
if appname:
|
280 |
-
path = os.path.join(path, appname)
|
281 |
-
else:
|
282 |
-
# XDG default for $XDG_CONFIG_DIRS
|
283 |
-
# only first, if multipath is False
|
284 |
-
path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg")
|
285 |
-
pathlist = [
|
286 |
-
os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
|
287 |
-
]
|
288 |
-
if appname:
|
289 |
-
if version:
|
290 |
-
appname = os.path.join(appname, version)
|
291 |
-
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
292 |
-
|
293 |
-
if multipath:
|
294 |
-
path = os.pathsep.join(pathlist)
|
295 |
-
else:
|
296 |
-
path = pathlist[0]
|
297 |
-
return path
|
298 |
-
|
299 |
-
|
300 |
-
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
|
301 |
-
r"""Return full path to the user-specific cache dir for this application.
|
302 |
-
|
303 |
-
"appname" is the name of application.
|
304 |
-
If None, just the system directory is returned.
|
305 |
-
"appauthor" (only used on Windows) is the name of the
|
306 |
-
appauthor or distributing body for this application. Typically
|
307 |
-
it is the owning company name. This falls back to appname. You may
|
308 |
-
pass False to disable it.
|
309 |
-
"version" is an optional version path element to append to the
|
310 |
-
path. You might want to use this if you want multiple versions
|
311 |
-
of your app to be able to run independently. If used, this
|
312 |
-
would typically be "<major>.<minor>".
|
313 |
-
Only applied when appname is present.
|
314 |
-
"opinion" (boolean) can be False to disable the appending of
|
315 |
-
"Cache" to the base app data dir for Windows. See
|
316 |
-
discussion below.
|
317 |
-
|
318 |
-
Typical user cache directories are:
|
319 |
-
Mac OS X: ~/Library/Caches/<AppName>
|
320 |
-
Unix: ~/.cache/<AppName> (XDG default)
|
321 |
-
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
|
322 |
-
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
|
323 |
-
|
324 |
-
On Windows the only suggestion in the MSDN docs is that local settings go in
|
325 |
-
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
|
326 |
-
app data dir (the default returned by `user_data_dir` above). Apps typically
|
327 |
-
put cache data somewhere *under* the given dir here. Some examples:
|
328 |
-
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
|
329 |
-
...\Acme\SuperApp\Cache\1.0
|
330 |
-
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
|
331 |
-
This can be disabled with the `opinion=False` option.
|
332 |
-
"""
|
333 |
-
if system == "win32":
|
334 |
-
if appauthor is None:
|
335 |
-
appauthor = appname
|
336 |
-
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
|
337 |
-
if appname:
|
338 |
-
if appauthor is not False:
|
339 |
-
path = os.path.join(path, appauthor, appname)
|
340 |
-
else:
|
341 |
-
path = os.path.join(path, appname)
|
342 |
-
if opinion:
|
343 |
-
path = os.path.join(path, "Cache")
|
344 |
-
elif system == "darwin":
|
345 |
-
path = os.path.expanduser("~/Library/Caches")
|
346 |
-
if appname:
|
347 |
-
path = os.path.join(path, appname)
|
348 |
-
else:
|
349 |
-
path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
350 |
-
if appname:
|
351 |
-
path = os.path.join(path, appname)
|
352 |
-
if appname and version:
|
353 |
-
path = os.path.join(path, version)
|
354 |
-
return path
|
355 |
-
|
356 |
-
|
357 |
-
def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
|
358 |
-
r"""Return full path to the user-specific state dir for this application.
|
359 |
-
|
360 |
-
"appname" is the name of application.
|
361 |
-
If None, just the system directory is returned.
|
362 |
-
"appauthor" (only used on Windows) is the name of the
|
363 |
-
appauthor or distributing body for this application. Typically
|
364 |
-
it is the owning company name. This falls back to appname. You may
|
365 |
-
pass False to disable it.
|
366 |
-
"version" is an optional version path element to append to the
|
367 |
-
path. You might want to use this if you want multiple versions
|
368 |
-
of your app to be able to run independently. If used, this
|
369 |
-
would typically be "<major>.<minor>".
|
370 |
-
Only applied when appname is present.
|
371 |
-
"roaming" (boolean, default False) can be set True to use the Windows
|
372 |
-
roaming appdata directory. That means that for users on a Windows
|
373 |
-
network setup for roaming profiles, this user data will be
|
374 |
-
sync'd on login. See
|
375 |
-
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
376 |
-
for a discussion of issues.
|
377 |
-
|
378 |
-
Typical user state directories are:
|
379 |
-
Mac OS X: same as user_data_dir
|
380 |
-
Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
|
381 |
-
Win *: same as user_data_dir
|
382 |
-
|
383 |
-
For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
|
384 |
-
to extend the XDG spec and support $XDG_STATE_HOME.
|
385 |
-
|
386 |
-
That means, by default "~/.local/state/<AppName>".
|
387 |
-
"""
|
388 |
-
if system in ["win32", "darwin"]:
|
389 |
-
path = user_data_dir(appname, appauthor, None, roaming)
|
390 |
-
else:
|
391 |
-
path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state"))
|
392 |
-
if appname:
|
393 |
-
path = os.path.join(path, appname)
|
394 |
-
if appname and version:
|
395 |
-
path = os.path.join(path, version)
|
396 |
-
return path
|
397 |
-
|
398 |
-
|
399 |
-
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
|
400 |
-
r"""Return full path to the user-specific log dir for this application.
|
401 |
-
|
402 |
-
"appname" is the name of application.
|
403 |
-
If None, just the system directory is returned.
|
404 |
-
"appauthor" (only used on Windows) is the name of the
|
405 |
-
appauthor or distributing body for this application. Typically
|
406 |
-
it is the owning company name. This falls back to appname. You may
|
407 |
-
pass False to disable it.
|
408 |
-
"version" is an optional version path element to append to the
|
409 |
-
path. You might want to use this if you want multiple versions
|
410 |
-
of your app to be able to run independently. If used, this
|
411 |
-
would typically be "<major>.<minor>".
|
412 |
-
Only applied when appname is present.
|
413 |
-
"opinion" (boolean) can be False to disable the appending of
|
414 |
-
"Logs" to the base app data dir for Windows, and "log" to the
|
415 |
-
base cache dir for Unix. See discussion below.
|
416 |
-
|
417 |
-
Typical user log directories are:
|
418 |
-
Mac OS X: ~/Library/Logs/<AppName>
|
419 |
-
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
|
420 |
-
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
|
421 |
-
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
|
422 |
-
|
423 |
-
On Windows the only suggestion in the MSDN docs is that local settings
|
424 |
-
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
|
425 |
-
examples of what some windows apps use for a logs dir.)
|
426 |
-
|
427 |
-
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
|
428 |
-
value for Windows and appends "log" to the user cache dir for Unix.
|
429 |
-
This can be disabled with the `opinion=False` option.
|
430 |
-
"""
|
431 |
-
if system == "darwin":
|
432 |
-
path = os.path.join(os.path.expanduser("~/Library/Logs"), appname)
|
433 |
-
elif system == "win32":
|
434 |
-
path = user_data_dir(appname, appauthor, version)
|
435 |
-
version = False
|
436 |
-
if opinion:
|
437 |
-
path = os.path.join(path, "Logs")
|
438 |
-
else:
|
439 |
-
path = user_cache_dir(appname, appauthor, version)
|
440 |
-
version = False
|
441 |
-
if opinion:
|
442 |
-
path = os.path.join(path, "log")
|
443 |
-
if appname and version:
|
444 |
-
path = os.path.join(path, version)
|
445 |
-
return path
|
446 |
-
|
447 |
-
|
448 |
-
class AppDirs(object):
|
449 |
-
"""Convenience wrapper for getting application dirs."""
|
450 |
-
|
451 |
-
def __init__(
|
452 |
-
self, appname=None, appauthor=None, version=None, roaming=False, multipath=False
|
453 |
-
):
|
454 |
-
self.appname = appname
|
455 |
-
self.appauthor = appauthor
|
456 |
-
self.version = version
|
457 |
-
self.roaming = roaming
|
458 |
-
self.multipath = multipath
|
459 |
-
|
460 |
-
@property
|
461 |
-
def user_data_dir(self):
|
462 |
-
return user_data_dir(
|
463 |
-
self.appname, self.appauthor, version=self.version, roaming=self.roaming
|
464 |
-
)
|
465 |
-
|
466 |
-
@property
|
467 |
-
def site_data_dir(self):
|
468 |
-
return site_data_dir(
|
469 |
-
self.appname, self.appauthor, version=self.version, multipath=self.multipath
|
470 |
-
)
|
471 |
-
|
472 |
-
@property
|
473 |
-
def user_config_dir(self):
|
474 |
-
return user_config_dir(
|
475 |
-
self.appname, self.appauthor, version=self.version, roaming=self.roaming
|
476 |
-
)
|
477 |
-
|
478 |
-
@property
|
479 |
-
def site_config_dir(self):
|
480 |
-
return site_config_dir(
|
481 |
-
self.appname, self.appauthor, version=self.version, multipath=self.multipath
|
482 |
-
)
|
483 |
-
|
484 |
-
@property
|
485 |
-
def user_cache_dir(self):
|
486 |
-
return user_cache_dir(self.appname, self.appauthor, version=self.version)
|
487 |
-
|
488 |
-
@property
|
489 |
-
def user_state_dir(self):
|
490 |
-
return user_state_dir(self.appname, self.appauthor, version=self.version)
|
491 |
-
|
492 |
-
@property
|
493 |
-
def user_log_dir(self):
|
494 |
-
return user_log_dir(self.appname, self.appauthor, version=self.version)
|
495 |
-
|
496 |
-
|
497 |
-
# ---- internal support stuff
|
498 |
-
|
499 |
-
|
500 |
-
def _get_win_folder_from_registry(csidl_name):
|
501 |
-
"""This is a fallback technique at best. I'm not sure if using the
|
502 |
-
registry for this guarantees us the correct answer for all CSIDL_*
|
503 |
-
names.
|
504 |
-
"""
|
505 |
-
import winreg as _winreg
|
506 |
-
|
507 |
-
shell_folder_name = {
|
508 |
-
"CSIDL_APPDATA": "AppData",
|
509 |
-
"CSIDL_COMMON_APPDATA": "Common AppData",
|
510 |
-
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
511 |
-
}[csidl_name]
|
512 |
-
|
513 |
-
key = _winreg.OpenKey(
|
514 |
-
_winreg.HKEY_CURRENT_USER,
|
515 |
-
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
|
516 |
-
)
|
517 |
-
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
|
518 |
-
return dir
|
519 |
-
|
520 |
-
|
521 |
-
def _get_win_folder_with_pywin32(csidl_name):
|
522 |
-
from win32com.shell import shell, shellcon
|
523 |
-
|
524 |
-
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
|
525 |
-
# Try to make this a unicode path because SHGetFolderPath does
|
526 |
-
# not return unicode strings when there is unicode data in the
|
527 |
-
# path.
|
528 |
-
try:
|
529 |
-
dir = unicode(dir)
|
530 |
-
|
531 |
-
# Downgrade to short path name if have highbit chars. See
|
532 |
-
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
533 |
-
has_high_char = False
|
534 |
-
for c in dir:
|
535 |
-
if ord(c) > 255:
|
536 |
-
has_high_char = True
|
537 |
-
break
|
538 |
-
if has_high_char:
|
539 |
-
try:
|
540 |
-
import win32api
|
541 |
-
|
542 |
-
dir = win32api.GetShortPathName(dir)
|
543 |
-
except ImportError:
|
544 |
-
pass
|
545 |
-
except UnicodeError:
|
546 |
-
pass
|
547 |
-
return dir
|
548 |
-
|
549 |
-
|
550 |
-
def _get_win_folder_with_ctypes(csidl_name):
|
551 |
-
import ctypes
|
552 |
-
|
553 |
-
csidl_const = {
|
554 |
-
"CSIDL_APPDATA": 26,
|
555 |
-
"CSIDL_COMMON_APPDATA": 35,
|
556 |
-
"CSIDL_LOCAL_APPDATA": 28,
|
557 |
-
}[csidl_name]
|
558 |
-
|
559 |
-
buf = ctypes.create_unicode_buffer(1024)
|
560 |
-
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
|
561 |
-
|
562 |
-
# Downgrade to short path name if have highbit chars. See
|
563 |
-
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
564 |
-
has_high_char = False
|
565 |
-
for c in buf:
|
566 |
-
if ord(c) > 255:
|
567 |
-
has_high_char = True
|
568 |
-
break
|
569 |
-
if has_high_char:
|
570 |
-
buf2 = ctypes.create_unicode_buffer(1024)
|
571 |
-
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
|
572 |
-
buf = buf2
|
573 |
-
|
574 |
-
return buf.value
|
575 |
-
|
576 |
-
|
577 |
-
def _get_win_folder_with_jna(csidl_name):
|
578 |
-
import array
|
579 |
-
|
580 |
-
from com.sun import jna
|
581 |
-
from com.sun.jna.platform import win32
|
582 |
-
|
583 |
-
buf_size = win32.WinDef.MAX_PATH * 2
|
584 |
-
buf = array.zeros("c", buf_size)
|
585 |
-
shell = win32.Shell32.INSTANCE
|
586 |
-
shell.SHGetFolderPath(
|
587 |
-
None,
|
588 |
-
getattr(win32.ShlObj, csidl_name),
|
589 |
-
None,
|
590 |
-
win32.ShlObj.SHGFP_TYPE_CURRENT,
|
591 |
-
buf,
|
592 |
-
)
|
593 |
-
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
594 |
-
|
595 |
-
# Downgrade to short path name if have highbit chars. See
|
596 |
-
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
597 |
-
has_high_char = False
|
598 |
-
for c in dir:
|
599 |
-
if ord(c) > 255:
|
600 |
-
has_high_char = True
|
601 |
-
break
|
602 |
-
if has_high_char:
|
603 |
-
buf = array.zeros("c", buf_size)
|
604 |
-
kernel = win32.Kernel32.INSTANCE
|
605 |
-
if kernel.GetShortPathName(dir, buf, buf_size):
|
606 |
-
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
607 |
-
|
608 |
-
return dir
|
609 |
-
|
610 |
-
|
611 |
-
if system == "win32":
|
612 |
-
try:
|
613 |
-
import win32com.shell
|
614 |
-
|
615 |
-
_get_win_folder = _get_win_folder_with_pywin32
|
616 |
-
except ImportError:
|
617 |
-
try:
|
618 |
-
from ctypes import windll
|
619 |
-
|
620 |
-
_get_win_folder = _get_win_folder_with_ctypes
|
621 |
-
except ImportError:
|
622 |
-
try:
|
623 |
-
import com.sun.jna
|
624 |
-
|
625 |
-
_get_win_folder = _get_win_folder_with_jna
|
626 |
-
except ImportError:
|
627 |
-
_get_win_folder = _get_win_folder_from_registry
|
628 |
-
|
629 |
-
|
630 |
-
# ---- self test code
|
631 |
-
|
632 |
-
if __name__ == "__main__":
|
633 |
-
appname = "MyApp"
|
634 |
-
appauthor = "MyCompany"
|
635 |
-
|
636 |
-
props = (
|
637 |
-
"user_data_dir",
|
638 |
-
"user_config_dir",
|
639 |
-
"user_cache_dir",
|
640 |
-
"user_state_dir",
|
641 |
-
"user_log_dir",
|
642 |
-
"site_data_dir",
|
643 |
-
"site_config_dir",
|
644 |
-
)
|
645 |
-
|
646 |
-
print(f"-- app dirs {__version__} --")
|
647 |
-
|
648 |
-
print("-- app dirs (with optional 'version')")
|
649 |
-
dirs = AppDirs(appname, appauthor, version="1.0")
|
650 |
-
for prop in props:
|
651 |
-
print(f"{prop}: {getattr(dirs, prop)}")
|
652 |
-
|
653 |
-
print("\n-- app dirs (without optional 'version')")
|
654 |
-
dirs = AppDirs(appname, appauthor)
|
655 |
-
for prop in props:
|
656 |
-
print(f"{prop}: {getattr(dirs, prop)}")
|
657 |
-
|
658 |
-
print("\n-- app dirs (without optional 'appauthor')")
|
659 |
-
dirs = AppDirs(appname)
|
660 |
-
for prop in props:
|
661 |
-
print(f"{prop}: {getattr(dirs, prop)}")
|
662 |
-
|
663 |
-
print("\n-- app dirs (with disabled 'appauthor')")
|
664 |
-
dirs = AppDirs(appname, appauthor=False)
|
665 |
-
for prop in props:
|
666 |
-
print(f"{prop}: {getattr(dirs, prop)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_awaits/__init__.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from typing import cast, Callable, Generic, Type, TypeVar
|
4 |
-
|
5 |
-
import torch
|
6 |
-
|
7 |
-
__all__ = ['Await']
|
8 |
-
|
9 |
-
W = TypeVar("W")
|
10 |
-
|
11 |
-
class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
|
12 |
-
pass
|
13 |
-
|
14 |
-
class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
|
15 |
-
r"""
|
16 |
-
Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
|
17 |
-
of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
|
18 |
-
``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
|
19 |
-
|
20 |
-
Torch scriptable manipulations:
|
21 |
-
``torch.jit._awaitable(func, *args)``
|
22 |
-
Creates ``Await[W]`` object, where W is return type of func.
|
23 |
-
|
24 |
-
Returns:
|
25 |
-
``torch.jit._awaitable_wait(Await[W])``
|
26 |
-
Returns the result of the function, specified at ``_awaitable``, with specified arguments.
|
27 |
-
|
28 |
-
Returns:
|
29 |
-
The result of type ``W`` of the function call. The result is owned by ``Await[W]``
|
30 |
-
and returned on all following ``_awaitable_wait`` calls.
|
31 |
-
|
32 |
-
|
33 |
-
``torch.jit._awaitable_nowait(W)``
|
34 |
-
Returns:
|
35 |
-
Trivial ``Await[W]`` with specified result.
|
36 |
-
|
37 |
-
|
38 |
-
Only in eager mode:
|
39 |
-
``fn() -> Callable[Tuple[Any], W]``
|
40 |
-
Returns:
|
41 |
-
Specified at ``_awaitable`` python function ``func``.
|
42 |
-
|
43 |
-
``args() -> Tuple[Any]``
|
44 |
-
Returns:
|
45 |
-
Specified at ``_awaitable`` python args.
|
46 |
-
|
47 |
-
``is_nowait() -> _bool``
|
48 |
-
Returns:
|
49 |
-
``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
|
50 |
-
|
51 |
-
In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
|
52 |
-
``_awaitable_wait()`` call will be transparently added.
|
53 |
-
"""
|
54 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_awaits/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (2.08 kB)
|
|
torch/_classes.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import types
|
2 |
-
|
3 |
-
import torch._C
|
4 |
-
|
5 |
-
|
6 |
-
class _ClassNamespace(types.ModuleType):
|
7 |
-
def __init__(self, name):
|
8 |
-
super().__init__("torch.classes" + name)
|
9 |
-
self.name = name
|
10 |
-
|
11 |
-
def __getattr__(self, attr):
|
12 |
-
proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
|
13 |
-
if proxy is None:
|
14 |
-
raise RuntimeError(f"Class {self.name}.{attr} not registered!")
|
15 |
-
return proxy
|
16 |
-
|
17 |
-
|
18 |
-
class _Classes(types.ModuleType):
|
19 |
-
__file__ = "_classes.py"
|
20 |
-
|
21 |
-
def __init__(self):
|
22 |
-
super().__init__("torch.classes")
|
23 |
-
|
24 |
-
def __getattr__(self, name):
|
25 |
-
namespace = _ClassNamespace(name)
|
26 |
-
setattr(self, name, namespace)
|
27 |
-
return namespace
|
28 |
-
|
29 |
-
@property
|
30 |
-
def loaded_libraries(self):
|
31 |
-
return torch.ops.loaded_libraries
|
32 |
-
|
33 |
-
def load_library(self, path):
|
34 |
-
"""
|
35 |
-
Loads a shared library from the given path into the current process.
|
36 |
-
|
37 |
-
The library being loaded may run global initialization code to register
|
38 |
-
custom classes with the PyTorch JIT runtime. This allows dynamically
|
39 |
-
loading custom classes. For this, you should compile your class
|
40 |
-
and the static registration code into a shared library object, and then
|
41 |
-
call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
|
42 |
-
shared object.
|
43 |
-
|
44 |
-
After the library is loaded, it is added to the
|
45 |
-
``torch.classes.loaded_libraries`` attribute, a set that may be inspected
|
46 |
-
for the paths of all libraries loaded using this function.
|
47 |
-
|
48 |
-
Args:
|
49 |
-
path (str): A path to a shared library to load.
|
50 |
-
"""
|
51 |
-
torch.ops.load_library(path)
|
52 |
-
|
53 |
-
|
54 |
-
# The classes "namespace"
|
55 |
-
classes = _Classes()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_compile.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
APIs related to torch.compile which lazily import torch._dynamo to avoid
|
3 |
-
circular dependencies.
|
4 |
-
"""
|
5 |
-
import functools
|
6 |
-
|
7 |
-
|
8 |
-
def _disable_dynamo(fn=None, recursive=True):
|
9 |
-
"""
|
10 |
-
This API should be only used inside torch, external users should still use
|
11 |
-
torch._dynamo.disable. The main goal of this API is to avoid circular
|
12 |
-
imports issues that is common while using _dynamo.disable inside torch
|
13 |
-
itself.
|
14 |
-
|
15 |
-
This API avoids it by lazily importing torch._dynamo from the import time to
|
16 |
-
the invocation of the decorated function.
|
17 |
-
"""
|
18 |
-
if fn is not None:
|
19 |
-
|
20 |
-
@functools.wraps(fn)
|
21 |
-
def inner(*args, **kwargs):
|
22 |
-
import torch._dynamo
|
23 |
-
|
24 |
-
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
|
25 |
-
|
26 |
-
return inner
|
27 |
-
else:
|
28 |
-
# decorator usage like @_disable_dynamo(recursive=False). The resulting
|
29 |
-
# object expects the original decorated function as the arg.
|
30 |
-
return functools.partial(_disable_dynamo, recursive=recursive)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_custom_op/__init__.py
DELETED
File without changes
|
torch/_custom_op/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (156 Bytes)
|
|
torch/_custom_op/__pycache__/autograd.cpython-310.pyc
DELETED
Binary file (8.86 kB)
|
|
torch/_custom_op/__pycache__/functional.cpython-310.pyc
DELETED
Binary file (5.95 kB)
|
|
torch/_custom_op/__pycache__/impl.cpython-310.pyc
DELETED
Binary file (33.5 kB)
|
|
torch/_custom_op/autograd.py
DELETED
@@ -1,274 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.utils._pytree as pytree
|
3 |
-
from collections import namedtuple
|
4 |
-
import functools
|
5 |
-
|
6 |
-
|
7 |
-
# NOTE [CustomOp autograd kernel indirection]
|
8 |
-
# We register `inner` as the autograd kernel for this custom_op.
|
9 |
-
# `inner` either calls the autograd formula registered by the user,
|
10 |
-
# or goes into an `autograd_not_implemented` kernel.
|
11 |
-
#
|
12 |
-
# The reason why this indirection exists is
|
13 |
-
# so that we can swap out the autograd kernel (the PyTorch dispatcher
|
14 |
-
# doesn't actually allow us to do this). By default, we want
|
15 |
-
# the `autograd_not_implemented` behavior, but then the user may come
|
16 |
-
# and register something that is actually a backward formula
|
17 |
-
def autograd_kernel_indirection(custom_op):
|
18 |
-
autograd_fallback = autograd_not_implemented(custom_op)
|
19 |
-
|
20 |
-
def inner(*args, **kwargs):
|
21 |
-
if custom_op._has_impl('autograd'):
|
22 |
-
kernel = custom_op._get_impl('autograd').func
|
23 |
-
return kernel(*args, **kwargs)
|
24 |
-
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
|
25 |
-
# after the user gives us "backward" and "save_for_backward", we generate
|
26 |
-
# the "autograd" impl. If the user only provided one, then we tell
|
27 |
-
# the user they've done something wrong.
|
28 |
-
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
|
29 |
-
missing = (
|
30 |
-
'save_for_backward' if custom_op._has_impl('backward')
|
31 |
-
else 'backward'
|
32 |
-
)
|
33 |
-
found = 'save_for_backward' if missing == 'backward' else 'backward'
|
34 |
-
loc = custom_op._get_impl(found).location
|
35 |
-
raise RuntimeError(
|
36 |
-
f"We found a '{found}' registration for {custom_op} at "
|
37 |
-
f"{loc} but were unable to find a '{missing}' registration. "
|
38 |
-
f"To use the CustomOp API to register a backward formula, "
|
39 |
-
f"please provide us both a backward function and a "
|
40 |
-
f"'save for backward' function via `impl_backward` and "
|
41 |
-
f"`impl_save_for_backward` respectively.")
|
42 |
-
return autograd_fallback(*args, **kwargs)
|
43 |
-
return inner
|
44 |
-
|
45 |
-
|
46 |
-
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
|
47 |
-
# or change the default autograd fallback to the autograd not implemented fallback.
|
48 |
-
def autograd_not_implemented(custom_op):
|
49 |
-
def kernel(*args, **kwargs):
|
50 |
-
if torch.is_grad_enabled() and pytree.tree_any(
|
51 |
-
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
|
52 |
-
):
|
53 |
-
raise RuntimeError("Autograd has not been implemented for operator")
|
54 |
-
with torch._C._AutoDispatchBelowAutograd():
|
55 |
-
return custom_op(*args, **kwargs)
|
56 |
-
return kernel
|
57 |
-
|
58 |
-
|
59 |
-
def mark_non_differentiable(ctx, output, output_differentiability):
|
60 |
-
# Output types are restricted to be:
|
61 |
-
# - Tensor
|
62 |
-
# - Tensor[]
|
63 |
-
# - int, bool, Scalar, float
|
64 |
-
# See _check_can_register_backward
|
65 |
-
if output_differentiability is not None:
|
66 |
-
if not isinstance(output, tuple):
|
67 |
-
tuple_output = (output,)
|
68 |
-
else:
|
69 |
-
tuple_output = output # type: ignore[assignment]
|
70 |
-
assert len(output_differentiability) == len(tuple_output)
|
71 |
-
non_differentiable_tensors = []
|
72 |
-
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
|
73 |
-
if isinstance(out, torch.Tensor):
|
74 |
-
if not differentiable:
|
75 |
-
non_differentiable_tensors.append(out)
|
76 |
-
continue
|
77 |
-
if isinstance(out, list):
|
78 |
-
if not differentiable:
|
79 |
-
non_differentiable_tensors.extend(out)
|
80 |
-
continue
|
81 |
-
if differentiable:
|
82 |
-
raise RuntimeError(
|
83 |
-
f"With output_differentiability={output_differentiability}. "
|
84 |
-
f"At idx {idx}, we received an object of type {type(out)} that "
|
85 |
-
f"is not a Tensor, so it cannot have be marked as differentiable in "
|
86 |
-
f"output_differentiability.")
|
87 |
-
if non_differentiable_tensors:
|
88 |
-
ctx.mark_non_differentiable(*non_differentiable_tensors)
|
89 |
-
|
90 |
-
|
91 |
-
def construct_autograd_kernel(
|
92 |
-
schema,
|
93 |
-
output_differentiability,
|
94 |
-
custom_op,
|
95 |
-
op_overload,
|
96 |
-
save_for_backward_fn,
|
97 |
-
backward_fn):
|
98 |
-
|
99 |
-
def apply(*args):
|
100 |
-
flat_args, spec = pytree.tree_flatten(args)
|
101 |
-
out_spec = None
|
102 |
-
|
103 |
-
def forward(ctx, *flat_args):
|
104 |
-
ctx.set_materialize_grads(True)
|
105 |
-
args = pytree.tree_unflatten(list(flat_args), spec)
|
106 |
-
with torch._C._AutoDispatchBelowAutograd():
|
107 |
-
output = op_overload(*args)
|
108 |
-
|
109 |
-
# We use the info about args to give better error messages in backward
|
110 |
-
args_info = namedtuple_args(
|
111 |
-
schema, pytree.tree_map(type, args))
|
112 |
-
|
113 |
-
save_for_backward_fn_inputs = namedtuple_args(schema, args)
|
114 |
-
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
|
115 |
-
|
116 |
-
save_pytree_for_backward(ctx, (to_save, args_info))
|
117 |
-
mark_non_differentiable(ctx, output, output_differentiability)
|
118 |
-
|
119 |
-
nonlocal out_spec
|
120 |
-
flat_output, out_spec = pytree.tree_flatten(output)
|
121 |
-
return tuple(flat_output)
|
122 |
-
|
123 |
-
def backward(ctx, *flat_grad_output):
|
124 |
-
assert out_spec is not None
|
125 |
-
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
|
126 |
-
saved, args_info = unpack_saved(ctx)
|
127 |
-
# There is nothing on the ctx object for now, it is just there so
|
128 |
-
# that we can add additional things in the future.
|
129 |
-
inner_ctx = object()
|
130 |
-
if not isinstance(grads, tuple):
|
131 |
-
grads = (grads,)
|
132 |
-
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
|
133 |
-
|
134 |
-
# Massage the grad_inputs_dict to a form acceptable by
|
135 |
-
# autograd.Function.
|
136 |
-
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
|
137 |
-
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
|
138 |
-
|
139 |
-
generated_cls = gen_autograd_function(
|
140 |
-
custom_op._opname + '_customop', forward, backward)
|
141 |
-
|
142 |
-
flat_output = generated_cls.apply(*flat_args)
|
143 |
-
assert out_spec is not None
|
144 |
-
return pytree.tree_unflatten(list(flat_output), out_spec)
|
145 |
-
return apply
|
146 |
-
|
147 |
-
|
148 |
-
def gen_autograd_function(name, forward, backward):
|
149 |
-
generated_cls = type(
|
150 |
-
name,
|
151 |
-
(torch.autograd.Function,),
|
152 |
-
{
|
153 |
-
'forward': staticmethod(forward),
|
154 |
-
'backward': staticmethod(backward),
|
155 |
-
}
|
156 |
-
)
|
157 |
-
return generated_cls
|
158 |
-
|
159 |
-
|
160 |
-
@functools.lru_cache
|
161 |
-
def namedtuple_args_cls(schema):
|
162 |
-
attribs = [arg.name for arg in schema.arguments.flat_all]
|
163 |
-
name = str(schema.name) + "_args"
|
164 |
-
# mypy doesn't support dynamic namedtuple name
|
165 |
-
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
|
166 |
-
return tuple_cls
|
167 |
-
|
168 |
-
|
169 |
-
def namedtuple_args(schema, args):
|
170 |
-
assert isinstance(args, tuple)
|
171 |
-
tuple_cls = namedtuple_args_cls(schema)
|
172 |
-
return tuple_cls(*args)
|
173 |
-
|
174 |
-
|
175 |
-
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
|
176 |
-
def error(what):
|
177 |
-
backward = forward_op._get_impl('backward')
|
178 |
-
raise RuntimeError(
|
179 |
-
f"In the backward function defined for {forward_op} at "
|
180 |
-
f"{backward.location} using the CustomOp API, {what}")
|
181 |
-
|
182 |
-
if not isinstance(grad_inputs_dict, dict):
|
183 |
-
error(f"expected the output of the backward function to be a dict but "
|
184 |
-
f"got {type(grad_inputs_dict)}")
|
185 |
-
|
186 |
-
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
|
187 |
-
if arg.type.is_tensor_like()}
|
188 |
-
actual_keys = grad_inputs_dict.keys()
|
189 |
-
if expected_keys != actual_keys:
|
190 |
-
error(f"expected the returned grad_input dict to have keys "
|
191 |
-
f"{expected_keys} but got {actual_keys}. The backward "
|
192 |
-
f"function must return a gradient (can be None) for each arg "
|
193 |
-
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
|
194 |
-
f"Args declared to be non-Tensor-like types should not appear "
|
195 |
-
f"in the grad_input dict")
|
196 |
-
|
197 |
-
for name, grad in grad_inputs_dict.items():
|
198 |
-
arg_info = getattr(args_info, name)
|
199 |
-
|
200 |
-
if isinstance(arg_info, list):
|
201 |
-
if not isinstance(grad, (tuple, list)):
|
202 |
-
error(f"for input '{name}' expected the grad_input dict to "
|
203 |
-
f"hold a list of gradients but got object of type "
|
204 |
-
f"{type(grad)}.")
|
205 |
-
if not len(grad) == len(arg_info):
|
206 |
-
error(f"for input '{name}' expected the grad_input dict to "
|
207 |
-
f"hold a list of {len(arg_info)} gradients but got "
|
208 |
-
f"{len(grad)}")
|
209 |
-
for idx, (g, info) in enumerate(zip(grad, arg_info)):
|
210 |
-
if g is None:
|
211 |
-
continue
|
212 |
-
if not isinstance(g, torch.Tensor):
|
213 |
-
error(f"for input '{name}' expected the grad_input dict to "
|
214 |
-
f"hold a list of None or Tensor gradients but got "
|
215 |
-
f"object of {type(g)} at index {idx}")
|
216 |
-
if not issubclass(info, torch.Tensor):
|
217 |
-
error(f"for input '{name}', got a Tensor as the gradient "
|
218 |
-
f"for the {idx}-th value but expected None because "
|
219 |
-
f"the {idx}-th value was not a Tensor (it was "
|
220 |
-
f"type {arg_info}")
|
221 |
-
continue
|
222 |
-
|
223 |
-
if grad is None:
|
224 |
-
continue
|
225 |
-
if not isinstance(grad, torch.Tensor):
|
226 |
-
error(f"got object of type {type(grad)} as the gradient for input "
|
227 |
-
f"'{name}', "
|
228 |
-
f"but expected the gradient to be either None or a Tensor")
|
229 |
-
if not issubclass(arg_info, torch.Tensor):
|
230 |
-
error(f"got a Tensor as the gradient for input '{name}' but "
|
231 |
-
f"expected None as the gradient because input '{name}' "
|
232 |
-
f"was not a Tensor (it was type {arg_info}).")
|
233 |
-
|
234 |
-
|
235 |
-
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
|
236 |
-
result = []
|
237 |
-
for name, arg_info in args_info._asdict().items():
|
238 |
-
if name not in grad_inputs_dict:
|
239 |
-
result.append(pytree.tree_map(lambda x: None, arg_info))
|
240 |
-
continue
|
241 |
-
result.append(grad_inputs_dict[name])
|
242 |
-
return tuple(pytree.tree_leaves(result))
|
243 |
-
|
244 |
-
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
|
245 |
-
# autograd.Function prefers that users use ctx.save_for_backward to
|
246 |
-
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
|
247 |
-
# ctx object.
|
248 |
-
def save_pytree_for_backward(ctx, stuff):
|
249 |
-
flat_stuff, spec = pytree.tree_flatten(stuff)
|
250 |
-
num_elts = len(flat_stuff)
|
251 |
-
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
252 |
-
if isinstance(thing, torch.Tensor)]
|
253 |
-
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
254 |
-
if not isinstance(thing, torch.Tensor)]
|
255 |
-
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
|
256 |
-
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
|
257 |
-
|
258 |
-
ctx.spec = spec
|
259 |
-
ctx.num_elts = num_elts
|
260 |
-
ctx.save_for_backward(*tensors)
|
261 |
-
ctx.tensor_idxs = tensor_idxs
|
262 |
-
ctx.saved_non_tensors = non_tensors
|
263 |
-
ctx.non_tensor_idxs = non_tensor_idxs
|
264 |
-
|
265 |
-
|
266 |
-
# Inverse operation to save_pytree_for_backward
|
267 |
-
def unpack_saved(ctx):
|
268 |
-
flat_stuff = [None] * ctx.num_elts
|
269 |
-
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
|
270 |
-
flat_stuff[idx] = tensor
|
271 |
-
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
|
272 |
-
flat_stuff[idx] = non_tensor
|
273 |
-
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
|
274 |
-
return stuff
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_custom_op/functional.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
import weakref
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.utils._pytree as pytree
|
5 |
-
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
6 |
-
from torch._ops import OpOverload
|
7 |
-
from torch.library import Library
|
8 |
-
from torchgen.model import (
|
9 |
-
BaseTy,
|
10 |
-
BaseType,
|
11 |
-
FunctionSchema,
|
12 |
-
OperatorName,
|
13 |
-
OptionalType,
|
14 |
-
SchemaKind,
|
15 |
-
)
|
16 |
-
|
17 |
-
from .autograd import autograd_not_implemented
|
18 |
-
|
19 |
-
|
20 |
-
def register_functional_op(
|
21 |
-
lib: Library,
|
22 |
-
new_op_name: str,
|
23 |
-
mutable_op: OpOverload,
|
24 |
-
) -> None:
|
25 |
-
"""Given a mutable operator, registers the functional variant.
|
26 |
-
|
27 |
-
This API also correctly links the functional variant with the mutable
|
28 |
-
operator for the purposes of functionalization.
|
29 |
-
|
30 |
-
All of the new registrations are performed on the ``lib`` passed in.
|
31 |
-
|
32 |
-
Arguments:
|
33 |
-
lib (Library): Should be a torch.library.Library object that has
|
34 |
-
the same namespace as ``mutable_op``'s namespace.
|
35 |
-
lib will be used to register the new functional op as well
|
36 |
-
as a functionalization kernel for the ``mutable_op``
|
37 |
-
If you don't have a library handy, use
|
38 |
-
``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
|
39 |
-
new_op_name (str): The name of the functional operator (without the
|
40 |
-
namespace). If no namespace, the new functional variant will be
|
41 |
-
accessible under ``torch.ops.{lib.ns}.new_op_name``.
|
42 |
-
mutable_op (OpOverload): The mutable custom operator. Note
|
43 |
-
that you may need to add a `.default` to it, like
|
44 |
-
`torch.ops.aten.abs_.default`.
|
45 |
-
|
46 |
-
"""
|
47 |
-
validate(mutable_op)
|
48 |
-
schema = functional_schema(new_op_name, mutable_op)
|
49 |
-
lib.define(schema)
|
50 |
-
|
51 |
-
functional_impl = construct_functional_impl(mutable_op)
|
52 |
-
lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
|
53 |
-
|
54 |
-
functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
|
55 |
-
|
56 |
-
# There's no easy way for us to generate the autograd kernel, so we
|
57 |
-
# use autograd_not_implemented. Also, this makes it so that the user
|
58 |
-
# is unable to register an autograd formula themselves. This shouldn't
|
59 |
-
# be a problem if the user doesn't use the functional op direclty
|
60 |
-
# in their program, but we may need to revist this in the future.
|
61 |
-
lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
|
62 |
-
|
63 |
-
f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
|
64 |
-
|
65 |
-
lib.impl(mutable_op, f_kernel, 'Functionalize')
|
66 |
-
|
67 |
-
|
68 |
-
def construct_functional_impl(mutable_op):
|
69 |
-
def functional_impl(*args):
|
70 |
-
# Strategy:
|
71 |
-
# - clone args that would have been mutated
|
72 |
-
# - run mutable_op
|
73 |
-
# - return the cloned args as additional outputs
|
74 |
-
new_args = []
|
75 |
-
extra_rets = []
|
76 |
-
for is_write, arg in zip(mutable_args(mutable_op), args):
|
77 |
-
if is_write:
|
78 |
-
cloned = arg.clone() if arg is not None else None
|
79 |
-
new_args.append(cloned)
|
80 |
-
extra_rets.append(cloned)
|
81 |
-
else:
|
82 |
-
new_args.append(arg)
|
83 |
-
result = mutable_op(*new_args)
|
84 |
-
if result is None:
|
85 |
-
return tuple(extra_rets)
|
86 |
-
if isinstance(result, tuple):
|
87 |
-
return (*result, *extra_rets)
|
88 |
-
return (result, *extra_rets)
|
89 |
-
return functional_impl
|
90 |
-
|
91 |
-
|
92 |
-
def construct_functionalization_kernel(mutable_op, functional_op):
|
93 |
-
def kernel(*args):
|
94 |
-
# There's nothing to be functionalized!
|
95 |
-
# We can still end up here because DispatchKey::Functionalize is a mode key
|
96 |
-
if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
|
97 |
-
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
98 |
-
return mutable_op(*args)
|
99 |
-
|
100 |
-
# NB: This differs from the codegen -- codegen handles cases where there
|
101 |
-
# are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
|
102 |
-
# This only really matters for XLA (mixed CPU-XLA tensors) and
|
103 |
-
# running functionalization without the PT2 stack (which guarantees to us that
|
104 |
-
# all tensors are FunctionalTensorWrapper).
|
105 |
-
if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
|
106 |
-
raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
|
107 |
-
|
108 |
-
unwrapped_args = []
|
109 |
-
for arg in args:
|
110 |
-
if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
|
111 |
-
torch._sync(arg)
|
112 |
-
unwrapped = torch._from_functional_tensor(arg)
|
113 |
-
unwrapped_args.append(unwrapped)
|
114 |
-
else:
|
115 |
-
unwrapped_args.append(arg)
|
116 |
-
|
117 |
-
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
118 |
-
output = functional_op(*unwrapped_args)
|
119 |
-
|
120 |
-
num_actual_output = len(mutable_op._schema.returns)
|
121 |
-
actual_output = pytree.tree_map(
|
122 |
-
torch._to_functional_tensor, output[:num_actual_output])
|
123 |
-
|
124 |
-
new_values_to_propagate = output[num_actual_output:]
|
125 |
-
inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
|
126 |
-
if is_write]
|
127 |
-
assert len(new_values_to_propagate) == len(inputs_to_replace)
|
128 |
-
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
|
129 |
-
if (arg is None and new_value is None) or (arg is not None and new_value is not None):
|
130 |
-
continue
|
131 |
-
torch._C._propagate_xla_data(arg, new_value)
|
132 |
-
torch._C._replace_(arg, new_value)
|
133 |
-
torch._C._commit_update(arg)
|
134 |
-
torch._sync(arg)
|
135 |
-
|
136 |
-
if len(actual_output) == 1:
|
137 |
-
return actual_output[0]
|
138 |
-
elif len(actual_output) == 0:
|
139 |
-
return None
|
140 |
-
return actual_output
|
141 |
-
|
142 |
-
return kernel
|
143 |
-
|
144 |
-
|
145 |
-
def validate(mutable_op: OpOverload):
|
146 |
-
if not isinstance(mutable_op, OpOverload):
|
147 |
-
raise TypeError(
|
148 |
-
f"register_functional_op(mutable_op): expected mutable_op to be instance of "
|
149 |
-
f"OpOverload but got {type(mutable_op)}")
|
150 |
-
|
151 |
-
# There are generally three types of "in-place" or "mutable" ops.
|
152 |
-
# Each of them have their own conventions:
|
153 |
-
# - inplace (first input modified in-place and returned as only output)
|
154 |
-
# - out= (some args modified in-place and returned as outputs)
|
155 |
-
# - mutable (some args modified in-place but none of those returned as outputs)
|
156 |
-
# In theory we can support all three, but we'll just support the last
|
157 |
-
# option right now for simplicity.
|
158 |
-
schema = FunctionSchema.parse(str(mutable_op._schema))
|
159 |
-
if not schema.kind() == SchemaKind.mutable:
|
160 |
-
raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
|
161 |
-
for ret in schema.returns:
|
162 |
-
# construct_functionalization_kernel assumes this for simplicity
|
163 |
-
if ret.annotation is not None:
|
164 |
-
raise NotImplementedError(
|
165 |
-
"NYI: register_functional_op(op) where op returns a mutated or aliased value. "
|
166 |
-
"Please file an issue (and as a workaround, modify your operator to "
|
167 |
-
"not return the mutated value or aliases)")
|
168 |
-
for arg in schema.arguments.flat_all:
|
169 |
-
# construct_functionalization_kernel assumes this for simplicity
|
170 |
-
if arg.type.is_tensor_like() and (
|
171 |
-
arg.type != BaseType(BaseTy.Tensor)
|
172 |
-
and arg.type != OptionalType(BaseType(BaseTy.Tensor))
|
173 |
-
):
|
174 |
-
raise NotImplementedError(
|
175 |
-
"NYI: register_functional_op(op) where op has a List[Tensor] input."
|
176 |
-
"Please file an issue.")
|
177 |
-
|
178 |
-
|
179 |
-
def functional_schema(new_op_name, op: OpOverload):
|
180 |
-
schema = FunctionSchema.parse(str(op._schema))
|
181 |
-
schema = schema.signature().with_name(OperatorName.parse(new_op_name))
|
182 |
-
return str(schema)
|
183 |
-
|
184 |
-
|
185 |
-
def mutable_args(op: OpOverload):
|
186 |
-
return tuple(False if arg.alias_info is None else arg.alias_info.is_write
|
187 |
-
for arg in op._schema.arguments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_custom_op/impl.py
DELETED
@@ -1,976 +0,0 @@
|
|
1 |
-
import dataclasses
|
2 |
-
import functools
|
3 |
-
import inspect
|
4 |
-
import sys
|
5 |
-
import typing
|
6 |
-
import weakref
|
7 |
-
|
8 |
-
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch._C as _C
|
12 |
-
import torch.library as library
|
13 |
-
from torch._library.abstract_impl import AbstractImplCtx
|
14 |
-
from torch.library import get_ctx
|
15 |
-
|
16 |
-
from .autograd import autograd_kernel_indirection, construct_autograd_kernel
|
17 |
-
|
18 |
-
"""
|
19 |
-
For a detailed guide on custom ops, please see
|
20 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
21 |
-
|
22 |
-
This file includes pieces of the implementation of our custom operator API.
|
23 |
-
"""
|
24 |
-
|
25 |
-
__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
|
26 |
-
|
27 |
-
|
28 |
-
SUPPORTED_DEVICE_TYPE_TO_KEY = {
|
29 |
-
"cpu": "CPU",
|
30 |
-
"cuda": "CUDA",
|
31 |
-
}
|
32 |
-
|
33 |
-
# We will not let users register CustomOps with anything that could look like
|
34 |
-
# PyTorch internals to avoid confusion.
|
35 |
-
RESERVED_NS = {
|
36 |
-
"prim",
|
37 |
-
"prims",
|
38 |
-
"aten",
|
39 |
-
"at",
|
40 |
-
"torch",
|
41 |
-
"pytorch",
|
42 |
-
}
|
43 |
-
|
44 |
-
|
45 |
-
def custom_op(
|
46 |
-
qualname: str, manual_schema: typing.Optional[str] = None
|
47 |
-
) -> typing.Callable:
|
48 |
-
r"""Creates a new CustomOp object.
|
49 |
-
|
50 |
-
WARNING: if you're a user, please do not use this directly
|
51 |
-
(instead use the torch._custom_ops APIs).
|
52 |
-
Also please see the following for a detailed guide on custom ops.
|
53 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
54 |
-
|
55 |
-
In PyTorch, defining an op (short for "operator") is a two step-process:
|
56 |
-
- we need to define (create) the op
|
57 |
-
- we need to implement behavior for how the operator interacts with
|
58 |
-
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
|
59 |
-
|
60 |
-
This entrypoint defines the CustomOp object (the first step);
|
61 |
-
you must then perform the second step by calling various methods on
|
62 |
-
the CustomOp object.
|
63 |
-
|
64 |
-
This API is used as a decorator (see examples).
|
65 |
-
|
66 |
-
Arguments:
|
67 |
-
qualname (str): Should be a string that looks like
|
68 |
-
"namespace::operator_name". Operators in PyTorch need a namespace to
|
69 |
-
avoid name collisions; a given operator may only be created once.
|
70 |
-
If you are writing a Python library, we recommend the namespace to
|
71 |
-
be the name of your top-level module. The operator_name must be
|
72 |
-
the same as the name of the function you pass to custom_op
|
73 |
-
(see examples).
|
74 |
-
manual_schema (Optional[str]): Each PyTorch operator needs a schema that
|
75 |
-
tells PyTorch the types of the inputs/outputs. If None (default),
|
76 |
-
we will infer the schema from the type annotations on the function
|
77 |
-
(see examples). Otherwise, if you don't want to use type annotations,
|
78 |
-
you may provide us the schema string.
|
79 |
-
|
80 |
-
Example::
|
81 |
-
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
82 |
-
>>> import numpy as np
|
83 |
-
>>> from torch import Tensor
|
84 |
-
>>>
|
85 |
-
>>> # Step 1: define the CustomOp.
|
86 |
-
>>> # We need to provide the decorator a "prototype function"
|
87 |
-
>>> # (a function with Python ellipses as the body).
|
88 |
-
>>> @custom_op("my_library::numpy_sin")
|
89 |
-
>>> def numpy_sin(x: Tensor) -> Tensor:
|
90 |
-
>>> ...
|
91 |
-
>>>
|
92 |
-
>>> # numpy_sin is now an instance of class CustomOp
|
93 |
-
>>> print(type(numpy_sin))
|
94 |
-
>>>
|
95 |
-
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
96 |
-
>>>
|
97 |
-
>>> # Register an implementation for CPU tensors
|
98 |
-
>>> @numpy_sin.impl('cpu')
|
99 |
-
>>> def numpy_sin_impl_cpu(x):
|
100 |
-
>>> return torch.from_numpy(np.sin(x.numpy()))
|
101 |
-
>>>
|
102 |
-
>>> # Register an implementation for CUDA tensors
|
103 |
-
>>> @numpy_sin.impl('cuda')
|
104 |
-
>>> def numpy_sin_impl_cuda(x):
|
105 |
-
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
|
106 |
-
>>>
|
107 |
-
>>> x = torch.randn(3)
|
108 |
-
>>> numpy_sin(x) # calls numpy_sin_impl_cpu
|
109 |
-
>>>
|
110 |
-
>>> x_cuda = x.cuda()
|
111 |
-
>>> numpy_sin(x) # calls numpy_sin_impl_cuda
|
112 |
-
|
113 |
-
"""
|
114 |
-
|
115 |
-
def inner(func):
|
116 |
-
if not inspect.isfunction(func):
|
117 |
-
raise ValueError(
|
118 |
-
f"custom_op(...)(func): Expected `func` to be a Python "
|
119 |
-
f"function, got: {type(func)}"
|
120 |
-
)
|
121 |
-
|
122 |
-
ns, name = parse_qualname(qualname)
|
123 |
-
validate_namespace(ns)
|
124 |
-
if func.__name__ != name:
|
125 |
-
raise ValueError(
|
126 |
-
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
|
127 |
-
f"to have name '{name}' but got '{func.__name__}'. "
|
128 |
-
f"Please either change the name of `func` or the qualname that "
|
129 |
-
f"is passed to `custom_op`"
|
130 |
-
)
|
131 |
-
|
132 |
-
schema = infer_schema(func) if manual_schema is None else manual_schema
|
133 |
-
schema_str = f"{name}{schema}"
|
134 |
-
function_schema = FunctionSchema.parse(schema_str)
|
135 |
-
validate_schema(function_schema)
|
136 |
-
if manual_schema is not None:
|
137 |
-
validate_function_matches_schema(function_schema, func)
|
138 |
-
|
139 |
-
lib = library.Library(ns, "FRAGMENT")
|
140 |
-
lib.define(schema_str)
|
141 |
-
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
142 |
-
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
|
143 |
-
|
144 |
-
result.__name__ = func.__name__
|
145 |
-
result.__module__ = func.__module__
|
146 |
-
result.__doc__ = func.__doc__
|
147 |
-
|
148 |
-
library.impl(lib, result._opname, "Autograd")(
|
149 |
-
autograd_kernel_indirection(weakref.proxy(result))
|
150 |
-
)
|
151 |
-
|
152 |
-
torch._C._dispatch_set_report_error_callback(
|
153 |
-
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
154 |
-
)
|
155 |
-
|
156 |
-
return result
|
157 |
-
|
158 |
-
return inner
|
159 |
-
|
160 |
-
|
161 |
-
# Global dictionary holding references to all CustomOp objects
|
162 |
-
# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
|
163 |
-
# Used to query the CustomOp associated with a specific C++ dispatcher operator.
|
164 |
-
# An example usage is FakeTensor: FakeTensor checks if a specific operator
|
165 |
-
# has an implementation registered via the CustomOp API.
|
166 |
-
# Indexed by qualname (e.g. aten::foo)
|
167 |
-
global_registry: typing.Dict[str, "CustomOp"] = {}
|
168 |
-
|
169 |
-
|
170 |
-
class CustomOp:
|
171 |
-
r"""Class for custom operators in PyTorch.
|
172 |
-
|
173 |
-
Use the CustomOp API to create user-defined custom operators that behave
|
174 |
-
just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
|
175 |
-
comes to various PyTorch subsystems (like torch.compile).
|
176 |
-
|
177 |
-
To construct a `CustomOp`, use `custom_op`.
|
178 |
-
"""
|
179 |
-
|
180 |
-
def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
|
181 |
-
super().__init__()
|
182 |
-
if not _private_access:
|
183 |
-
raise RuntimeError(
|
184 |
-
"The CustomOp constructor is private and we do not guarantee "
|
185 |
-
"BC for it. Please use custom_op(...) to create a CustomOp object"
|
186 |
-
)
|
187 |
-
name = f"{cpp_ns}::{operator_name}"
|
188 |
-
self._schema = schema
|
189 |
-
self._cpp_ns = cpp_ns
|
190 |
-
self._lib: library.Library = lib
|
191 |
-
self._ophandle: _C._DispatchOperatorHandle = ophandle
|
192 |
-
# Has the name of the op, e.g. "foo". We cache here for convenience.
|
193 |
-
self._opname: str = operator_name
|
194 |
-
# this is _opname but with namespace. e.g. "custom::foo"
|
195 |
-
self._qualname: str = name
|
196 |
-
self.__name__ = None # mypy requires this
|
197 |
-
# NB: Some of these impls are registered as kernels to DispatchKeys.
|
198 |
-
# Modifying the _impls dict directly won't do anything in that case.
|
199 |
-
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
|
200 |
-
# See NOTE [CustomOp autograd kernel indirection]
|
201 |
-
self._registered_autograd_kernel_indirection = False
|
202 |
-
|
203 |
-
global_registry[self._qualname] = self
|
204 |
-
|
205 |
-
def _register_autograd_kernel_indirection(self):
|
206 |
-
assert not self._registered_autograd_kernel_indirection
|
207 |
-
self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
|
208 |
-
self._registered_autograd_kernel_indirection = True
|
209 |
-
|
210 |
-
# Records the impl and the source location in self._impls
|
211 |
-
# Note that this doesn't cause torch.library to use the impl, that
|
212 |
-
# needs to be done in a separate self._lib.impl call.
|
213 |
-
def _register_impl(self, kind, func, stacklevel=2):
|
214 |
-
if self._has_impl(kind):
|
215 |
-
func_and_location = self._impls[kind]
|
216 |
-
assert func_and_location is not None # Pacify mypy
|
217 |
-
location = func_and_location.location
|
218 |
-
raise RuntimeError(
|
219 |
-
f"Attempting to register a {kind} impl for operator {self._qualname} "
|
220 |
-
f"that already has a {kind} impl registered from Python at "
|
221 |
-
f"{location}. This is not supported."
|
222 |
-
)
|
223 |
-
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
224 |
-
location = f"{frame.filename}:{frame.lineno}"
|
225 |
-
self._impls[kind] = FuncAndLocation(func, location)
|
226 |
-
|
227 |
-
def _get_impl(self, kind):
|
228 |
-
return self._impls[kind]
|
229 |
-
|
230 |
-
def _has_impl(self, kind):
|
231 |
-
return kind in self._impls
|
232 |
-
|
233 |
-
def _destroy(self):
|
234 |
-
# NOTE: [CustomOp lifetime]
|
235 |
-
# A CustomOp, once created, lives forever. The mechanism is that the
|
236 |
-
# global registry holds a reference to it. However, to make testing
|
237 |
-
# easier, we want to be able to destroy CustomOp objects.
|
238 |
-
# CustomOp._destroy does the job, though it leaves the CustomOp
|
239 |
-
# in a garbage state.
|
240 |
-
del self._lib
|
241 |
-
|
242 |
-
opnamespace = getattr(torch.ops, self._cpp_ns)
|
243 |
-
if hasattr(opnamespace, self._opname):
|
244 |
-
delattr(opnamespace, self._opname)
|
245 |
-
|
246 |
-
del global_registry[self._qualname]
|
247 |
-
|
248 |
-
def __repr__(self):
|
249 |
-
return f'<CustomOp(op="{self._qualname}")>'
|
250 |
-
|
251 |
-
def __call__(self, *args, **kwargs):
|
252 |
-
# Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
|
253 |
-
# Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
|
254 |
-
# issues from caching operators that make testing CustomOp difficult).
|
255 |
-
result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
|
256 |
-
return result
|
257 |
-
|
258 |
-
def impl(
|
259 |
-
self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
|
260 |
-
) -> typing.Callable:
|
261 |
-
r"""Register an implementation for a device type for this CustomOp object.
|
262 |
-
|
263 |
-
WARNING: if you're a user, please do not use this directly
|
264 |
-
(instead use the torch._custom_ops APIs).
|
265 |
-
Also please see the following for a detailed guide on custom ops.
|
266 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
267 |
-
|
268 |
-
If the CustomOp is passed multiple Tensor inputs with different device
|
269 |
-
types, it will dispatch to the registered implementation for the highest
|
270 |
-
priority device type among those present.
|
271 |
-
The supported device types, in order of priority, are {'cuda', 'cpu'}.
|
272 |
-
|
273 |
-
This API is used as a decorator (see examples).
|
274 |
-
|
275 |
-
Arguments:
|
276 |
-
device_types (str or Iterable[str]): the device type(s) to register the function for.
|
277 |
-
|
278 |
-
Examples::
|
279 |
-
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
280 |
-
>>> import numpy as np
|
281 |
-
>>> from torch import Tensor
|
282 |
-
>>>
|
283 |
-
>>> @custom_op("my_library::numpy_cos")
|
284 |
-
>>> def numpy_cos(x: Tensor) -> Tensor:
|
285 |
-
>>> ...
|
286 |
-
>>>
|
287 |
-
>>> # Register an implementation for CPU Tensors
|
288 |
-
>>> @numpy_cos.impl('cpu')
|
289 |
-
>>> def numpy_cos_impl_cpu(x):
|
290 |
-
>>> return torch.from_numpy(np.cos(x.numpy()))
|
291 |
-
>>>
|
292 |
-
>>> # Register an implementation for CUDA Tensors
|
293 |
-
>>> @numpy_cos.impl('cuda')
|
294 |
-
>>> def numpy_cos_impl_cuda(x):
|
295 |
-
>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
|
296 |
-
>>>
|
297 |
-
>>> x = torch.randn(3)
|
298 |
-
>>> numpy_cos(x) # calls numpy_cos_impl_cpu
|
299 |
-
>>>
|
300 |
-
>>> x_cuda = x.cuda()
|
301 |
-
>>> numpy_cos(x) # calls numpy_cos_impl_cuda
|
302 |
-
|
303 |
-
"""
|
304 |
-
if isinstance(device_types, str):
|
305 |
-
device_types = [device_types]
|
306 |
-
for device_type in device_types:
|
307 |
-
validate_device_type(device_type)
|
308 |
-
|
309 |
-
def inner(f):
|
310 |
-
for device_type in set(device_types):
|
311 |
-
self._check_doesnt_have_library_impl(device_type)
|
312 |
-
self._register_impl(device_type, f, stacklevel=_stacklevel)
|
313 |
-
dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
|
314 |
-
library.impl(self._lib, self._opname, dispatch_key)(f)
|
315 |
-
return f
|
316 |
-
|
317 |
-
return inner
|
318 |
-
|
319 |
-
def _check_doesnt_have_library_impl(self, device_type):
|
320 |
-
if self._has_impl(device_type):
|
321 |
-
return
|
322 |
-
key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
|
323 |
-
if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
|
324 |
-
raise RuntimeError(
|
325 |
-
f"impl(..., device_types={device_type}): the operator {self._qualname} "
|
326 |
-
f"already has an implementation for this device type via a "
|
327 |
-
f"pre-existing torch.library or TORCH_LIBRARY registration.")
|
328 |
-
|
329 |
-
def impl_factory(self) -> typing.Callable:
|
330 |
-
r"""Register an implementation for a factory function."""
|
331 |
-
|
332 |
-
def inner(f):
|
333 |
-
self._register_impl("factory", f)
|
334 |
-
library.impl(self._lib, self._opname, "BackendSelect")(f)
|
335 |
-
return f
|
336 |
-
|
337 |
-
return inner
|
338 |
-
|
339 |
-
def impl_abstract(self, _stacklevel=2) -> typing.Callable:
|
340 |
-
r"""Register an abstract implementation for this operator.
|
341 |
-
|
342 |
-
WARNING: please do not use this directly (and instead use the torch._custom_ops
|
343 |
-
APIs). Also please see the following for a detailed guide on custom ops.
|
344 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
345 |
-
|
346 |
-
An "abstract implementation" specifies the behavior of this operator on
|
347 |
-
Tensors that carry no data. Given some input Tensors with certain properties
|
348 |
-
(sizes/strides/storage_offset/device), it specifies what the properties of
|
349 |
-
the output Tensors are.
|
350 |
-
|
351 |
-
The abstract implementation has the same signature as the operator.
|
352 |
-
It is run for both FakeTensors and meta tensors. To write an abstract
|
353 |
-
implementation, assume that all Tensor inputs to the operator are
|
354 |
-
regular CPU/CUDA/Meta tensors, but they do not have storage, and
|
355 |
-
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
|
356 |
-
The abstract implementation must consist of only PyTorch operations
|
357 |
-
(and may not directly access the storage or data of any input or
|
358 |
-
intermediate Tensors).
|
359 |
-
|
360 |
-
This API is used as a decorator (see examples).
|
361 |
-
|
362 |
-
Examples::
|
363 |
-
>>> import numpy as np
|
364 |
-
>>> from torch import Tensor
|
365 |
-
>>>
|
366 |
-
>>> # Example 1: an operator without data-dependent output shape
|
367 |
-
>>> @custom_op('my_library::custom_linear')
|
368 |
-
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
369 |
-
>>> ...
|
370 |
-
>>>
|
371 |
-
>>> @custom_linear.impl_abstract()
|
372 |
-
>>> def custom_linear_abstract(x, weight):
|
373 |
-
>>> assert x.dim() == 2
|
374 |
-
>>> assert weight.dim() == 2
|
375 |
-
>>> assert bias.dim() == 1
|
376 |
-
>>> assert x.shape[1] == weight.shape[1]
|
377 |
-
>>> assert weight.shape[0] == bias.shape[0]
|
378 |
-
>>> assert x.device == weight.device
|
379 |
-
>>>
|
380 |
-
>>> return (x @ weight.t()) + bias
|
381 |
-
>>>
|
382 |
-
>>> # Example 2: an operator with data-dependent output shape
|
383 |
-
>>> @custom_op('my_library::custom_nonzero')
|
384 |
-
>>> def custom_nonzero(x: Tensor) -> Tensor:
|
385 |
-
>>> ...
|
386 |
-
>>>
|
387 |
-
>>> @custom_nonzero.impl_abstract()
|
388 |
-
>>> def custom_nonzero_abstract(x):
|
389 |
-
>>> # Number of nonzero-elements is data-dependent.
|
390 |
-
>>> # Since we cannot peek at the data in an abstract impl,
|
391 |
-
>>> # we use the ctx object to construct a new symint that
|
392 |
-
>>> # represents the data-dependent size.
|
393 |
-
>>> ctx = torch._custom_op.get_ctx()
|
394 |
-
>>> nnz = ctx.create_unbacked_symint()
|
395 |
-
>>> shape = [x.dim(), nnz]
|
396 |
-
>>> result = x.new_empty(shape, dtype=torch.long)
|
397 |
-
>>> return result
|
398 |
-
>>>
|
399 |
-
>>> @custom_nonzero.impl(['cpu', 'cuda'])
|
400 |
-
>>> def custom_nonzero_impl(x):
|
401 |
-
>>> x_np = to_numpy(x)
|
402 |
-
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
403 |
-
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
|
404 |
-
>>> # constrain the range to at least 2
|
405 |
-
>>> if res.shape[0] <= 1:
|
406 |
-
>>> raise RuntimeError("not supported")
|
407 |
-
>>> return torch.tensor(res, device=x.device)
|
408 |
-
|
409 |
-
"""
|
410 |
-
|
411 |
-
def inner(f):
|
412 |
-
self._check_doesnt_have_library_meta_impl()
|
413 |
-
self._register_impl("abstract", f, stacklevel=_stacklevel)
|
414 |
-
location = self._get_impl("abstract").location
|
415 |
-
|
416 |
-
qualname = self._qualname
|
417 |
-
|
418 |
-
# Handle DispatchKey.Meta registration
|
419 |
-
@functools.wraps(f)
|
420 |
-
def f_with_ctx(*args, **kwargs):
|
421 |
-
def error_on_ctx():
|
422 |
-
raise RuntimeError(
|
423 |
-
f"Attempted to call get_ctx() for the meta implementation "
|
424 |
-
f"for {qualname}."
|
425 |
-
f"You have presumably called get_ctx() because the operator "
|
426 |
-
f"has a data-dependent output shape; if so, there is no "
|
427 |
-
f"such meta implementation and this error is the correct "
|
428 |
-
f"behavior. Otherwise, please remove the call to get_ctx() "
|
429 |
-
f"in the implementation registered with impl_abstract "
|
430 |
-
f"at {location}"
|
431 |
-
)
|
432 |
-
|
433 |
-
with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
|
434 |
-
return f(*args, **kwargs)
|
435 |
-
|
436 |
-
self._lib.impl(self._opname, f_with_ctx, "Meta")
|
437 |
-
return f
|
438 |
-
|
439 |
-
return inner
|
440 |
-
|
441 |
-
def _check_can_register_backward(self):
|
442 |
-
def error(detail):
|
443 |
-
raise RuntimeError(
|
444 |
-
f"Cannot use torch._custom_ops APIs to register backward "
|
445 |
-
f"formula for {detail}. Got operator "
|
446 |
-
f"{self._qualname} with schema: {schema}"
|
447 |
-
)
|
448 |
-
|
449 |
-
schema = self._schema
|
450 |
-
if schema.kind() != SchemaKind.functional:
|
451 |
-
error("non-functional operator")
|
452 |
-
|
453 |
-
rets = schema.returns
|
454 |
-
if not schema.returns:
|
455 |
-
error("operator with no returns")
|
456 |
-
|
457 |
-
assert len(rets) > 0
|
458 |
-
is_non_mutating_view = any(
|
459 |
-
r.annotation is not None and not r.annotation.is_write for r in rets
|
460 |
-
)
|
461 |
-
if is_non_mutating_view:
|
462 |
-
error("operator that returns views")
|
463 |
-
|
464 |
-
# We make assumptions about the schema's return types.
|
465 |
-
allowed_return_types = {
|
466 |
-
BaseType(BaseTy.int): "int",
|
467 |
-
BaseType(BaseTy.SymInt): "SymInt",
|
468 |
-
BaseType(BaseTy.bool): "bool",
|
469 |
-
BaseType(BaseTy.float): "float",
|
470 |
-
BaseType(BaseTy.Tensor): "Tensor",
|
471 |
-
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
|
472 |
-
}
|
473 |
-
for ret in schema.returns:
|
474 |
-
if ret.type in allowed_return_types:
|
475 |
-
continue
|
476 |
-
error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
|
477 |
-
|
478 |
-
def _check_doesnt_have_library_autograd_impl(self):
|
479 |
-
if self._registered_autograd_kernel_indirection:
|
480 |
-
return
|
481 |
-
|
482 |
-
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
|
483 |
-
raise RuntimeError(
|
484 |
-
f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
|
485 |
-
f"already has an implementation for this device type via a "
|
486 |
-
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
487 |
-
f"CompositeImplicitAutograd operators do not need an autograd formula; "
|
488 |
-
f"instead, the operator will decompose into its constituents and those "
|
489 |
-
f"can have autograd formulas defined on them.")
|
490 |
-
|
491 |
-
# We can improve this by adding "all Autograd<BACKEND> keys", but
|
492 |
-
# realistically people will just be using this API for CPU/CUDA for now.
|
493 |
-
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
|
494 |
-
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
|
495 |
-
raise RuntimeError(
|
496 |
-
f"impl_backward/impl_save_for_backward: "
|
497 |
-
f"the operator {self._qualname} already has an Autograd kernel "
|
498 |
-
f"registered to DispatchKey::{key} vi a pre-existing "
|
499 |
-
f"torch.library or TORCH_LIBRARY registration. Please either "
|
500 |
-
f"remove those registrations or don't use the torch._custom_ops APIs")
|
501 |
-
|
502 |
-
def _check_doesnt_have_library_meta_impl(self):
|
503 |
-
if self._has_impl("abstract"):
|
504 |
-
return
|
505 |
-
|
506 |
-
# If the user's operator is CompositeExplicitAutograd,
|
507 |
-
# allow them to impl_abstract. This is being pragmatic
|
508 |
-
# (existing custom ops may have CompositeExplicitAutograd
|
509 |
-
# registration that don't work with Meta kernels, so this
|
510 |
-
# gives them an escape hatch).
|
511 |
-
if (
|
512 |
-
_C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
|
513 |
-
and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
|
514 |
-
):
|
515 |
-
return
|
516 |
-
|
517 |
-
# Otherwise, if the user's already has a Meta kernel or their
|
518 |
-
# op is CompositeImplicitAutograd or some other alias dispatch key,
|
519 |
-
# raise.
|
520 |
-
|
521 |
-
# Special case for CompositeImplicitAutograd
|
522 |
-
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
|
523 |
-
raise RuntimeError(
|
524 |
-
f"impl_abstract(...): the operator {self._qualname} "
|
525 |
-
f"already has an implementation for this device type via a "
|
526 |
-
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
527 |
-
f"CompositeImplicitAutograd operators do not need an abstract impl; "
|
528 |
-
f"instead, the operator will decompose into its constituents and those "
|
529 |
-
f"can have abstract impls defined on them.")
|
530 |
-
|
531 |
-
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
|
532 |
-
raise RuntimeError(
|
533 |
-
f"impl_abstract(...): the operator {self._qualname} "
|
534 |
-
f"already has an DispatchKey::Meta implementation via a "
|
535 |
-
f"pre-existing torch.library or TORCH_LIBRARY registration. "
|
536 |
-
f"Please either remove that registration or don't call impl_abstract.")
|
537 |
-
|
538 |
-
# NOTE ["backward", "save_for_backward", and "autograd"]
|
539 |
-
# As a part of the explicit autograd API, a user must provide us
|
540 |
-
# a "save_for_backward" function and a "backward" function.
|
541 |
-
# When both of these have been provided, then we automatically
|
542 |
-
# construct the "autograd" kernel.
|
543 |
-
def _register_autograd_kernel(self):
|
544 |
-
assert self._has_impl("backward")
|
545 |
-
assert self._has_impl("save_for_backward")
|
546 |
-
kernel = construct_autograd_kernel(
|
547 |
-
self._schema,
|
548 |
-
self._output_differentiability,
|
549 |
-
self,
|
550 |
-
get_op(self._qualname),
|
551 |
-
self._get_impl("save_for_backward").func,
|
552 |
-
self._get_impl("backward").func)
|
553 |
-
self._register_impl("autograd", kernel)
|
554 |
-
|
555 |
-
def impl_save_for_backward(self, _stacklevel=2):
|
556 |
-
r"""Register a function that tells us what to save for backward.
|
557 |
-
|
558 |
-
Please see impl_backward for more details.
|
559 |
-
"""
|
560 |
-
def inner(f):
|
561 |
-
self._check_can_register_backward()
|
562 |
-
self._check_doesnt_have_library_autograd_impl()
|
563 |
-
if not self._registered_autograd_kernel_indirection:
|
564 |
-
self._register_autograd_kernel_indirection()
|
565 |
-
self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
|
566 |
-
if self._has_impl("backward"):
|
567 |
-
self._register_autograd_kernel()
|
568 |
-
return inner
|
569 |
-
|
570 |
-
def impl_backward(self, output_differentiability=None, _stacklevel=2):
|
571 |
-
r"""Registers a backward formula.
|
572 |
-
|
573 |
-
WARNING: if you're a user, please do not use this directly
|
574 |
-
(instead use the torch._custom_ops APIs).
|
575 |
-
Also please see the following for a detailed guide on custom ops.
|
576 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
577 |
-
|
578 |
-
In order for the CustomOp to work with autograd, you need to register
|
579 |
-
a backward formula. There are two pieces to this:
|
580 |
-
1. You must give us a function to specify what to save for backward.
|
581 |
-
Call this the "save for backward" function.
|
582 |
-
2. You must give us a function that computes gradients. Call this the
|
583 |
-
"backward" function.
|
584 |
-
|
585 |
-
Use `impl_save_for_backward` to define a "save for backward" function
|
586 |
-
that specifies what gets saved for backward. The function should accept
|
587 |
-
two arguments ``(inputs, output)`` and return the quantities to be saved
|
588 |
-
for backward.
|
589 |
-
|
590 |
-
During runtime, when you call the CustomOp, PyTorch will invoke the
|
591 |
-
"save for backward" function with the inputs and output of the CustomOp.
|
592 |
-
|
593 |
-
Use `impl_backward` to define the "backward" function. The backward
|
594 |
-
function must accept ``(ctx, saved, *grads)``:
|
595 |
-
- ``ctx`` is a context object where we may provide information
|
596 |
-
- ``saved`` is exactly what gets returned from the "save for backward"
|
597 |
-
function
|
598 |
-
- ``grads`` is one or more gradients. The number of gradients matches
|
599 |
-
the number of outputs of the CustomOp.
|
600 |
-
|
601 |
-
The backward function must return a dict that maps the name of
|
602 |
-
an input to the CustomOp to its corresponding gradient. All inputs that
|
603 |
-
were declared to be Tensors in the CustomOp definition must be accounted
|
604 |
-
for in the dict. The gradient may be a Tensor or None.
|
605 |
-
|
606 |
-
"""
|
607 |
-
if output_differentiability is not None:
|
608 |
-
def yell():
|
609 |
-
raise RuntimeError(
|
610 |
-
f"impl_backward(output_differentiability): expected "
|
611 |
-
f"output_differentiability to be a list of bools with "
|
612 |
-
f"length equal to the number of outputs of this CustomOp "
|
613 |
-
f"got: {output_differentiability}")
|
614 |
-
|
615 |
-
if not isinstance(output_differentiability, list):
|
616 |
-
yell()
|
617 |
-
for diff in output_differentiability:
|
618 |
-
if not isinstance(diff, bool):
|
619 |
-
yell()
|
620 |
-
if len(self._schema.returns) != len(output_differentiability):
|
621 |
-
yell()
|
622 |
-
|
623 |
-
def inner(f):
|
624 |
-
self._check_can_register_backward()
|
625 |
-
self._check_doesnt_have_library_autograd_impl()
|
626 |
-
if not self._registered_autograd_kernel_indirection:
|
627 |
-
self._register_autograd_kernel_indirection()
|
628 |
-
self._register_impl("backward", f, stacklevel=_stacklevel)
|
629 |
-
self._output_differentiability = output_differentiability
|
630 |
-
if self._has_impl("save_for_backward"):
|
631 |
-
self._register_autograd_kernel()
|
632 |
-
return inner
|
633 |
-
|
634 |
-
|
635 |
-
@dataclasses.dataclass
|
636 |
-
class FuncAndLocation:
|
637 |
-
func: typing.Callable
|
638 |
-
location: str
|
639 |
-
|
640 |
-
|
641 |
-
def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
|
642 |
-
overload_name = (
|
643 |
-
"" if operator_name.overload_name is None else operator_name.overload_name
|
644 |
-
)
|
645 |
-
return _C._dispatch_find_schema_or_throw(
|
646 |
-
f"{cpp_ns}::{str(operator_name.name)}", overload_name
|
647 |
-
)
|
648 |
-
|
649 |
-
|
650 |
-
def validate_namespace(ns: str) -> None:
|
651 |
-
if "." in ns:
|
652 |
-
raise ValueError(
|
653 |
-
f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
|
654 |
-
f"valid variable name)"
|
655 |
-
)
|
656 |
-
if ns in RESERVED_NS:
|
657 |
-
raise ValueError(
|
658 |
-
f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
|
659 |
-
f"please choose something else. "
|
660 |
-
)
|
661 |
-
|
662 |
-
def validate_schema(schema: FunctionSchema) -> None:
|
663 |
-
if not torch._library.utils.is_functional_schema(schema):
|
664 |
-
raise ValueError(
|
665 |
-
f"custom_op only supports functional operators "
|
666 |
-
f"(ops that do not mutate any inputs, do not return "
|
667 |
-
f"views of the inputs, and has at least one return). "
|
668 |
-
f"Got the following non-functional schema: {schema}"
|
669 |
-
)
|
670 |
-
|
671 |
-
# For simplicity: don't allow self arguments
|
672 |
-
if schema.arguments.self_arg is not None:
|
673 |
-
raise ValueError(
|
674 |
-
f"custom_op does not support arguments named 'self'. Please "
|
675 |
-
f"rename your argument. Got: {schema}"
|
676 |
-
)
|
677 |
-
|
678 |
-
|
679 |
-
def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
|
680 |
-
names = qualname.split("::", 1)
|
681 |
-
if len(names) != 2:
|
682 |
-
raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
|
683 |
-
f"operator name should look something like ns::foo")
|
684 |
-
if '.' in names[1]:
|
685 |
-
raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
|
686 |
-
f"i.e. operator names with '.' in them. "
|
687 |
-
f"Please name your operator something like ns::foo. "
|
688 |
-
f"Got: {qualname}")
|
689 |
-
return names[0], names[1]
|
690 |
-
|
691 |
-
|
692 |
-
def validate_device_type(device_type: str) -> None:
|
693 |
-
if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
|
694 |
-
raise ValueError(
|
695 |
-
f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
|
696 |
-
f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
|
697 |
-
)
|
698 |
-
|
699 |
-
|
700 |
-
def supported_param(param: inspect.Parameter) -> bool:
|
701 |
-
return param.kind in (
|
702 |
-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
703 |
-
inspect.Parameter.KEYWORD_ONLY,
|
704 |
-
)
|
705 |
-
|
706 |
-
|
707 |
-
def validate_function_matches_schema(
|
708 |
-
schema: FunctionSchema, func: typing.Callable
|
709 |
-
) -> None:
|
710 |
-
sig = inspect.signature(func)
|
711 |
-
|
712 |
-
if not all(supported_param(p) for _, p in sig.parameters.items()):
|
713 |
-
raise ValueError(
|
714 |
-
f"custom_op(..., manual_schema)(func): positional-only args, "
|
715 |
-
f"varargs, and kwargs are not supported. Please rewrite `func` "
|
716 |
-
f"to not have them. Got `func` with signature: {sig}"
|
717 |
-
)
|
718 |
-
|
719 |
-
if (
|
720 |
-
any(
|
721 |
-
p.annotation is not inspect.Parameter.empty
|
722 |
-
for _, p in sig.parameters.items()
|
723 |
-
)
|
724 |
-
or sig.return_annotation is not inspect.Signature.empty
|
725 |
-
):
|
726 |
-
raise ValueError(
|
727 |
-
f"custom_op(..., manual_schema)(func): When passing in a manual "
|
728 |
-
f"schema, we expect `func` to have no type annotations to avoid "
|
729 |
-
f"ambiguity. Got `func` with signature: {sig}"
|
730 |
-
)
|
731 |
-
|
732 |
-
positional = [
|
733 |
-
(name, param)
|
734 |
-
for name, param in sig.parameters.items()
|
735 |
-
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
736 |
-
]
|
737 |
-
kwargonly = [
|
738 |
-
(name, param)
|
739 |
-
for name, param in sig.parameters.items()
|
740 |
-
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
741 |
-
]
|
742 |
-
|
743 |
-
def error():
|
744 |
-
raise ValueError(
|
745 |
-
f"custom_op(..., manual_schema)(func): When passing in a manual "
|
746 |
-
f"schema, we expect `func`'s signature to match `manual_schema` "
|
747 |
-
f"(aside from type annotations). "
|
748 |
-
f"func's signature: {sig}, manual_schema: {schema}"
|
749 |
-
)
|
750 |
-
|
751 |
-
def error_default_args():
|
752 |
-
raise ValueError(
|
753 |
-
f"custom_op(..., manual_schema)(func): "
|
754 |
-
f"neither func nor manual_schema should have default "
|
755 |
-
f"arguments. Got "
|
756 |
-
f"func's signature: {sig}, manual_schema: {schema}"
|
757 |
-
)
|
758 |
-
|
759 |
-
def compare(sig_args, schema_args):
|
760 |
-
if len(sig_args) != len(schema_args):
|
761 |
-
error()
|
762 |
-
for (name, param), arg in zip(sig_args, schema_args):
|
763 |
-
if name != arg.name:
|
764 |
-
error()
|
765 |
-
if param.default is not inspect.Parameter.empty or arg.default is not None:
|
766 |
-
error_default_args()
|
767 |
-
|
768 |
-
compare(positional, schema.arguments.flat_positional)
|
769 |
-
compare(kwargonly, schema.arguments.flat_kwarg_only)
|
770 |
-
|
771 |
-
|
772 |
-
def infer_schema(prototype_function: typing.Callable) -> str:
|
773 |
-
sig = inspect.signature(prototype_function)
|
774 |
-
|
775 |
-
def error_fn(what):
|
776 |
-
raise ValueError(
|
777 |
-
f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
|
778 |
-
)
|
779 |
-
|
780 |
-
params = [
|
781 |
-
parse_param(name, param, error_fn) for name, param in sig.parameters.items()
|
782 |
-
]
|
783 |
-
ret = parse_return(sig.return_annotation, error_fn)
|
784 |
-
return f"({', '.join(params)}) -> {ret}"
|
785 |
-
|
786 |
-
|
787 |
-
def parse_param(name, param, error_fn):
|
788 |
-
if not supported_param(param):
|
789 |
-
error_fn("We do not support positional-only args, varargs, or varkwargs.")
|
790 |
-
|
791 |
-
if param.annotation is inspect.Parameter.empty:
|
792 |
-
error_fn(f"Parameter {name} must have a type annotation.")
|
793 |
-
|
794 |
-
if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
|
795 |
-
error_fn(
|
796 |
-
f"Parameter {name} has unsupported type {param.annotation}. "
|
797 |
-
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
798 |
-
)
|
799 |
-
|
800 |
-
if param.default is not inspect.Parameter.empty:
|
801 |
-
error_fn(
|
802 |
-
f"Parameter {name} has a default value; this is not supported. "
|
803 |
-
f"If you want to use default values then create a function with "
|
804 |
-
f"default values that calls the CustomOp"
|
805 |
-
)
|
806 |
-
|
807 |
-
return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
|
808 |
-
|
809 |
-
|
810 |
-
def derived_types(
|
811 |
-
base_type, cpp_type, list_base, optional_base_list, optional_list_base
|
812 |
-
):
|
813 |
-
result = [
|
814 |
-
(base_type, cpp_type),
|
815 |
-
(typing.Optional[base_type], f"{cpp_type}?"),
|
816 |
-
]
|
817 |
-
if list_base:
|
818 |
-
result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type]
|
819 |
-
if optional_base_list:
|
820 |
-
result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type]
|
821 |
-
if optional_list_base:
|
822 |
-
result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type]
|
823 |
-
return result
|
824 |
-
|
825 |
-
|
826 |
-
def get_supported_param_types():
|
827 |
-
data = [
|
828 |
-
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
|
829 |
-
(torch.Tensor, "Tensor", True, True, False),
|
830 |
-
(int, "SymInt", True, False, True),
|
831 |
-
(float, "float", True, False, True),
|
832 |
-
(bool, "bool", True, False, True),
|
833 |
-
(str, "str", False, False, False),
|
834 |
-
(torch.types.Number, "Scalar", True, False, False),
|
835 |
-
(torch.dtype, "ScalarType", False, False, False),
|
836 |
-
(torch.device, "Device", False, False, False),
|
837 |
-
]
|
838 |
-
result = []
|
839 |
-
for line in data:
|
840 |
-
result.extend(derived_types(*line))
|
841 |
-
return dict(result)
|
842 |
-
|
843 |
-
|
844 |
-
SUPPORTED_RETURN_TYPES = {
|
845 |
-
torch.Tensor: "Tensor",
|
846 |
-
typing.List[torch.Tensor]: "Tensor[]",
|
847 |
-
int: "SymInt",
|
848 |
-
float: "float",
|
849 |
-
bool: "bool",
|
850 |
-
torch.types.Number: "Scalar",
|
851 |
-
}
|
852 |
-
|
853 |
-
|
854 |
-
def parse_return(annotation, error_fn):
|
855 |
-
origin = typing.get_origin(annotation)
|
856 |
-
if origin is not tuple:
|
857 |
-
if annotation not in SUPPORTED_RETURN_TYPES.keys():
|
858 |
-
error_fn(
|
859 |
-
f"Return has unsupported type {annotation}. "
|
860 |
-
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
861 |
-
)
|
862 |
-
return SUPPORTED_RETURN_TYPES[annotation]
|
863 |
-
|
864 |
-
args = typing.get_args(annotation)
|
865 |
-
for arg in args:
|
866 |
-
if arg not in SUPPORTED_RETURN_TYPES:
|
867 |
-
error_fn(
|
868 |
-
f"Return has unsupported type {annotation}. "
|
869 |
-
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
870 |
-
)
|
871 |
-
|
872 |
-
return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
|
873 |
-
|
874 |
-
|
875 |
-
SUPPORTED_PARAM_TYPES = get_supported_param_types()
|
876 |
-
|
877 |
-
|
878 |
-
def report_error_callback(custom_op: typing.Any, key: str) -> None:
|
879 |
-
if key == "Undefined":
|
880 |
-
raise NotImplementedError(
|
881 |
-
f"{custom_op}: There were no Tensor inputs to this operator "
|
882 |
-
f"(e.g. you passed an empty list of Tensors). If your operator is a "
|
883 |
-
f"factory function (that is, it takes no Tensors and constructs "
|
884 |
-
f"a new one), then please use CustomOp.impl_factory to register "
|
885 |
-
f"an implementation for it"
|
886 |
-
)
|
887 |
-
if key == "Meta":
|
888 |
-
raise NotImplementedError(
|
889 |
-
f"{custom_op}: when running with device='Meta' tensors: there is no "
|
890 |
-
f"abstract impl registered for this CustomOp. Please register one via "
|
891 |
-
f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
|
892 |
-
)
|
893 |
-
if key in ("CPU", "CUDA"):
|
894 |
-
device = key.lower()
|
895 |
-
raise NotImplementedError(
|
896 |
-
f"{custom_op}: when running with device='{device}' tensors: there is no "
|
897 |
-
f"{device} impl registered for this CustomOp. Please register one via "
|
898 |
-
f"CustomOp.impl(device_type='{device}')"
|
899 |
-
)
|
900 |
-
raise NotImplementedError(
|
901 |
-
f"{custom_op}: No implementation for dispatch key {key}. It is likely "
|
902 |
-
f"that we have not added this functionality yet, please either open an "
|
903 |
-
f"issue or if you're feeling adventurous, use the low-level "
|
904 |
-
f"torch.library API"
|
905 |
-
)
|
906 |
-
|
907 |
-
|
908 |
-
def custom_op_from_existing(op):
|
909 |
-
ns = op.namespace
|
910 |
-
lib = torch.library.Library(ns, "FRAGMENT")
|
911 |
-
name = op.name().split("::")[-1]
|
912 |
-
schema_str = str(op._schema)
|
913 |
-
# CustomOp expects the schema string without the namespace
|
914 |
-
schema_str = schema_str.split("::")[-1]
|
915 |
-
schema = FunctionSchema.parse(schema_str)
|
916 |
-
return CustomOp(lib, ns, schema, name, op, _private_access=True)
|
917 |
-
|
918 |
-
|
919 |
-
def get_op(qualname):
|
920 |
-
def error_not_found():
|
921 |
-
raise ValueError(
|
922 |
-
f"Could not find the operator {qualname}. Please make sure you have "
|
923 |
-
f"already registered the operator and (if registered from C++) "
|
924 |
-
f"loaded it via torch.ops.load_library.")
|
925 |
-
|
926 |
-
ns, name = parse_qualname(qualname)
|
927 |
-
if not hasattr(torch.ops, ns):
|
928 |
-
error_not_found()
|
929 |
-
opnamespace = getattr(torch.ops, ns)
|
930 |
-
if not hasattr(opnamespace, name):
|
931 |
-
error_not_found()
|
932 |
-
packet = getattr(opnamespace, name)
|
933 |
-
if not hasattr(packet, 'default'):
|
934 |
-
error_not_found()
|
935 |
-
return packet.default
|
936 |
-
|
937 |
-
|
938 |
-
def _find_custom_op(qualname, also_check_torch_library=False):
|
939 |
-
if qualname in global_registry:
|
940 |
-
return global_registry[qualname]
|
941 |
-
if not also_check_torch_library:
|
942 |
-
raise RuntimeError(
|
943 |
-
f"Could not find custom op \"{qualname}\". Did you register it via "
|
944 |
-
f"the torch._custom_ops API?")
|
945 |
-
overload = get_op(qualname)
|
946 |
-
result = custom_op_from_existing(overload)
|
947 |
-
return result
|
948 |
-
|
949 |
-
|
950 |
-
def get_abstract_impl(qualname):
|
951 |
-
if qualname not in torch._custom_op.impl.global_registry:
|
952 |
-
return None
|
953 |
-
custom_op = torch._custom_op.impl.global_registry[qualname]
|
954 |
-
if custom_op is None:
|
955 |
-
return None
|
956 |
-
if not custom_op._has_impl("abstract"):
|
957 |
-
return None
|
958 |
-
return custom_op._get_impl("abstract").func
|
959 |
-
|
960 |
-
|
961 |
-
def _custom_op_with_schema(qualname, schema):
|
962 |
-
ns, name = qualname.split("::")
|
963 |
-
schema_str = f"{name}{schema}"
|
964 |
-
function_schema = FunctionSchema.parse(schema_str)
|
965 |
-
validate_schema(function_schema)
|
966 |
-
|
967 |
-
lib = library.Library(ns, "FRAGMENT")
|
968 |
-
lib.define(schema_str)
|
969 |
-
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
970 |
-
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
|
971 |
-
result._register_autograd_kernel_indirection()
|
972 |
-
|
973 |
-
torch._C._dispatch_set_report_error_callback(
|
974 |
-
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
975 |
-
)
|
976 |
-
return get_op(qualname)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_custom_ops.py
DELETED
@@ -1,322 +0,0 @@
|
|
1 |
-
import inspect
|
2 |
-
|
3 |
-
from torch._custom_op.impl import (
|
4 |
-
_custom_op_with_schema,
|
5 |
-
_find_custom_op,
|
6 |
-
infer_schema,
|
7 |
-
parse_qualname,
|
8 |
-
validate_namespace,
|
9 |
-
)
|
10 |
-
from torch.library import get_ctx
|
11 |
-
|
12 |
-
__all__ = [
|
13 |
-
"custom_op",
|
14 |
-
"impl",
|
15 |
-
"impl_abstract",
|
16 |
-
"get_ctx",
|
17 |
-
"impl_save_for_backward",
|
18 |
-
"impl_backward",
|
19 |
-
]
|
20 |
-
|
21 |
-
|
22 |
-
def custom_op(qualname, func_or_schema=None):
|
23 |
-
r"""Register a new custom operator
|
24 |
-
|
25 |
-
In PyTorch, defining an op (short for "operator") is a two step-process:
|
26 |
-
- we need to define the op (by providing an operator name and schema)
|
27 |
-
- we need to implement behavior for how the operator interacts with
|
28 |
-
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
|
29 |
-
|
30 |
-
This entrypoint defines the custom operator (the first step)
|
31 |
-
you must then perform the second step by calling various
|
32 |
-
``impl_*`` APIs.
|
33 |
-
|
34 |
-
This API may be used as a decorator (see examples).
|
35 |
-
|
36 |
-
For a detailed guide on custom ops, please see
|
37 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
38 |
-
|
39 |
-
Arguments:
|
40 |
-
qualname (str): Should be a string that looks like
|
41 |
-
"namespace::operator_name". Operators in PyTorch need a namespace to
|
42 |
-
avoid name collisions; a given operator may only be created once.
|
43 |
-
If you are writing a Python library, we recommend the namespace to
|
44 |
-
be the name of your top-level module.
|
45 |
-
func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
|
46 |
-
schema that tells PyTorch the types of the inputs/outputs.
|
47 |
-
If this is a Callable, we will automatically infer the schema from
|
48 |
-
the type annotations on the function (see examples). Otherwise,
|
49 |
-
if you don't want to use type annotations, you may provide us the
|
50 |
-
schema string.
|
51 |
-
|
52 |
-
Example::
|
53 |
-
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
54 |
-
>>> import torch
|
55 |
-
>>> import numpy as np
|
56 |
-
>>> from torch import Tensor
|
57 |
-
>>>
|
58 |
-
>>> # Step 1: define the custom op.
|
59 |
-
>>> # We need to provide the API a "prototype function"
|
60 |
-
>>> # (a function that returns NotImplementedError), from which
|
61 |
-
>>> # we will infer the types of the inputs and outputs.
|
62 |
-
>>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
|
63 |
-
>>> def numpy_sin(x: Tensor) -> Tensor:
|
64 |
-
>>> raise NotImplementedError()
|
65 |
-
>>>
|
66 |
-
>>> # The custom op is now accessible via the torch.ops module:
|
67 |
-
>>> torch.ops.mylibrary.numpy_sin
|
68 |
-
>>>
|
69 |
-
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
70 |
-
>>>
|
71 |
-
>>> # Register an implementation for CPU tensors
|
72 |
-
>>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
|
73 |
-
>>> def numpy_sin_impl_cpu(x):
|
74 |
-
>>> return torch.from_numpy(np.sin(x.numpy()))
|
75 |
-
>>>
|
76 |
-
>>> # Register an implementation for CUDA tensors
|
77 |
-
>>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
|
78 |
-
>>> def numpy_sin_impl_cuda(x):
|
79 |
-
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
|
80 |
-
>>>
|
81 |
-
>>> x = torch.randn(3)
|
82 |
-
>>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
|
83 |
-
>>>
|
84 |
-
>>> x_cuda = x.cuda()
|
85 |
-
>>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
|
86 |
-
|
87 |
-
"""
|
88 |
-
ns, name = parse_qualname(qualname)
|
89 |
-
validate_namespace(ns)
|
90 |
-
|
91 |
-
def inner(func):
|
92 |
-
if not inspect.isfunction(func):
|
93 |
-
raise ValueError(
|
94 |
-
f"custom_op(...)(func): Expected `func` to be a Python "
|
95 |
-
f"function, got: {type(func)}"
|
96 |
-
)
|
97 |
-
|
98 |
-
if func.__name__ != name:
|
99 |
-
raise ValueError(
|
100 |
-
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
|
101 |
-
f"to have name '{name}' but got '{func.__name__}'. "
|
102 |
-
f"Please either change the name of `func` or the qualname that "
|
103 |
-
f"is passed to `custom_op`"
|
104 |
-
)
|
105 |
-
|
106 |
-
schema = infer_schema(func)
|
107 |
-
_custom_op_with_schema(qualname, schema)
|
108 |
-
return func
|
109 |
-
|
110 |
-
if func_or_schema is None:
|
111 |
-
return inner
|
112 |
-
if isinstance(func_or_schema, str):
|
113 |
-
_custom_op_with_schema(qualname, func_or_schema)
|
114 |
-
else:
|
115 |
-
return inner(func_or_schema)
|
116 |
-
|
117 |
-
|
118 |
-
def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
|
119 |
-
r"""Register an implementation for a device type for this custom op.
|
120 |
-
|
121 |
-
If the op is passed multiple Tensor inputs with different device
|
122 |
-
types, it will dispatch to the registered implementation for the highest
|
123 |
-
priority device type among those present.
|
124 |
-
The supported device types, in order of priority, are {'cuda', 'cpu'}.
|
125 |
-
|
126 |
-
This API may be used as a decorator (see examples).
|
127 |
-
|
128 |
-
For a detailed guide on custom ops, please see
|
129 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
130 |
-
|
131 |
-
Arguments:
|
132 |
-
device_types (str or Iterable[str]): the device type(s) to register the function for.
|
133 |
-
|
134 |
-
Example::
|
135 |
-
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
136 |
-
>>> import torch
|
137 |
-
>>> import numpy as np
|
138 |
-
>>> from torch import Tensor
|
139 |
-
>>>
|
140 |
-
>>> # Step 1: define the custom op.
|
141 |
-
>>> # We need to provide the API a "prototype function"
|
142 |
-
>>> # (a function that returns NotImplementedError), from which
|
143 |
-
>>> # we will infer the types of the inputs and outputs.
|
144 |
-
>>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
|
145 |
-
>>> def numpy_cos(x: Tensor) -> Tensor:
|
146 |
-
>>> raise NotImplementedError()
|
147 |
-
>>>
|
148 |
-
>>> # The custom op is now accessible via the torch.ops module:
|
149 |
-
>>> torch.ops.mylibrary.numpy_cos
|
150 |
-
>>>
|
151 |
-
>>> # Step 2: Register an implementation for various PyTorch subsystems
|
152 |
-
>>>
|
153 |
-
>>> # Register an implementation for CPU tensors
|
154 |
-
>>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
|
155 |
-
>>> def numpy_cos_impl_cpu(x):
|
156 |
-
>>> return torch.from_numpy(np.cos(x.numpy()))
|
157 |
-
>>>
|
158 |
-
>>> # Register an implementation for CUDA tensors
|
159 |
-
>>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
|
160 |
-
>>> def numpy_cos_impl_cuda(x):
|
161 |
-
>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
|
162 |
-
>>>
|
163 |
-
>>> x = torch.randn(3)
|
164 |
-
>>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
|
165 |
-
>>>
|
166 |
-
>>> x_cuda = x.cuda()
|
167 |
-
>>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
|
168 |
-
|
169 |
-
"""
|
170 |
-
|
171 |
-
def inner(func):
|
172 |
-
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
173 |
-
custom_op.impl(device_types, _stacklevel=3)(func)
|
174 |
-
return func
|
175 |
-
|
176 |
-
if func is None:
|
177 |
-
return inner
|
178 |
-
return inner(func)
|
179 |
-
|
180 |
-
|
181 |
-
def impl_abstract(qualname, *, func=None):
|
182 |
-
r"""Register an abstract implementation for this operator.
|
183 |
-
|
184 |
-
An "abstract implementation" specifies the behavior of this operator on
|
185 |
-
Tensors that carry no data. Given some input Tensors with certain properties
|
186 |
-
(sizes/strides/storage_offset/device), it specifies what the properties of
|
187 |
-
the output Tensors are.
|
188 |
-
|
189 |
-
The abstract implementation has the same signature as the operator.
|
190 |
-
It is run for both FakeTensors and meta tensors. To write an abstract
|
191 |
-
implementation, assume that all Tensor inputs to the operator are
|
192 |
-
regular CPU/CUDA/Meta tensors, but they do not have storage, and
|
193 |
-
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
|
194 |
-
The abstract implementation must consist of only PyTorch operations
|
195 |
-
(and may not directly access the storage or data of any input or
|
196 |
-
intermediate Tensors).
|
197 |
-
|
198 |
-
This API may be used as a decorator (see examples).
|
199 |
-
|
200 |
-
For a detailed guide on custom ops, please see
|
201 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
202 |
-
|
203 |
-
Examples::
|
204 |
-
>>> import numpy as np
|
205 |
-
>>> from torch import Tensor
|
206 |
-
>>>
|
207 |
-
>>> # Example 1: an operator without data-dependent output shape
|
208 |
-
>>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
|
209 |
-
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
210 |
-
>>> raise NotImplementedError()
|
211 |
-
>>>
|
212 |
-
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
|
213 |
-
>>> def custom_linear_abstract(x, weight):
|
214 |
-
>>> assert x.dim() == 2
|
215 |
-
>>> assert weight.dim() == 2
|
216 |
-
>>> assert bias.dim() == 1
|
217 |
-
>>> assert x.shape[1] == weight.shape[1]
|
218 |
-
>>> assert weight.shape[0] == bias.shape[0]
|
219 |
-
>>> assert x.device == weight.device
|
220 |
-
>>>
|
221 |
-
>>> return (x @ weight.t()) + bias
|
222 |
-
>>>
|
223 |
-
>>> # Example 2: an operator with data-dependent output shape
|
224 |
-
>>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
|
225 |
-
>>> def custom_nonzero(x: Tensor) -> Tensor:
|
226 |
-
>>> ...
|
227 |
-
>>>
|
228 |
-
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
|
229 |
-
>>> def custom_nonzero_abstract(x):
|
230 |
-
>>> # Number of nonzero-elements is data-dependent.
|
231 |
-
>>> # Since we cannot peek at the data in an abstract impl,
|
232 |
-
>>> # we use the ctx object to construct a new symint that
|
233 |
-
>>> # represents the data-dependent size.
|
234 |
-
>>> ctx = torch._custom_ops.get_ctx()
|
235 |
-
>>> nnz = ctx.create_unbacked_symint()
|
236 |
-
>>> shape = [x.dim(), nnz]
|
237 |
-
>>> result = x.new_empty(shape, dtype=torch.long)
|
238 |
-
>>> return result
|
239 |
-
>>>
|
240 |
-
>>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
|
241 |
-
>>> def custom_nonzero_impl(x):
|
242 |
-
>>> x_np = to_numpy(x)
|
243 |
-
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
244 |
-
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
|
245 |
-
>>> # constrain the range to at least 2
|
246 |
-
>>> if res.shape[0] <= 1:
|
247 |
-
>>> raise RuntimeError("not supported")
|
248 |
-
>>> return torch.tensor(res, device=x.device)
|
249 |
-
|
250 |
-
"""
|
251 |
-
import torch.library
|
252 |
-
|
253 |
-
return torch.library.impl_abstract(qualname, func, _stacklevel=2)
|
254 |
-
|
255 |
-
|
256 |
-
def impl_save_for_backward(qualname, *, func=None):
|
257 |
-
r"""Register a function that tells us what to save for backward.
|
258 |
-
|
259 |
-
Please see :func:`impl_backward` for more details.
|
260 |
-
"""
|
261 |
-
|
262 |
-
def inner(func):
|
263 |
-
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
264 |
-
custom_op.impl_save_for_backward(_stacklevel=3)(func)
|
265 |
-
return func
|
266 |
-
|
267 |
-
if func is None:
|
268 |
-
return inner
|
269 |
-
return inner(func)
|
270 |
-
|
271 |
-
|
272 |
-
def impl_backward(qualname, output_differentiability=None, *, func=None):
|
273 |
-
r"""Registers a backward formula for an operator.
|
274 |
-
|
275 |
-
In order for an operator to work with autograd, you need to register
|
276 |
-
a backward formula. There are two pieces to this:
|
277 |
-
1. You must give us a function to specify what to save for backward.
|
278 |
-
Call this the "save for backward" function.
|
279 |
-
2. You must give us a function that computes gradients. Call this the
|
280 |
-
"backward" function.
|
281 |
-
|
282 |
-
Use `impl_save_for_backward` to define a "save for backward" function
|
283 |
-
that specifies what gets saved for backward. The function should accept
|
284 |
-
two arguments ``(inputs, output)`` and return the quantities to be saved
|
285 |
-
for backward.
|
286 |
-
|
287 |
-
During runtime, when you call the operator in a forwards pass, PyTorch
|
288 |
-
will invoke the "save for backward" function with the inputs and output
|
289 |
-
of the operator.
|
290 |
-
|
291 |
-
Use `impl_backward` to define the "backward" function. The backward
|
292 |
-
function must accept ``(ctx, saved, *grads)``:
|
293 |
-
- ``ctx`` is a context object where we may provide information
|
294 |
-
- ``saved`` is exactly what gets returned from the "save for backward"
|
295 |
-
function
|
296 |
-
- ``grads`` is one or more gradients. The number of gradients matches
|
297 |
-
the number of outputs of the operator.
|
298 |
-
|
299 |
-
The backward function must return a dict that maps the name of
|
300 |
-
an input to the operator to its corresponding gradient. All inputs that
|
301 |
-
were declared to be Tensors in the operator definition must be accounted
|
302 |
-
for in the dict. The gradient may be a Tensor or None.
|
303 |
-
|
304 |
-
For a detailed guide on custom ops, please see
|
305 |
-
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
|
306 |
-
|
307 |
-
"""
|
308 |
-
|
309 |
-
def inner(func):
|
310 |
-
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
311 |
-
custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
|
312 |
-
return func
|
313 |
-
|
314 |
-
if func is None:
|
315 |
-
return inner
|
316 |
-
return inner(func)
|
317 |
-
|
318 |
-
|
319 |
-
def _destroy(qualname):
|
320 |
-
"""De-registers a custom op. For testing purposes only"""
|
321 |
-
custom_op = _find_custom_op(qualname)
|
322 |
-
custom_op._destroy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_decomp/__init__.py
DELETED
@@ -1,444 +0,0 @@
|
|
1 |
-
import inspect
|
2 |
-
from collections import defaultdict
|
3 |
-
from functools import wraps
|
4 |
-
from itertools import chain
|
5 |
-
from typing import Callable, Dict, List, Sequence, Union
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.library
|
9 |
-
from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
|
10 |
-
from torch._prims_common import CustomOutParamAnnotation
|
11 |
-
from torch.utils import _pytree as pytree
|
12 |
-
|
13 |
-
__all__ = [
|
14 |
-
"decomposition_table",
|
15 |
-
"pre_autograd_decomposition_table",
|
16 |
-
"meta_table",
|
17 |
-
"register_decomposition",
|
18 |
-
"get_decompositions",
|
19 |
-
"core_aten_decompositions",
|
20 |
-
]
|
21 |
-
|
22 |
-
|
23 |
-
# TODO: relax key type here; torch registrations should be possible to; but
|
24 |
-
# right now this type is accurate
|
25 |
-
global_decomposition_table: Dict[
|
26 |
-
str, Dict[torch._ops.OperatorBase, Callable]
|
27 |
-
] = defaultdict(dict)
|
28 |
-
|
29 |
-
decomposition_table = global_decomposition_table["post_autograd"]
|
30 |
-
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
|
31 |
-
meta_table = global_decomposition_table["meta"]
|
32 |
-
|
33 |
-
|
34 |
-
def _add_op_to_registry(registry, op, fn):
|
35 |
-
"""
|
36 |
-
This is an internal API for adding an op to the decomposition table.
|
37 |
-
|
38 |
-
If op is OpOverload, it will be added to the registry directly.
|
39 |
-
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
|
40 |
-
"""
|
41 |
-
overloads: List[Union[torch._ops.OperatorBase]] = []
|
42 |
-
if isinstance(op, HigherOrderOperator):
|
43 |
-
# There's no concept of overloads for HigherOrderOperator
|
44 |
-
registry[op] = fn
|
45 |
-
return
|
46 |
-
elif isinstance(op, OpOverload):
|
47 |
-
overloads.append(op)
|
48 |
-
else:
|
49 |
-
assert isinstance(op, OpOverloadPacket)
|
50 |
-
for ol in op.overloads():
|
51 |
-
overloads.append(getattr(op, ol))
|
52 |
-
|
53 |
-
for op_overload in overloads:
|
54 |
-
if op_overload in registry:
|
55 |
-
raise RuntimeError(f"duplicate registrations for {op_overload}")
|
56 |
-
# TorchScript dumps a bunch of extra nonsense overloads
|
57 |
-
# which don't have corresponding dispatcher entries, we need
|
58 |
-
# to filter those out, e.g aten.add.float_int
|
59 |
-
if torch._C._dispatch_has_kernel(op_overload.name()):
|
60 |
-
registry[op_overload] = fn
|
61 |
-
|
62 |
-
|
63 |
-
def _convert_out_params(f):
|
64 |
-
out_annotation = f.__annotations__.get("out")
|
65 |
-
|
66 |
-
# If there are no out params, do not wrap the function.
|
67 |
-
if not out_annotation:
|
68 |
-
return f
|
69 |
-
|
70 |
-
# Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
|
71 |
-
if getattr(out_annotation, "__origin__", None) is tuple:
|
72 |
-
sig = inspect.signature(f)
|
73 |
-
out_names = sig.return_annotation._fields
|
74 |
-
# If out is a tuple, we need to register a function that unpacks all the out
|
75 |
-
# elements as this is what native_functions.yaml expects
|
76 |
-
|
77 |
-
@wraps(f)
|
78 |
-
def _fn(*args, **kwargs):
|
79 |
-
out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
|
80 |
-
# Either all of the out kwargs are set or none of them
|
81 |
-
is_none = out_kwargs[0] is None
|
82 |
-
assert all((o is None) == is_none for o in out_kwargs)
|
83 |
-
return f(*args, **kwargs, out=None if is_none else out_kwargs)
|
84 |
-
|
85 |
-
out_params = [
|
86 |
-
inspect.Parameter(
|
87 |
-
o,
|
88 |
-
kind=inspect.Parameter.KEYWORD_ONLY,
|
89 |
-
default=None,
|
90 |
-
annotation=t,
|
91 |
-
)
|
92 |
-
for o, t in zip(out_names, out_annotation.__args__)
|
93 |
-
]
|
94 |
-
# Drop the out parameter and concatenate the new kwargs in the signature
|
95 |
-
params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
|
96 |
-
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
97 |
-
parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
|
98 |
-
)
|
99 |
-
# Drop the out parameter and concatenate the new kwargs in the annotations
|
100 |
-
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
|
101 |
-
for o in out_params:
|
102 |
-
_fn.__annotations__[o.name] = o.annotation
|
103 |
-
|
104 |
-
# Propagate that this function is wrapped by `out_wrapper`
|
105 |
-
_fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
|
106 |
-
|
107 |
-
return _fn
|
108 |
-
|
109 |
-
# Alternatively, there may be a single tensor out parameter with a name
|
110 |
-
# other than "out". This will need special treatment and is indicated by an
|
111 |
-
# annotation, which we will remove here so it is not exposed after wrapping.
|
112 |
-
custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
|
113 |
-
if custom_out_param_name:
|
114 |
-
|
115 |
-
@wraps(f)
|
116 |
-
def _fn(*args, **kwargs):
|
117 |
-
out_kwarg = kwargs.pop(custom_out_param_name, None)
|
118 |
-
return f(*args, **kwargs, out=out_kwarg)
|
119 |
-
|
120 |
-
out_param = inspect.Parameter(
|
121 |
-
custom_out_param_name,
|
122 |
-
kind=inspect.Parameter.KEYWORD_ONLY,
|
123 |
-
default=None,
|
124 |
-
annotation=out_annotation,
|
125 |
-
)
|
126 |
-
|
127 |
-
# Drop the out parameter and concatenate the new kwarg in the signature
|
128 |
-
sig = inspect.signature(f)
|
129 |
-
params = chain(
|
130 |
-
(v for k, v in sig.parameters.items() if k != "out"), (out_param,)
|
131 |
-
)
|
132 |
-
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
133 |
-
parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
|
134 |
-
)
|
135 |
-
|
136 |
-
# Drop the out parameter and concatenate the new kwargs in the annotations
|
137 |
-
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
|
138 |
-
_fn.__annotations__[out_param.name] = out_param.annotation
|
139 |
-
|
140 |
-
return _fn
|
141 |
-
|
142 |
-
return f
|
143 |
-
|
144 |
-
|
145 |
-
def register_decomposition(
|
146 |
-
aten_op, registry=None, *, type="post_autograd", unsafe=False
|
147 |
-
):
|
148 |
-
"""
|
149 |
-
A decorator to register a function as a decomposition to the Python
|
150 |
-
decomposition table. Use it like this::
|
151 |
-
|
152 |
-
@register_decomposition(torch.ops.aten.clamp_min)
|
153 |
-
def clamp_min(x):
|
154 |
-
return torch.clamp(self, min=min)
|
155 |
-
|
156 |
-
If you are writing a new decomposition, consider contributing it
|
157 |
-
directly to PyTorch in torch._decomp.decompositions.
|
158 |
-
|
159 |
-
This API is experimental; we are almost certainly going to extend
|
160 |
-
the API when we make decompositions eligible for use in transforms (e.g.,
|
161 |
-
autograd) and not just backend tracing, where we then need to know if a
|
162 |
-
decomposition can be used to simulate a transform.
|
163 |
-
|
164 |
-
By default, we also will register it to the Meta key of dispatcher,
|
165 |
-
and replace the c++ Meta implementation if there is already one.
|
166 |
-
|
167 |
-
unsafe kwarg is for reuse of this function for registering non-function
|
168 |
-
things
|
169 |
-
"""
|
170 |
-
|
171 |
-
assert type in {"post_autograd", "pre_autograd", "meta"}
|
172 |
-
|
173 |
-
def decomposition_decorator(fn: Callable) -> Callable:
|
174 |
-
if not unsafe:
|
175 |
-
fn = _convert_out_params(fn)
|
176 |
-
|
177 |
-
nonlocal registry
|
178 |
-
if registry is None:
|
179 |
-
registry = global_decomposition_table[type]
|
180 |
-
|
181 |
-
def register(op):
|
182 |
-
_add_op_to_registry(registry, op, fn)
|
183 |
-
|
184 |
-
# To handle allowing multiple aten_ops at once
|
185 |
-
pytree.tree_map_(register, aten_op)
|
186 |
-
return fn
|
187 |
-
|
188 |
-
return decomposition_decorator
|
189 |
-
|
190 |
-
|
191 |
-
def get_decompositions(
|
192 |
-
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
|
193 |
-
type: str = "post_autograd",
|
194 |
-
) -> Dict[torch._ops.OperatorBase, Callable]:
|
195 |
-
"""
|
196 |
-
Retrieve a dictionary of decompositions corresponding to the list of
|
197 |
-
operator overloads and overload packets passed as input. Overload
|
198 |
-
packets will include all decomposed overloads in the packet. If there is
|
199 |
-
no decomposition for a requested operator, it is silently ignored.
|
200 |
-
|
201 |
-
This API is experimental; we are almost certainly going to give an alternate,
|
202 |
-
more recommended formulation, where a user provides the set of operators
|
203 |
-
they know how to implement, and we provide decompositions for everything
|
204 |
-
not in this set.
|
205 |
-
"""
|
206 |
-
assert type in {"post_autograd", "pre_autograd", "meta"}
|
207 |
-
|
208 |
-
registry = global_decomposition_table[type]
|
209 |
-
packets_to_overloads = defaultdict(list)
|
210 |
-
for opo in registry:
|
211 |
-
if isinstance(opo, (OpOverload, OpOverloadPacket)):
|
212 |
-
packets_to_overloads[opo.overloadpacket].append(opo)
|
213 |
-
decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
|
214 |
-
for op in aten_ops:
|
215 |
-
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
|
216 |
-
for op_overload in packets_to_overloads[op]:
|
217 |
-
decompositions[op_overload] = registry[op_overload]
|
218 |
-
elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
|
219 |
-
decompositions[op] = registry[op]
|
220 |
-
return decompositions
|
221 |
-
|
222 |
-
|
223 |
-
def remove_decompositions(
|
224 |
-
decompositions: Dict[torch._ops.OperatorBase, Callable],
|
225 |
-
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
|
226 |
-
) -> None:
|
227 |
-
"""
|
228 |
-
Given a dictionary of decompositions obtained from get_decompositions(), removes
|
229 |
-
operators associated with a list of operator overloads and overload packets passed
|
230 |
-
as input. If the decomposition dictionary does not contain a decomposition that is
|
231 |
-
specified to be removed, it is silently ignored.
|
232 |
-
"""
|
233 |
-
for op in aten_ops:
|
234 |
-
if isinstance(op, OpOverloadPacket):
|
235 |
-
for overload_name in op.overloads():
|
236 |
-
opo = getattr(op, overload_name)
|
237 |
-
decompositions.pop(opo, None)
|
238 |
-
elif isinstance(op, OpOverload):
|
239 |
-
decompositions.pop(op, None)
|
240 |
-
|
241 |
-
|
242 |
-
# populate the table
|
243 |
-
import torch._decomp.decompositions
|
244 |
-
import torch._refs
|
245 |
-
|
246 |
-
|
247 |
-
# See NOTE [Core ATen Ops]
|
248 |
-
#
|
249 |
-
# list was copied from torch/_inductor/decomposition.py
|
250 |
-
# excluding decompositions that results in prim ops
|
251 |
-
# Resulting opset of decomposition is core aten ops
|
252 |
-
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
253 |
-
aten = torch.ops.aten
|
254 |
-
return get_decompositions(
|
255 |
-
[
|
256 |
-
aten.addcdiv,
|
257 |
-
aten.addcdiv_,
|
258 |
-
aten.addcmul,
|
259 |
-
aten.addcmul_,
|
260 |
-
aten.addr,
|
261 |
-
aten.affine_grid_generator,
|
262 |
-
aten.all,
|
263 |
-
aten.aminmax,
|
264 |
-
aten.arange.default,
|
265 |
-
aten.arange.start,
|
266 |
-
aten.avg_pool2d_backward,
|
267 |
-
aten.baddbmm,
|
268 |
-
aten.binary_cross_entropy,
|
269 |
-
aten.binary_cross_entropy_backward,
|
270 |
-
aten.binary_cross_entropy_with_logits,
|
271 |
-
aten.celu,
|
272 |
-
aten.celu_,
|
273 |
-
aten.clamp_max,
|
274 |
-
aten.clamp_min,
|
275 |
-
aten.col2im,
|
276 |
-
aten.count_nonzero,
|
277 |
-
aten.cudnn_batch_norm,
|
278 |
-
aten.cudnn_batch_norm_backward,
|
279 |
-
aten.deg2rad,
|
280 |
-
aten.deg2rad_,
|
281 |
-
aten.detach,
|
282 |
-
aten.diag_embed,
|
283 |
-
aten.diagonal_backward,
|
284 |
-
aten.dot,
|
285 |
-
aten.vdot,
|
286 |
-
aten.elu,
|
287 |
-
aten.elu_,
|
288 |
-
aten.elu_backward,
|
289 |
-
aten._embedding_bag,
|
290 |
-
aten.embedding_dense_backward,
|
291 |
-
aten.empty_like,
|
292 |
-
aten._euclidean_dist.default,
|
293 |
-
aten.expand_as,
|
294 |
-
aten.eye,
|
295 |
-
aten.fill,
|
296 |
-
aten.fill_,
|
297 |
-
aten.floor_divide,
|
298 |
-
aten.frac,
|
299 |
-
aten.frac_,
|
300 |
-
aten._fused_moving_avg_obs_fq_helper,
|
301 |
-
aten.gelu_,
|
302 |
-
aten.gelu_backward,
|
303 |
-
aten.glu,
|
304 |
-
aten.glu_backward,
|
305 |
-
aten.hardshrink,
|
306 |
-
aten.hardsigmoid,
|
307 |
-
aten.hardsigmoid_,
|
308 |
-
aten.hardsigmoid_backward,
|
309 |
-
aten.hardswish,
|
310 |
-
aten.hardswish_,
|
311 |
-
aten.hardswish_backward,
|
312 |
-
aten.hardtanh_,
|
313 |
-
aten.hardtanh_backward,
|
314 |
-
aten.heaviside,
|
315 |
-
aten.heaviside_,
|
316 |
-
aten.huber_loss,
|
317 |
-
aten.huber_loss_backward,
|
318 |
-
aten.im2col,
|
319 |
-
aten.index_add,
|
320 |
-
aten.index_add_,
|
321 |
-
aten.index_copy,
|
322 |
-
aten.index_copy_,
|
323 |
-
aten.index_fill,
|
324 |
-
aten.index_fill_,
|
325 |
-
aten.isneginf,
|
326 |
-
aten.isposinf,
|
327 |
-
aten.l1_loss,
|
328 |
-
aten.leaky_relu_,
|
329 |
-
aten.leaky_relu_backward,
|
330 |
-
aten.lerp,
|
331 |
-
aten.lerp_,
|
332 |
-
aten.linspace,
|
333 |
-
aten.logaddexp,
|
334 |
-
aten.logaddexp2,
|
335 |
-
aten.logit,
|
336 |
-
aten.logit_,
|
337 |
-
aten.logit_backward,
|
338 |
-
aten.log_sigmoid_backward,
|
339 |
-
aten.log_sigmoid_forward,
|
340 |
-
aten._log_softmax_backward_data,
|
341 |
-
aten.logspace,
|
342 |
-
aten.logsumexp.default,
|
343 |
-
aten.masked_fill,
|
344 |
-
aten.masked_fill_,
|
345 |
-
aten.mish,
|
346 |
-
aten.mish_,
|
347 |
-
aten.mse_loss,
|
348 |
-
aten.mse_loss_backward,
|
349 |
-
aten.multi_margin_loss,
|
350 |
-
aten.multilabel_margin_loss_forward,
|
351 |
-
aten.mv,
|
352 |
-
aten.mvlgamma,
|
353 |
-
aten.mvlgamma_,
|
354 |
-
aten.nansum,
|
355 |
-
aten.nan_to_num,
|
356 |
-
aten.nan_to_num_,
|
357 |
-
aten.narrow,
|
358 |
-
aten.native_batch_norm_backward,
|
359 |
-
aten.native_dropout_backward,
|
360 |
-
aten.native_group_norm_backward,
|
361 |
-
aten.native_layer_norm_backward,
|
362 |
-
aten.new_empty,
|
363 |
-
aten.new_full,
|
364 |
-
aten.new_ones,
|
365 |
-
aten.new_zeros,
|
366 |
-
aten.nll_loss_backward,
|
367 |
-
aten.nll_loss_forward,
|
368 |
-
aten.norm,
|
369 |
-
aten.ones,
|
370 |
-
aten.ones_like,
|
371 |
-
aten._prelu_kernel,
|
372 |
-
aten._prelu_kernel_backward,
|
373 |
-
aten._reshape_alias,
|
374 |
-
aten.rad2deg,
|
375 |
-
aten.rad2deg_,
|
376 |
-
aten.renorm,
|
377 |
-
aten.renorm_,
|
378 |
-
aten.replication_pad2d,
|
379 |
-
aten.rot90,
|
380 |
-
aten.rrelu_with_noise,
|
381 |
-
aten.rrelu_with_noise_,
|
382 |
-
aten.rsub.Scalar,
|
383 |
-
aten.rsub.Tensor,
|
384 |
-
aten._scaled_dot_product_flash_attention.default,
|
385 |
-
aten.select_backward,
|
386 |
-
aten.select_scatter,
|
387 |
-
aten.sgn,
|
388 |
-
aten.sgn_,
|
389 |
-
aten.sigmoid_backward,
|
390 |
-
aten.silu,
|
391 |
-
aten.silu_,
|
392 |
-
aten.silu_backward,
|
393 |
-
aten.sinc,
|
394 |
-
aten.sinc_,
|
395 |
-
aten.slice_backward,
|
396 |
-
aten.smooth_l1_loss,
|
397 |
-
aten.smooth_l1_loss_backward,
|
398 |
-
aten.soft_margin_loss,
|
399 |
-
aten.soft_margin_loss_backward,
|
400 |
-
aten._softmax_backward_data,
|
401 |
-
aten.softplus,
|
402 |
-
aten.softplus_backward,
|
403 |
-
aten.softshrink,
|
404 |
-
aten.special_entr,
|
405 |
-
aten.special_log_ndtr,
|
406 |
-
aten.special_xlog1py,
|
407 |
-
aten.split.Tensor,
|
408 |
-
aten.squeeze.default,
|
409 |
-
aten.squeeze.dim,
|
410 |
-
aten.std,
|
411 |
-
aten.std_mean,
|
412 |
-
aten.stack,
|
413 |
-
aten.sum.default,
|
414 |
-
aten.sum.out,
|
415 |
-
aten.t,
|
416 |
-
aten.tanh_backward,
|
417 |
-
aten.threshold,
|
418 |
-
aten.threshold_,
|
419 |
-
aten.threshold_backward,
|
420 |
-
aten.trace,
|
421 |
-
aten.transpose.int,
|
422 |
-
aten.tril,
|
423 |
-
aten.tril_,
|
424 |
-
aten.triu,
|
425 |
-
aten.triu_,
|
426 |
-
aten.unbind,
|
427 |
-
aten.unfold_backward,
|
428 |
-
aten.unfold_copy,
|
429 |
-
aten._unsafe_index,
|
430 |
-
aten.unsafe_split.Tensor,
|
431 |
-
aten.unsafe_split_with_sizes,
|
432 |
-
aten._unsafe_view,
|
433 |
-
aten.upsample_bilinear2d,
|
434 |
-
aten.upsample_nearest2d_backward,
|
435 |
-
aten.view_as_complex,
|
436 |
-
aten.xlogy,
|
437 |
-
aten.xlogy_,
|
438 |
-
aten.zero,
|
439 |
-
aten.zero_,
|
440 |
-
aten.zeros,
|
441 |
-
aten.zeros_like,
|
442 |
-
aten._weight_norm_interface,
|
443 |
-
]
|
444 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_decomp/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (12.2 kB)
|
|
torch/_decomp/__pycache__/decompositions.cpython-310.pyc
DELETED
Binary file (102 kB)
|
|
torch/_decomp/__pycache__/decompositions_for_jvp.cpython-310.pyc
DELETED
Binary file (6.27 kB)
|
|
torch/_decomp/__pycache__/decompositions_for_rng.cpython-310.pyc
DELETED
Binary file (7.99 kB)
|
|
torch/_decomp/decompositions.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
torch/_decomp/decompositions_for_jvp.py
DELETED
@@ -1,302 +0,0 @@
|
|
1 |
-
import inspect
|
2 |
-
from typing import Callable, Dict, List, Optional, Tuple
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch._decomp
|
6 |
-
from torch import Tensor
|
7 |
-
from torch._prims_common.wrappers import _maybe_remove_out_wrapper
|
8 |
-
|
9 |
-
decomposition_table = torch._decomp.decomposition_table
|
10 |
-
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
|
11 |
-
register_decomposition = torch._decomp.register_decomposition
|
12 |
-
aten = torch.ops.aten
|
13 |
-
|
14 |
-
# NOTE: [forward-mode AD decompositions mechanism]
|
15 |
-
#
|
16 |
-
# The mechanism is in VariableType,
|
17 |
-
# IF any inputs have forward grad
|
18 |
-
# AND there is no forward AD formula implemented
|
19 |
-
# AND the functions is actually differentiable
|
20 |
-
# run the decomposition
|
21 |
-
# See run_jit_decomposition_with_args_for_jvp
|
22 |
-
# We currently use python decompositions that we torchscript.
|
23 |
-
#
|
24 |
-
# Note that we would be building the backward graph at the decomposed level
|
25 |
-
# too, but that is OK, because we would've errored out otherwise anyway.
|
26 |
-
#
|
27 |
-
# TODO: The mechanism we are using to register decompositions doesn't
|
28 |
-
# seem to be exclusively used for jvp. So open question here is whether
|
29 |
-
# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
|
30 |
-
# If that is the case, we may go down the decomposition path unexpectedly
|
31 |
-
# (and possibly produce an unintelligible error) vs erroring out earlier and
|
32 |
-
# printing that the forward AD formula is not implemented.
|
33 |
-
#
|
34 |
-
# The solution to this may be to have a explicitly white list control when
|
35 |
-
# to enable the decomposition.
|
36 |
-
|
37 |
-
|
38 |
-
def maybe_register_decomposition(op):
|
39 |
-
def decorator(f):
|
40 |
-
try:
|
41 |
-
return register_decomposition(op)(f)
|
42 |
-
except Exception:
|
43 |
-
return f
|
44 |
-
|
45 |
-
return decorator
|
46 |
-
|
47 |
-
|
48 |
-
# Functions where we need a special decomposition for jvp but there's another version that
|
49 |
-
# should be used more generally (ex. for jvp we need to recompute the mean and variance for
|
50 |
-
# the backwards of a normalization function. Without jvp, it should use the saved value)
|
51 |
-
decomposition_table_for_jvp = {}
|
52 |
-
|
53 |
-
|
54 |
-
def register_decomposition_for_jvp(fn):
|
55 |
-
return register_decomposition(fn, registry=decomposition_table_for_jvp)
|
56 |
-
|
57 |
-
|
58 |
-
def _register_jit_decomposition_for_jvp(decomp, use_python=False):
|
59 |
-
if decomp in decomposition_table_for_jvp:
|
60 |
-
decomposition_table_used = decomposition_table_for_jvp
|
61 |
-
elif decomp in decomposition_table:
|
62 |
-
decomposition_table_used = decomposition_table
|
63 |
-
else:
|
64 |
-
raise RuntimeError(f"could not find decomposition for {decomp}")
|
65 |
-
decomp_fn = decomposition_table_used[decomp]
|
66 |
-
|
67 |
-
# `out_wrapper` extends a decompositions signature with
|
68 |
-
# an `out` parameter. However jit will use the unwrapped function's
|
69 |
-
# signature instead so we need to unwrap here to prevent an error
|
70 |
-
decomp_fn = _maybe_remove_out_wrapper(decomp_fn)
|
71 |
-
|
72 |
-
if use_python:
|
73 |
-
decomp_fn = torch.jit.ignore(decomp_fn)
|
74 |
-
sig = inspect.signature(decomp_fn)
|
75 |
-
|
76 |
-
# Create a string wrapping the function from the signature
|
77 |
-
# example output:
|
78 |
-
# def wrapped_decomp(x: torch.Tensor, y: int, z: int):
|
79 |
-
# return decomp_fn(x, y, z)
|
80 |
-
# Thanks copilot!
|
81 |
-
def get_function_def(sig):
|
82 |
-
param_def = [f"{param_str}" for param_str in sig.parameters.values()]
|
83 |
-
param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
|
84 |
-
|
85 |
-
return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
|
86 |
-
|
87 |
-
f_str = get_function_def(sig)
|
88 |
-
graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
|
89 |
-
else:
|
90 |
-
graph = torch.jit.script(decomp_fn).graph
|
91 |
-
torch.jit._register_decomposition(decomp, graph)
|
92 |
-
|
93 |
-
|
94 |
-
# The only decompositions here are temporary or hacks for the purposes of jvp
|
95 |
-
|
96 |
-
|
97 |
-
# TODO: do these also belong here?
|
98 |
-
@maybe_register_decomposition(aten.trace.default)
|
99 |
-
def trace(self: Tensor) -> Tensor:
|
100 |
-
return torch.sum(torch.diag(self))
|
101 |
-
|
102 |
-
|
103 |
-
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
|
104 |
-
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
|
105 |
-
min = torch.minimum(self.new_zeros(()), self)
|
106 |
-
z = torch.exp(-torch.abs(self))
|
107 |
-
if self.is_cuda:
|
108 |
-
buffer = self.new_zeros((0,))
|
109 |
-
else:
|
110 |
-
buffer = z
|
111 |
-
return min - torch.log1p(z), buffer
|
112 |
-
|
113 |
-
|
114 |
-
def recompute_mean_var(
|
115 |
-
input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
|
116 |
-
):
|
117 |
-
# for most norm decompositions, it will be the same as the core version except for here.
|
118 |
-
# We recompute the mean and variance so that they track gradients through input
|
119 |
-
|
120 |
-
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
|
121 |
-
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
|
122 |
-
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
|
123 |
-
eps = eps.detach()
|
124 |
-
rstd = 1 / torch.sqrt(var + eps)
|
125 |
-
return mean, rstd
|
126 |
-
|
127 |
-
|
128 |
-
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
|
129 |
-
def native_layer_norm_backward(
|
130 |
-
grad_out: Tensor,
|
131 |
-
input: Tensor,
|
132 |
-
normalized_shape: List[int],
|
133 |
-
mean: Tensor,
|
134 |
-
rstd: Tensor,
|
135 |
-
weight: Optional[Tensor],
|
136 |
-
bias: Optional[Tensor],
|
137 |
-
output_mask: List[bool],
|
138 |
-
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
139 |
-
input_shape = input.shape
|
140 |
-
input_ndim = input.dim()
|
141 |
-
|
142 |
-
axis = input_ndim - len(normalized_shape)
|
143 |
-
inner_dims = input_shape[axis:]
|
144 |
-
outer_dims = input_shape[:axis]
|
145 |
-
inner_dim_indices = list(range(axis, input_ndim))
|
146 |
-
outer_dim_indices = list(range(0, axis))
|
147 |
-
|
148 |
-
N = 1
|
149 |
-
for i in inner_dims:
|
150 |
-
N *= i
|
151 |
-
M = 1
|
152 |
-
for i in outer_dims:
|
153 |
-
M *= i
|
154 |
-
if M <= 0 or N <= 0:
|
155 |
-
return (
|
156 |
-
input.new_zeros(input_shape),
|
157 |
-
input.new_zeros(input_shape[axis:]),
|
158 |
-
input.new_zeros(input_shape[axis:]),
|
159 |
-
)
|
160 |
-
|
161 |
-
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
|
162 |
-
|
163 |
-
x_hat = (input - mean_) * rstd_
|
164 |
-
if weight is not None:
|
165 |
-
grad_x_hat = grad_out * weight
|
166 |
-
else:
|
167 |
-
grad_x_hat = grad_out
|
168 |
-
a = grad_x_hat * N
|
169 |
-
b = torch.sum(grad_x_hat, inner_dim_indices, True)
|
170 |
-
c1 = torch.mul(grad_x_hat, x_hat)
|
171 |
-
c2 = torch.sum(c1, inner_dim_indices, True)
|
172 |
-
c3 = torch.mul(x_hat, c2)
|
173 |
-
inner = a - b - c3
|
174 |
-
|
175 |
-
if output_mask[0]:
|
176 |
-
d_input: Optional[Tensor] = (rstd_ / N) * inner
|
177 |
-
else:
|
178 |
-
d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
|
179 |
-
|
180 |
-
if output_mask[1] and weight is not None:
|
181 |
-
if len(outer_dim_indices) > 0:
|
182 |
-
d_weight: Optional[Tensor] = torch.sum(
|
183 |
-
grad_out * x_hat, outer_dim_indices, False
|
184 |
-
)
|
185 |
-
else:
|
186 |
-
d_weight = grad_out * x_hat
|
187 |
-
elif weight is not None:
|
188 |
-
d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
|
189 |
-
else:
|
190 |
-
d_weight = torch.zeros(()) # should be None but doesn't work with vjp
|
191 |
-
|
192 |
-
if output_mask[2] and bias is not None:
|
193 |
-
if len(outer_dim_indices) > 0:
|
194 |
-
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
|
195 |
-
else:
|
196 |
-
d_bias = grad_out.clone()
|
197 |
-
elif bias is not None:
|
198 |
-
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
|
199 |
-
else:
|
200 |
-
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
|
201 |
-
|
202 |
-
return (d_input, d_weight, d_bias)
|
203 |
-
|
204 |
-
|
205 |
-
def prod(x: List[int]):
|
206 |
-
r = 1
|
207 |
-
for i in x:
|
208 |
-
r *= i
|
209 |
-
return r
|
210 |
-
|
211 |
-
|
212 |
-
@register_decomposition_for_jvp(aten.native_batch_norm_backward)
|
213 |
-
def native_batch_norm_backward(
|
214 |
-
grad_out: Tensor,
|
215 |
-
input: Tensor,
|
216 |
-
weight: Optional[Tensor],
|
217 |
-
running_mean: Optional[Tensor],
|
218 |
-
running_var: Optional[Tensor],
|
219 |
-
save_mean: Optional[Tensor],
|
220 |
-
save_invstd: Optional[Tensor],
|
221 |
-
train: bool,
|
222 |
-
eps: float,
|
223 |
-
output_mask: List[bool],
|
224 |
-
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
225 |
-
input_shape = input.shape
|
226 |
-
input_rank = input.dim()
|
227 |
-
assert input_rank >= 2, "rank of the input must be at least 2"
|
228 |
-
|
229 |
-
axis = 1
|
230 |
-
num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
|
231 |
-
mean = save_mean
|
232 |
-
invstd = save_invstd
|
233 |
-
if train:
|
234 |
-
assert (
|
235 |
-
save_mean is not None and save_invstd is not None
|
236 |
-
), "when train=True, save_mean and save_invstd are required"
|
237 |
-
|
238 |
-
reduciton_dims = [0] + list(range(2, input.dim()))
|
239 |
-
assert invstd is not None # for typing
|
240 |
-
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
|
241 |
-
else:
|
242 |
-
assert running_mean is not None and running_var is not None
|
243 |
-
mean = running_mean
|
244 |
-
invstd = torch.rsqrt(running_var + eps)
|
245 |
-
|
246 |
-
assert invstd is not None and mean is not None
|
247 |
-
|
248 |
-
broadcast_mask = [1] * input_rank
|
249 |
-
broadcast_mask[axis] = input_shape[axis]
|
250 |
-
|
251 |
-
reduction_axes: List[int] = []
|
252 |
-
for i in range(input_rank):
|
253 |
-
if i != axis:
|
254 |
-
reduction_axes.append(i)
|
255 |
-
|
256 |
-
mean = torch.reshape(mean, broadcast_mask)
|
257 |
-
norm = 1.0 / num_features
|
258 |
-
grad_output_sum = torch.sum(grad_out, reduction_axes)
|
259 |
-
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
|
260 |
-
|
261 |
-
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
|
262 |
-
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
|
263 |
-
|
264 |
-
if weight is None:
|
265 |
-
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
|
266 |
-
else:
|
267 |
-
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
|
268 |
-
|
269 |
-
if train:
|
270 |
-
proj = (input - mean) * proj_scale
|
271 |
-
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
|
272 |
-
else:
|
273 |
-
grad_input = grad_out * grad_scale
|
274 |
-
|
275 |
-
if output_mask[1]:
|
276 |
-
grad_weight = dot_p * invstd
|
277 |
-
elif weight is not None:
|
278 |
-
grad_weight = torch.zeros_like(
|
279 |
-
weight
|
280 |
-
) # should be None but doesn't work with vjp
|
281 |
-
else:
|
282 |
-
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
|
283 |
-
|
284 |
-
if output_mask[2]:
|
285 |
-
grad_bias = grad_output_sum
|
286 |
-
else:
|
287 |
-
grad_bias = torch.zeros_like(
|
288 |
-
grad_output_sum
|
289 |
-
) # should be None but doesn't work with vjp
|
290 |
-
|
291 |
-
return (grad_input, grad_weight, grad_bias)
|
292 |
-
|
293 |
-
|
294 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
|
295 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
|
296 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
|
297 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
|
298 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
|
299 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
|
300 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
|
301 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
|
302 |
-
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_decomp/decompositions_for_rng.py
DELETED
@@ -1,263 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
from collections import defaultdict
|
3 |
-
from typing import Callable, Dict
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch._decomp as decomp
|
7 |
-
from torch._decomp import get_decompositions
|
8 |
-
from torch._ops import OpOverload
|
9 |
-
|
10 |
-
aten = torch.ops.aten
|
11 |
-
|
12 |
-
rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
|
13 |
-
|
14 |
-
|
15 |
-
def register_rng_decomposition(aten_op):
|
16 |
-
return decomp.register_decomposition(aten_op, rng_decompositions)
|
17 |
-
|
18 |
-
|
19 |
-
def throw_on_non_cuda(device):
|
20 |
-
raise RuntimeError(
|
21 |
-
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
|
22 |
-
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
|
23 |
-
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
|
24 |
-
)
|
25 |
-
|
26 |
-
|
27 |
-
# TODO - We have to register many more distributions here, and also higher level
|
28 |
-
# ops like dropout which have fused implementation and can hide the rand inside.
|
29 |
-
@register_rng_decomposition(aten.rand)
|
30 |
-
def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
|
31 |
-
if device and device.type != "cuda":
|
32 |
-
throw_on_non_cuda(device)
|
33 |
-
seed, offset = PhiloxStateTracker.get_state_as_tuple()
|
34 |
-
dtype = dtype or torch.float32
|
35 |
-
out, offset_jump = torch.ops.rngprims.philox_rand(
|
36 |
-
shape, seed, offset, None, device, dtype
|
37 |
-
)
|
38 |
-
PhiloxStateTracker.advance_offset(offset_jump)
|
39 |
-
return out
|
40 |
-
|
41 |
-
|
42 |
-
@register_rng_decomposition(aten.rand_like)
|
43 |
-
def rand_like(
|
44 |
-
x: torch.Tensor,
|
45 |
-
dtype=None,
|
46 |
-
layout=None,
|
47 |
-
device=None,
|
48 |
-
pin_memory=False,
|
49 |
-
memory_format=torch.preserve_format,
|
50 |
-
):
|
51 |
-
device = device or x.device
|
52 |
-
if device.type != "cuda":
|
53 |
-
throw_on_non_cuda(device)
|
54 |
-
dtype = dtype or x.dtype
|
55 |
-
seed, offset = PhiloxStateTracker.get_state_as_tuple()
|
56 |
-
out, offset_jump = torch.ops.rngprims.philox_rand(
|
57 |
-
x.shape, seed, offset, None, device, dtype
|
58 |
-
)
|
59 |
-
PhiloxStateTracker.advance_offset(offset_jump)
|
60 |
-
return out
|
61 |
-
|
62 |
-
|
63 |
-
class PhiloxState:
|
64 |
-
"""
|
65 |
-
Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
|
66 |
-
relative_offset. seed and base_offset basically point to the rng state just
|
67 |
-
before tracing starts. relative offset tracks the totally consumed offset at
|
68 |
-
trace time.
|
69 |
-
"""
|
70 |
-
|
71 |
-
def __init__(self):
|
72 |
-
self.reset()
|
73 |
-
|
74 |
-
def reset(self):
|
75 |
-
self.seed = torch.tensor(())
|
76 |
-
self.base_offset = torch.tensor(())
|
77 |
-
self.relative_offset = 0
|
78 |
-
self.offset_advanced_alteast_once = False
|
79 |
-
|
80 |
-
def validate_state(self):
|
81 |
-
assert self.seed.numel() != 0 and self.base_offset.numel() != 0
|
82 |
-
|
83 |
-
def advance_offset(self, consumed_offset):
|
84 |
-
self.offset_advanced_alteast_once = True
|
85 |
-
self.relative_offset = self.relative_offset + consumed_offset
|
86 |
-
|
87 |
-
def set_state(self, seed, base_offset, relative_offset=0):
|
88 |
-
self.seed = seed
|
89 |
-
self.base_offset = base_offset
|
90 |
-
self.relative_offset = relative_offset
|
91 |
-
|
92 |
-
def get_state_as_tuple(self):
|
93 |
-
self.validate_state()
|
94 |
-
return (self.seed, self.base_offset + self.relative_offset)
|
95 |
-
|
96 |
-
def get_state_as_tensor(self):
|
97 |
-
# Only needed because we override get_rng_state.
|
98 |
-
self.validate_state()
|
99 |
-
return torch.stack([self.seed, self.base_offset + self.relative_offset])
|
100 |
-
|
101 |
-
def set_state_from_tensor(self, state):
|
102 |
-
# Only needed because we override set_rng_state.
|
103 |
-
self.seed, self.base_offset = torch.unbind(state)
|
104 |
-
self.relative_offset = 0
|
105 |
-
|
106 |
-
|
107 |
-
class PhiloxStateTracker:
|
108 |
-
"""
|
109 |
-
Singleton class to track the philox rng state during AOT Autograd tracing.
|
110 |
-
For each aot tracing instance, AOT Autograd resets this tracker and keeps
|
111 |
-
track of both forward and backward offsets. At runtime, we only care about
|
112 |
-
the total consumed forward and backward offsets. For dynamic shapes, these
|
113 |
-
offsets are a function of input shapes. Therefore, the AOT generated graphs
|
114 |
-
have additional outputs that compute total consumed forward and backward
|
115 |
-
offsets.
|
116 |
-
"""
|
117 |
-
|
118 |
-
running_state: PhiloxState
|
119 |
-
fwd_state: PhiloxState
|
120 |
-
bwd_state: PhiloxState
|
121 |
-
|
122 |
-
def __enter__(self):
|
123 |
-
PhiloxStateTracker.reset()
|
124 |
-
return self
|
125 |
-
|
126 |
-
def __exit__(self, exc_type, exc_cal, exc_tb):
|
127 |
-
PhiloxStateTracker.reset()
|
128 |
-
|
129 |
-
@classmethod
|
130 |
-
def reset(cls):
|
131 |
-
cls.running_state = PhiloxState()
|
132 |
-
cls.fwd_state = PhiloxState()
|
133 |
-
cls.bwd_state = PhiloxState()
|
134 |
-
|
135 |
-
@classmethod
|
136 |
-
def mark_beginning_of_forward(cls):
|
137 |
-
# Tells the tracker to use fwd_state as the running state
|
138 |
-
cls.running_state = cls.fwd_state
|
139 |
-
|
140 |
-
@classmethod
|
141 |
-
def mark_beginning_of_backward(cls):
|
142 |
-
# Tells the tracker to use bwd_state as the running state
|
143 |
-
cls.running_state = cls.bwd_state
|
144 |
-
|
145 |
-
@classmethod
|
146 |
-
def record_state(cls, seed, offset, mode):
|
147 |
-
# Records the seed and offset tensors. These tensors are used to invoke
|
148 |
-
# the philox_rand functional primitives.
|
149 |
-
if mode == "forward":
|
150 |
-
cls.fwd_state.set_state(seed, offset)
|
151 |
-
cls.mark_beginning_of_forward()
|
152 |
-
else:
|
153 |
-
assert mode == "backward"
|
154 |
-
cls.bwd_state.set_state(seed, offset)
|
155 |
-
|
156 |
-
@classmethod
|
157 |
-
def get_state_as_tensor(cls):
|
158 |
-
# The only reason this exists is because we override get_rng_state and
|
159 |
-
# set_rng_state during tracing. get_rng_state expects a tensor output,
|
160 |
-
# so return (seed, offset) tuple upset other parts of the program like
|
161 |
-
# ctx.saved_tensors.
|
162 |
-
|
163 |
-
# A bad consequence is that if user saves and restores rng state, we
|
164 |
-
# have little bit of ugliness in the generated code, where we first
|
165 |
-
# concat the (seed, offset) to create a tensor for get_rng_state, and
|
166 |
-
# then split it back to get (seed, offset) tuple in set_rng_state.
|
167 |
-
|
168 |
-
# TODO: Investigate if there is be a better way to wrap the tuple in a
|
169 |
-
# false Tensor object, and then desugar it later on.
|
170 |
-
return cls.running_state.get_state_as_tensor()
|
171 |
-
|
172 |
-
@classmethod
|
173 |
-
def get_state_as_tuple(cls):
|
174 |
-
return cls.running_state.get_state_as_tuple()
|
175 |
-
|
176 |
-
@classmethod
|
177 |
-
def set_state_from_tensor(cls, x):
|
178 |
-
# This is only needed because we override set_rng_state. Look at the
|
179 |
-
# comment in get_state_from_tensor method.
|
180 |
-
cls.running_state.set_state_from_tensor(x)
|
181 |
-
|
182 |
-
@classmethod
|
183 |
-
def advance_offset(cls, consumed_offset):
|
184 |
-
cls.running_state.advance_offset(consumed_offset)
|
185 |
-
|
186 |
-
@classmethod
|
187 |
-
def get_current_relative_offset(cls):
|
188 |
-
return cls.running_state.relative_offset
|
189 |
-
|
190 |
-
@staticmethod
|
191 |
-
def multiple_of_4(offset):
|
192 |
-
# torch cuda rng state offset must be a multiple of 4. For inductor, as
|
193 |
-
# we sum up all the numel, the result might not be a multiple of 4. This
|
194 |
-
# method achieves that.
|
195 |
-
return (offset + 3) // 4 * 4
|
196 |
-
|
197 |
-
@classmethod
|
198 |
-
def get_updated_fwd_offset(cls):
|
199 |
-
# Short circuit if no rand ops were observed
|
200 |
-
if not cls.fwd_state.offset_advanced_alteast_once:
|
201 |
-
return cls.fwd_state.base_offset
|
202 |
-
return cls.multiple_of_4(
|
203 |
-
cls.fwd_state.base_offset + cls.fwd_state.relative_offset
|
204 |
-
)
|
205 |
-
|
206 |
-
@classmethod
|
207 |
-
def get_updated_bwd_offset(cls):
|
208 |
-
# Short circuit if no rand ops were observed
|
209 |
-
if not cls.bwd_state.offset_advanced_alteast_once:
|
210 |
-
return cls.bwd_state.base_offset
|
211 |
-
return cls.multiple_of_4(
|
212 |
-
cls.bwd_state.base_offset + cls.bwd_state.relative_offset
|
213 |
-
)
|
214 |
-
|
215 |
-
|
216 |
-
# Adding more decompositions which eventually use rand_like inside decomps.
|
217 |
-
# Adding these in rng_decompositions ensures the functionalization of rand_like
|
218 |
-
# ops used in these decomps. The list is copied from inductor codebase, which
|
219 |
-
# uses it for similar purpose.
|
220 |
-
#
|
221 |
-
# Caution - These decomps do not have same accuracy as that of eager. However,
|
222 |
-
# we can't just disable them with a config flag like fallback_random, because
|
223 |
-
# for functionalization of rng ops, we have to decompose these ops.
|
224 |
-
extra_random_decomps = get_decompositions(
|
225 |
-
[
|
226 |
-
aten.cauchy,
|
227 |
-
aten.cauchy_,
|
228 |
-
aten.exponential,
|
229 |
-
aten.exponential_,
|
230 |
-
aten.geometric,
|
231 |
-
aten.geometric_,
|
232 |
-
aten.native_dropout,
|
233 |
-
aten.normal,
|
234 |
-
aten.normal_,
|
235 |
-
aten.normal_functional,
|
236 |
-
aten.log_normal,
|
237 |
-
aten.log_normal_,
|
238 |
-
aten.rrelu_with_noise,
|
239 |
-
aten.rrelu_with_noise_,
|
240 |
-
aten.uniform_,
|
241 |
-
]
|
242 |
-
)
|
243 |
-
register_extra_random_decomp = functools.partial(
|
244 |
-
decomp.register_decomposition, registry=extra_random_decomps
|
245 |
-
)
|
246 |
-
|
247 |
-
|
248 |
-
@register_extra_random_decomp([aten.bernoulli_])
|
249 |
-
def bernoulli_(self, p=0.5):
|
250 |
-
if self.device == torch.device("cpu"):
|
251 |
-
return NotImplemented
|
252 |
-
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
|
253 |
-
|
254 |
-
|
255 |
-
@register_extra_random_decomp([aten.bernoulli.p])
|
256 |
-
def bernoulli_p(self, p=0.5, *, generator=None):
|
257 |
-
if self.device == torch.device("cpu"):
|
258 |
-
return NotImplemented
|
259 |
-
assert generator is None
|
260 |
-
return torch.rand_like(self, dtype=torch.float32) < p
|
261 |
-
|
262 |
-
|
263 |
-
rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_deploy.py
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
|
5 |
-
from torch.package._package_pickler import create_pickler
|
6 |
-
from torch.package._package_unpickler import PackageUnpickler
|
7 |
-
from torch.serialization import _maybe_decode_ascii
|
8 |
-
|
9 |
-
|
10 |
-
def _save_storages(importer, obj):
|
11 |
-
serialized_storages = []
|
12 |
-
serialized_dtypes = []
|
13 |
-
|
14 |
-
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
|
15 |
-
importers: Importer
|
16 |
-
if importer is not None:
|
17 |
-
importers = OrderedImporter(importer, sys_importer)
|
18 |
-
else:
|
19 |
-
importers = sys_importer
|
20 |
-
|
21 |
-
def persistent_id(obj):
|
22 |
-
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
|
23 |
-
if isinstance(obj, torch.storage.TypedStorage):
|
24 |
-
# TODO: Once we decide to break serialization FC, we can
|
25 |
-
# remove this case
|
26 |
-
storage = obj._untyped_storage
|
27 |
-
dtype = obj.dtype
|
28 |
-
else:
|
29 |
-
storage = obj
|
30 |
-
dtype = torch.uint8
|
31 |
-
|
32 |
-
serialized_storages.append(obj)
|
33 |
-
serialized_dtypes.append(dtype)
|
34 |
-
return ("storage", len(serialized_storages) - 1)
|
35 |
-
|
36 |
-
if hasattr(obj, "__reduce_deploy__"):
|
37 |
-
if _serialized_reduces.get(id(obj)) is None:
|
38 |
-
_serialized_reduces[id(obj)] = (
|
39 |
-
"reduce_deploy",
|
40 |
-
id(obj),
|
41 |
-
*obj.__reduce_deploy__(importers),
|
42 |
-
)
|
43 |
-
return _serialized_reduces[id(obj)]
|
44 |
-
|
45 |
-
return None
|
46 |
-
|
47 |
-
# Write the pickle data for `obj`
|
48 |
-
data_buf = io.BytesIO()
|
49 |
-
pickler = create_pickler(data_buf, importers)
|
50 |
-
pickler.persistent_id = persistent_id
|
51 |
-
pickler.dump(obj)
|
52 |
-
data_value = data_buf.getvalue()
|
53 |
-
return (
|
54 |
-
data_value,
|
55 |
-
serialized_storages,
|
56 |
-
serialized_dtypes,
|
57 |
-
importer.zip_reader if importer else None,
|
58 |
-
)
|
59 |
-
|
60 |
-
|
61 |
-
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
|
62 |
-
def persistent_load(saved_id):
|
63 |
-
assert isinstance(saved_id, tuple)
|
64 |
-
typename = _maybe_decode_ascii(saved_id[0])
|
65 |
-
data = saved_id[1:]
|
66 |
-
|
67 |
-
if typename == "storage":
|
68 |
-
# TODO: Once we decide to break serialization FC, we can
|
69 |
-
# stop wrapping with TypedStorage
|
70 |
-
storage = serialized_storages[data[0]]
|
71 |
-
dtype = serialized_dtypes[data[0]]
|
72 |
-
return torch.storage.TypedStorage(
|
73 |
-
wrap_storage=storage.untyped(), dtype=dtype
|
74 |
-
)
|
75 |
-
|
76 |
-
if typename == "reduce_deploy":
|
77 |
-
reduce_id, func, args = data
|
78 |
-
if reduce_id not in _loaded_reduces:
|
79 |
-
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
|
80 |
-
return _loaded_reduces[reduce_id]
|
81 |
-
|
82 |
-
return None
|
83 |
-
|
84 |
-
importer: Importer
|
85 |
-
if zip_reader is not None:
|
86 |
-
importer = OrderedImporter(_get_package(zip_reader), sys_importer)
|
87 |
-
else:
|
88 |
-
importer = sys_importer
|
89 |
-
|
90 |
-
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
|
91 |
-
unpickler.persistent_load = persistent_load # type: ignore[assignment]
|
92 |
-
result = _deploy_objects[id] = unpickler.load()
|
93 |
-
return result
|
94 |
-
|
95 |
-
|
96 |
-
def _get_package(zip_reader):
|
97 |
-
if zip_reader not in _raw_packages:
|
98 |
-
_raw_packages[zip_reader] = PackageImporter(zip_reader)
|
99 |
-
return _raw_packages[zip_reader]
|
100 |
-
|
101 |
-
|
102 |
-
_raw_packages: dict = {}
|
103 |
-
_deploy_objects: dict = {}
|
104 |
-
_serialized_reduces: dict = {}
|
105 |
-
_loaded_reduces: dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch/_dispatch/__init__.py
DELETED
File without changes
|
torch/_dispatch/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (155 Bytes)
|
|