ayousanz commited on
Commit
cc55e01
·
verified ·
1 Parent(s): 413e7ca

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/Lib/site-packages/torch/lib/kineto.lib +3 -0
  3. .venv/Lib/site-packages/torch/mtia/__init__.py +332 -0
  4. .venv/Lib/site-packages/torch/mtia/__pycache__/__init__.cpython-39.pyc +0 -0
  5. .venv/Lib/site-packages/torch/mtia/__pycache__/_utils.cpython-39.pyc +0 -0
  6. .venv/Lib/site-packages/torch/multiprocessing/__init__.py +100 -0
  7. .venv/Lib/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-39.pyc +0 -0
  8. .venv/Lib/site-packages/torch/multiprocessing/_atfork.py +35 -0
  9. .venv/Lib/site-packages/torch/multiprocessing/pool.py +52 -0
  10. .venv/Lib/site-packages/torch/multiprocessing/queue.py +43 -0
  11. .venv/Lib/site-packages/torch/multiprocessing/reductions.py +647 -0
  12. .venv/Lib/site-packages/torch/multiprocessing/spawn.py +328 -0
  13. .venv/Lib/site-packages/torch/nn/parallel/__init__.py +28 -0
  14. .venv/Lib/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-39.pyc +0 -0
  15. .venv/Lib/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-39.pyc +0 -0
  16. .venv/Lib/site-packages/torch/nn/parallel/__pycache__/comm.cpython-39.pyc +0 -0
  17. .venv/Lib/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-39.pyc +0 -0
  18. .venv/Lib/site-packages/torch/nn/parallel/__pycache__/distributed.cpython-39.pyc +0 -0
  19. .venv/Lib/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-39.pyc +0 -0
  20. .venv/Lib/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-39.pyc +0 -0
  21. .venv/Lib/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-39.pyc +0 -0
  22. .venv/Lib/site-packages/torch/nn/qat/__init__.py +18 -0
  23. .venv/Lib/site-packages/torch/nn/qat/dynamic/__init__.py +7 -0
  24. .venv/Lib/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
  25. .venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__init__.py +4 -0
  26. .venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  27. .venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc +0 -0
  28. .venv/Lib/site-packages/torch/nn/qat/dynamic/modules/linear.py +10 -0
  29. .venv/Lib/site-packages/torch/nn/qat/modules/__init__.py +20 -0
  30. .venv/Lib/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-39.pyc +0 -0
  31. .venv/Lib/site-packages/torch/nn/qat/modules/conv.py +11 -0
  32. .venv/Lib/site-packages/torch/nn/qat/modules/embedding_ops.py +14 -0
  33. .venv/Lib/site-packages/torch/nn/qat/modules/linear.py +10 -0
  34. .venv/Lib/site-packages/torch/nn/quantized/__init__.py +39 -0
  35. .venv/Lib/site-packages/torch/nn/quantized/dynamic/__init__.py +1 -0
  36. .venv/Lib/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
  37. .venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__init__.py +43 -0
  38. .venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/conv.py +28 -0
  39. .venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/linear.py +10 -0
  40. .venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/rnn.py +34 -0
  41. .venv/Lib/site-packages/torch/nn/quantized/functional.py +10 -0
  42. .venv/Lib/site-packages/torch/nn/quantized/modules/__init__.py +97 -0
  43. .venv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  44. .venv/Lib/site-packages/torch/nn/quantized/modules/activation.py +20 -0
  45. .venv/Lib/site-packages/torch/nn/quantized/modules/batchnorm.py +11 -0
  46. .venv/Lib/site-packages/torch/nn/quantized/modules/conv.py +29 -0
  47. .venv/Lib/site-packages/torch/nn/quantized/modules/dropout.py +14 -0
  48. .venv/Lib/site-packages/torch/nn/quantized/modules/embedding_ops.py +18 -0
  49. .venv/Lib/site-packages/torch/nn/quantized/modules/functional_modules.py +18 -0
  50. .venv/Lib/site-packages/torch/nn/quantized/modules/linear.py +14 -0
.gitattributes CHANGED
@@ -122,3 +122,4 @@ reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs
122
  .venv/Lib/site-packages/torch/lib/cudnn_engines_runtime_compiled64_9.dll filter=lfs diff=lfs merge=lfs -text
123
  .venv/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text
124
  .venv/Lib/site-packages/torch/lib/nvrtc-builtins64_121.dll filter=lfs diff=lfs merge=lfs -text
 
 
122
  .venv/Lib/site-packages/torch/lib/cudnn_engines_runtime_compiled64_9.dll filter=lfs diff=lfs merge=lfs -text
123
  .venv/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text
124
  .venv/Lib/site-packages/torch/lib/nvrtc-builtins64_121.dll filter=lfs diff=lfs merge=lfs -text
125
+ .venv/Lib/site-packages/torch/lib/kineto.lib filter=lfs diff=lfs merge=lfs -text
.venv/Lib/site-packages/torch/lib/kineto.lib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4b349dc42209360e73bcabaf3160289923aec193db7996966926407cb51fb76
3
+ size 21732956
.venv/Lib/site-packages/torch/mtia/__init__.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ r"""
3
+ This package enables an interface for accessing MTIA backend in python
4
+ """
5
+
6
+ import threading
7
+ import warnings
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import device as _device, Tensor
12
+ from torch._utils import _dummy_type, _LazySeedTracker, classproperty
13
+ from torch.types import Device
14
+
15
+ from ._utils import _get_device_index
16
+
17
+
18
+ _device_t = Union[_device, str, int, None]
19
+
20
+ # torch.mtia.Event/Stream is alias of torch.Event/Stream
21
+ Event = torch.Event
22
+ Stream = torch.Stream
23
+
24
+ _initialized = False
25
+ _queued_calls: List[
26
+ Tuple[Callable[[], None], List[str]]
27
+ ] = [] # don't invoke these until initialization occurs
28
+ _tls = threading.local()
29
+ _initialization_lock = threading.Lock()
30
+ _lazy_seed_tracker = _LazySeedTracker()
31
+
32
+
33
+ def init():
34
+ _lazy_init()
35
+
36
+
37
+ def is_initialized():
38
+ r"""Return whether PyTorch's MTIA state has been initialized."""
39
+ return _initialized and not _is_in_bad_fork()
40
+
41
+
42
+ def _is_in_bad_fork() -> bool:
43
+ return torch._C._mtia_isInBadFork()
44
+
45
+
46
+ def _lazy_init() -> None:
47
+ global _initialized, _queued_calls
48
+ if is_initialized() or hasattr(_tls, "is_initializing"):
49
+ return
50
+ with _initialization_lock:
51
+ # We be double-checking locking, boys! This is OK because
52
+ # the above test was GIL protected anyway. The inner test
53
+ # is for when a thread blocked on some other thread which was
54
+ # doing the initialization; when they get the lock, they will
55
+ # find there is nothing left to do.
56
+ if is_initialized():
57
+ return
58
+ # It is important to prevent other threads from entering _lazy_init
59
+ # immediately, while we are still guaranteed to have the GIL, because some
60
+ # of the C calls we make below will release the GIL
61
+ if _is_in_bad_fork():
62
+ raise RuntimeError(
63
+ "Cannot re-initialize MTIA in forked subprocess. To use MTIA with "
64
+ "multiprocessing, you must use the 'spawn' start method"
65
+ )
66
+ if not _is_compiled():
67
+ raise AssertionError(
68
+ "Torch not compiled with MTIA enabled. "
69
+ "Ensure you have `import mtia.host_runtime.torch_mtia` in your python "
70
+ "src file and include `//mtia/host_runtime/torch_mtia:torch_mtia` as "
71
+ "your target dependency!"
72
+ )
73
+
74
+ torch._C._mtia_init()
75
+ # Some of the queued calls may reentrantly call _lazy_init();
76
+ # we need to just return without initializing in that case.
77
+ # However, we must not let any *other* threads in!
78
+ _tls.is_initializing = True
79
+
80
+ for calls in _lazy_seed_tracker.get_calls():
81
+ if calls:
82
+ _queued_calls.append(calls)
83
+
84
+ try:
85
+ for queued_call, orig_traceback in _queued_calls:
86
+ try:
87
+ queued_call()
88
+ except Exception as e:
89
+ msg = (
90
+ f"MTIA call failed lazily at initialization with error: {str(e)}\n\n"
91
+ f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}"
92
+ )
93
+ raise DeferredMtiaCallError(msg) from e
94
+ finally:
95
+ delattr(_tls, "is_initializing")
96
+ _initialized = True
97
+
98
+
99
+ class DeferredMtiaCallError(Exception):
100
+ pass
101
+
102
+
103
+ def _is_compiled() -> bool:
104
+ r"""Return true if compiled with MTIA support."""
105
+ return torch._C._mtia_isBuilt()
106
+
107
+
108
+ def is_available() -> bool:
109
+ r"""Return true if MTIA device is available"""
110
+ if not _is_compiled():
111
+ return False
112
+ # MTIA has to init devices first to know if there is any devices available.
113
+ return device_count() > 0
114
+
115
+
116
+ def synchronize(device: Optional[_device_t] = None) -> None:
117
+ r"""Waits for all jobs in all streams on a MTIA device to complete."""
118
+ with torch.mtia.device(device):
119
+ return torch._C._mtia_deviceSynchronize()
120
+
121
+
122
+ def device_count() -> int:
123
+ r"""Return the number of MTIA devices available."""
124
+ return torch._C._accelerator_hooks_device_count()
125
+
126
+
127
+ def current_device() -> int:
128
+ r"""Return the index of a currently selected device."""
129
+ return torch._C._accelerator_hooks_get_current_device()
130
+
131
+
132
+ def current_stream(device: Optional[_device_t] = None) -> Stream:
133
+ r"""Return the currently selected :class:`Stream` for a given device.
134
+
135
+ Args:
136
+ device (torch.device or int, optional): selected device. Returns
137
+ the currently selected :class:`Stream` for the current device, given
138
+ by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
139
+ (default).
140
+ """
141
+ return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True))
142
+
143
+
144
+ def default_stream(device: Optional[_device_t] = None) -> Stream:
145
+ r"""Return the default :class:`Stream` for a given device.
146
+
147
+ Args:
148
+ device (torch.device or int, optional): selected device. Returns
149
+ the default :class:`Stream` for the current device, given by
150
+ :func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
151
+ (default).
152
+ """
153
+ return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True))
154
+
155
+
156
+ def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]:
157
+ r"""Return a dictionary of MTIA memory allocator statistics for a given device.
158
+
159
+ Args:
160
+ device (torch.device or int, optional) selected device. Returns
161
+ statistics for the current device, given by current_device(),
162
+ if device is None (default).
163
+ """
164
+ if not is_initialized():
165
+ return {}
166
+ return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
167
+
168
+
169
+ def set_stream(stream: Stream):
170
+ r"""Set the current stream.This is a wrapper API to set the stream.
171
+ Usage of this function is discouraged in favor of the ``stream``
172
+ context manager.
173
+
174
+ Args:
175
+ stream (Stream): selected stream. This function is a no-op
176
+ if this argument is ``None``.
177
+ """
178
+ if stream is None:
179
+ return
180
+ torch._C._mtia_setCurrentStream(stream)
181
+
182
+
183
+ def set_device(device: _device_t) -> None:
184
+ r"""Set the current device.
185
+
186
+ Args:
187
+ device (torch.device or int): selected device. This function is a no-op
188
+ if this argument is negative.
189
+ """
190
+ device = _get_device_index(device)
191
+ if device >= 0:
192
+ torch._C._accelerator_hooks_set_current_device(device)
193
+
194
+
195
+ class device:
196
+ r"""Context-manager that changes the selected device.
197
+
198
+ Args:
199
+ device (torch.device or int): device index to select. It's a no-op if
200
+ this argument is a negative integer or ``None``.
201
+ """
202
+
203
+ def __init__(self, device: Any):
204
+ self.idx = _get_device_index(device, optional=True)
205
+ self.prev_idx = -1
206
+
207
+ def __enter__(self):
208
+ self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx)
209
+
210
+ def __exit__(self, type: Any, value: Any, traceback: Any):
211
+ self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx)
212
+ return False
213
+
214
+
215
+ class StreamContext:
216
+ r"""Context-manager that selects a given stream.
217
+
218
+ All MTIA kernels queued within its context will be enqueued on a selected
219
+ stream.
220
+
221
+ Args:
222
+ Stream (Stream): selected stream. This manager is a no-op if it's
223
+ ``None``.
224
+ .. note:: Streams are per-device.
225
+ """
226
+
227
+ cur_stream: Optional["torch.mtia.Stream"]
228
+
229
+ def __init__(self, stream: Optional["torch.mtia.Stream"]):
230
+ self.cur_stream = None
231
+ self.stream = stream
232
+ self.idx = _get_device_index(None, True)
233
+ if not torch.jit.is_scripting():
234
+ if self.idx is None:
235
+ self.idx = -1
236
+
237
+ self.src_prev_stream = (
238
+ None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
239
+ )
240
+ self.dst_prev_stream = (
241
+ None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
242
+ )
243
+
244
+ def __enter__(self):
245
+ # Local cur_stream variable for type refinement
246
+ cur_stream = self.stream
247
+ # Return if stream is None or MTIA device not available
248
+ if cur_stream is None or self.idx == -1:
249
+ return
250
+ self.src_prev_stream = torch.mtia.current_stream(None)
251
+
252
+ # If the stream is not on the current device, then
253
+ # set the current stream on the device
254
+ if self.src_prev_stream.device != cur_stream.device:
255
+ with device(cur_stream.device):
256
+ self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device)
257
+ torch.mtia.set_stream(cur_stream)
258
+
259
+ def __exit__(self, type: Any, value: Any, traceback: Any):
260
+ # Local cur_stream variable for type refinement
261
+ cur_stream = self.stream
262
+ # If stream is None or no MTIA device available, return
263
+ if cur_stream is None or self.idx == -1:
264
+ return
265
+
266
+ # Reset the stream on the original device
267
+ # and destination device
268
+ if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
269
+ torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
270
+ torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
271
+
272
+
273
+ def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
274
+ r"""Wrap around the Context-manager StreamContext that selects a given stream.
275
+
276
+ Arguments:
277
+ stream (Stream): selected stream. This manager is a no-op if it's
278
+ ``None``.
279
+ ..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream
280
+ """
281
+ return StreamContext(stream)
282
+
283
+
284
+ def get_rng_state(device: Union[int, str, torch.device] = "mtia") -> Tensor:
285
+ r"""Returns the random number generator state as a ByteTensor.
286
+
287
+ Args:
288
+ device (torch.device or int, optional): The device to return the RNG state of.
289
+ Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device).
290
+ """
291
+ warnings.warn(
292
+ "get_rng_state is not implemented in torch.mtia",
293
+ UserWarning,
294
+ stacklevel=2,
295
+ )
296
+ return torch.zeros([1], dtype=torch.uint8, device=device)
297
+
298
+
299
+ def set_rng_state(
300
+ new_state: Tensor, device: Union[int, str, torch.device] = "mtia"
301
+ ) -> None:
302
+ r"""Sets the random number generator state.
303
+
304
+ Args:
305
+ new_state (torch.ByteTensor): The desired state
306
+ device (torch.device or int, optional): The device to set the RNG state.
307
+ Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device).
308
+ """
309
+ warnings.warn(
310
+ "set_rng_state is not implemented in torch.mtia",
311
+ UserWarning,
312
+ stacklevel=2,
313
+ )
314
+
315
+
316
+ __all__ = [
317
+ "init",
318
+ "is_available",
319
+ "is_initialized",
320
+ "synchronize",
321
+ "device_count",
322
+ "current_device",
323
+ "current_stream",
324
+ "default_stream",
325
+ "memory_stats",
326
+ "set_device",
327
+ "set_stream",
328
+ "stream",
329
+ "device",
330
+ "set_rng_state",
331
+ "get_rng_state",
332
+ ]
.venv/Lib/site-packages/torch/mtia/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (10.2 kB). View file
 
.venv/Lib/site-packages/torch/mtia/__pycache__/_utils.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
.venv/Lib/site-packages/torch/multiprocessing/__init__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module.
3
+
4
+ It registers custom reducers, that use shared memory to provide shared
5
+ views on the same data in different processes. Once the tensor/storage is moved
6
+ to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible
7
+ to send it to other processes without making any copies.
8
+
9
+ The API is 100% compatible with the original module - it's enough to change
10
+ ``import multiprocessing`` to ``import torch.multiprocessing`` to have all the
11
+ tensors sent through the queues or shared via other mechanisms, moved to shared
12
+ memory.
13
+
14
+ Because of the similarity of APIs we do not document most of this package
15
+ contents, and we recommend referring to very good docs of the original module.
16
+ """
17
+ import multiprocessing
18
+ import sys
19
+
20
+ import torch
21
+
22
+ from .reductions import init_reductions
23
+
24
+
25
+ __all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]
26
+
27
+
28
+ from multiprocessing import * # noqa: F403
29
+
30
+
31
+ __all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined]
32
+
33
+
34
+ # This call adds a Linux specific prctl(2) wrapper function to this module.
35
+ # See https://github.com/pytorch/pytorch/pull/14391 for more information.
36
+ torch._C._multiprocessing_init()
37
+
38
+
39
+ """Add helper function to spawn N processes and wait for completion of any of
40
+ them. This depends `mp.get_context` which was added in Python 3.4."""
41
+ from .spawn import (
42
+ ENV_VAR_PARALLEL_START,
43
+ ProcessContext,
44
+ ProcessExitedException,
45
+ ProcessRaisedException,
46
+ spawn,
47
+ SpawnContext,
48
+ start_processes,
49
+ )
50
+
51
+
52
+ if sys.platform == "darwin" or sys.platform == "win32":
53
+ _sharing_strategy = "file_system"
54
+ _all_sharing_strategies = {"file_system"}
55
+ else:
56
+ _sharing_strategy = "file_descriptor"
57
+ _all_sharing_strategies = {"file_descriptor", "file_system"}
58
+
59
+
60
+ def set_sharing_strategy(new_strategy):
61
+ """Set the strategy for sharing CPU tensors.
62
+
63
+ Args:
64
+ new_strategy (str): Name of the selected strategy. Should be one of
65
+ the values returned by :func:`get_all_sharing_strategies()`.
66
+ """
67
+ global _sharing_strategy
68
+ assert new_strategy in _all_sharing_strategies
69
+ _sharing_strategy = new_strategy
70
+
71
+
72
+ def get_sharing_strategy():
73
+ """Return the current strategy for sharing CPU tensors."""
74
+ return _sharing_strategy
75
+
76
+
77
+ def get_all_sharing_strategies():
78
+ """Return a set of sharing strategies supported on a current system."""
79
+ return _all_sharing_strategies
80
+
81
+
82
+ def _set_thread_name(name: str) -> None:
83
+ """Set the name of the current thread.
84
+
85
+ Args:
86
+ name (str): Name of the current thread.
87
+ """
88
+ torch._C._set_thread_name(name)
89
+
90
+
91
+ def _get_thread_name() -> str:
92
+ """Get the name of the current thread.
93
+
94
+ Returns:
95
+ str: Name of the current thread.
96
+ """
97
+ return torch._C._get_thread_name()
98
+
99
+
100
+ init_reductions()
.venv/Lib/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-39.pyc ADDED
Binary file (11.5 kB). View file
 
.venv/Lib/site-packages/torch/multiprocessing/_atfork.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import sys
3
+
4
+
5
+ __all__ = ["register_after_fork"]
6
+
7
+ if sys.platform == "win32":
8
+ import multiprocessing.util as _util
9
+
10
+ def _register(func):
11
+ def wrapper(arg):
12
+ func()
13
+
14
+ _util.register_after_fork(_register, wrapper)
15
+
16
+ else:
17
+ import os
18
+
19
+ def _register(func):
20
+ os.register_at_fork(after_in_child=func)
21
+
22
+
23
+ def register_after_fork(func):
24
+ """Register a callable to be executed in the child process after a fork.
25
+
26
+ Note:
27
+ In python < 3.7 this will only work with processes created using the
28
+ ``multiprocessing`` module. In python >= 3.7 it also works with
29
+ ``os.fork()``.
30
+
31
+ Args:
32
+ func (function): Function taking no arguments to be called in the child after fork
33
+
34
+ """
35
+ _register(func)
.venv/Lib/site-packages/torch/multiprocessing/pool.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing.pool
2
+ import multiprocessing.util as util
3
+
4
+ from .queue import SimpleQueue
5
+
6
+
7
+ def clean_worker(*args, **kwargs):
8
+ import gc
9
+
10
+ multiprocessing.pool.worker(*args, **kwargs)
11
+ # Regular multiprocessing workers don't fully clean up after themselves,
12
+ # so we have to explicitly trigger garbage collection to make sure that all
13
+ # destructors are called...
14
+ gc.collect()
15
+
16
+
17
+ class Pool(multiprocessing.pool.Pool):
18
+ """Pool implementation which uses our version of SimpleQueue.
19
+
20
+ This lets us pass tensors in shared memory across processes instead of
21
+ serializing the underlying data.
22
+ """
23
+
24
+ def _setup_queues(self):
25
+ self._inqueue = SimpleQueue()
26
+ self._outqueue = SimpleQueue()
27
+ self._quick_put = self._inqueue._writer.send
28
+ self._quick_get = self._outqueue._reader.recv
29
+
30
+ def _repopulate_pool(self):
31
+ """Increase the number of pool processes to the specified number.
32
+
33
+ Bring the number of pool processes up to the specified number, for use after
34
+ reaping workers which have exited.
35
+ """
36
+ for i in range(self._processes - len(self._pool)):
37
+ # changed worker -> clean_worker
38
+ args = (
39
+ self._inqueue,
40
+ self._outqueue,
41
+ self._initializer,
42
+ self._initargs,
43
+ self._maxtasksperchild,
44
+ )
45
+ if hasattr(self, "_wrap_exception"):
46
+ args += (self._wrap_exception,)
47
+ w = self.Process(target=clean_worker, args=args)
48
+ self._pool.append(w)
49
+ w.name = w.name.replace("Process", "PoolWorker")
50
+ w.daemon = True
51
+ w.start()
52
+ util.debug("added worker")
.venv/Lib/site-packages/torch/multiprocessing/queue.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import io
3
+ import multiprocessing.queues
4
+ import pickle
5
+ from multiprocessing.reduction import ForkingPickler
6
+
7
+
8
+ class ConnectionWrapper:
9
+ """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization."""
10
+
11
+ def __init__(self, conn):
12
+ self.conn = conn
13
+
14
+ def send(self, obj):
15
+ buf = io.BytesIO()
16
+ ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
17
+ self.send_bytes(buf.getvalue())
18
+
19
+ def recv(self):
20
+ buf = self.recv_bytes()
21
+ return pickle.loads(buf)
22
+
23
+ def __getattr__(self, name):
24
+ if "conn" in self.__dict__:
25
+ return getattr(self.conn, name)
26
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'")
27
+
28
+
29
+ class Queue(multiprocessing.queues.Queue):
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+ self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
33
+ self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
34
+ self._send = self._writer.send
35
+ self._recv = self._reader.recv
36
+
37
+
38
+ class SimpleQueue(multiprocessing.queues.SimpleQueue):
39
+ def _make_methods(self):
40
+ if not isinstance(self._reader, ConnectionWrapper):
41
+ self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
42
+ self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
43
+ super()._make_methods() # type: ignore[misc]
.venv/Lib/site-packages/torch/multiprocessing/reductions.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import multiprocessing
3
+ import os
4
+ import threading
5
+ from multiprocessing.reduction import ForkingPickler
6
+ from multiprocessing.util import register_after_fork
7
+ from typing import Union
8
+
9
+ import torch
10
+ from torch._namedtensor_internals import check_serializing_named_tensor
11
+
12
+
13
+ try:
14
+ # Early load resource_sharer to prevent a partially initialized instance
15
+ # from being inherited in a forked child process. The reduce_storage method
16
+ # requires this module indirectly through DupFd(). The built-in mp.Queue
17
+ # class pickles arguments in a background thread which may overlap with the
18
+ # fork.
19
+ import multiprocessing.resource_sharer
20
+ except ImportError:
21
+ pass
22
+
23
+
24
+ class StorageWeakRef:
25
+ r"""A weak reference to a Storage.
26
+
27
+ The cdata member is a Python number containing the integer representation of
28
+ the Storage pointer.
29
+ """
30
+
31
+ __slots__ = ["cdata", "_free_weak_ref"]
32
+
33
+ def __init__(self, storage):
34
+ self.cdata = storage._weak_ref()
35
+ # Save a direct reference to _free_weak_ref because the `torch` module
36
+ # might be cleared during Python shutdown before this module is cleared.
37
+ self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
38
+
39
+ @classmethod
40
+ def from_weakref(cls, cdata):
41
+ instance = cls.__new__(cls)
42
+ instance.cdata = cdata
43
+ instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
44
+ return instance
45
+
46
+ def expired(self):
47
+ return torch.Storage._expired(self.cdata) # type: ignore[attr-defined]
48
+
49
+ def __del__(self):
50
+ self._free_weak_ref(self.cdata)
51
+
52
+ def __hash__(self):
53
+ return self.cdata
54
+
55
+ def __eq__(self, other):
56
+ if id(self) == id(other):
57
+ return True
58
+ return self.cdata == other.cdata
59
+
60
+
61
+ class SharedCache(dict):
62
+ """Dictionary from multiprocessing handles to StorageWeakRef."""
63
+
64
+ def __init__(self) -> None:
65
+ # free_dead_references() is called if the len exceeds the current
66
+ # limit. The limit scales with the number of remaining live objects.
67
+ self.limit = 128
68
+ # `fork` inherits lock state, so in case we fork when the lock is held,
69
+ # we register a function to reset the lock to a new object to avoid
70
+ # possible deadlocks, following python multiprocessing library design.
71
+ self._after_fork()
72
+ register_after_fork(self, SharedCache._after_fork)
73
+
74
+ def _after_fork(self):
75
+ self.lock = threading.Lock()
76
+
77
+ def get(self, key):
78
+ with self.lock:
79
+ return dict.get(self, key)
80
+
81
+ def __setitem__(self, key, storage_ref):
82
+ with self.lock:
83
+ dict.__setitem__(self, key, storage_ref)
84
+ if len(self) > self.limit:
85
+ self.free_dead_references()
86
+
87
+ def free_dead_references(self):
88
+ live = 0
89
+ for key, storage_ref in list(self.items()):
90
+ if storage_ref.expired():
91
+ del self[key]
92
+ else:
93
+ live += 1
94
+ self.limit = max(128, live * 2)
95
+
96
+
97
+ # mapping from handles to StorageWeakRef objects
98
+ shared_cache = SharedCache()
99
+
100
+
101
+ def rebuild_event(device, handle):
102
+ return torch.cuda.Event.from_ipc_handle(device, handle)
103
+
104
+
105
+ def reduce_event(event):
106
+ handle = event.ipc_handle()
107
+ return (rebuild_event, (event.device, handle))
108
+
109
+
110
+ def rebuild_tensor(cls, storage, metadata):
111
+ storage_offset, size, stride, requires_grad = metadata
112
+ t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
113
+ if cls == torch.nn.parameter.Parameter:
114
+ # we have to pass requires_grad into constructor, rather than set it as an
115
+ # attribute later, because it's an important check for Integer Tensors to
116
+ # have requires_grad=False (or else they raise an error)
117
+ t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
118
+ else:
119
+ t.requires_grad = requires_grad
120
+ return t
121
+
122
+
123
+ def rebuild_meta_tensor(
124
+ tensor_cls,
125
+ tensor_size,
126
+ tensor_stride,
127
+ tensor_offset,
128
+ dtype,
129
+ storage_size_bytes,
130
+ requires_grad,
131
+ ):
132
+ untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
133
+
134
+ typed_storage = torch.TypedStorage(
135
+ wrap_storage=untyped_storage, dtype=dtype, _internal=True
136
+ )
137
+
138
+ t = torch._utils._rebuild_tensor(
139
+ typed_storage,
140
+ tensor_offset,
141
+ tensor_size,
142
+ tensor_stride,
143
+ )
144
+
145
+ if tensor_cls == torch.nn.parameter.Parameter:
146
+ # It is crucial for integer tensors to receive
147
+ # the requires_grad=False as an argument in the constructor
148
+ t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
149
+ else:
150
+ t.requires_grad = requires_grad
151
+
152
+ return t
153
+
154
+
155
+ def rebuild_cuda_tensor(
156
+ tensor_cls,
157
+ tensor_size,
158
+ tensor_stride,
159
+ tensor_offset,
160
+ storage_cls,
161
+ dtype,
162
+ storage_device,
163
+ storage_handle,
164
+ storage_size_bytes,
165
+ storage_offset_bytes,
166
+ requires_grad,
167
+ ref_counter_handle,
168
+ ref_counter_offset,
169
+ event_handle,
170
+ event_sync_required,
171
+ ):
172
+ # If storage_handle is None, storage points to nullptr.
173
+ if storage_handle is None or storage_size_bytes == 0:
174
+ storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
175
+ else:
176
+ storage = storage_from_cache(
177
+ storage_cls, (storage_handle, storage_offset_bytes)
178
+ )
179
+ if storage is None:
180
+ torch.cuda._lazy_init()
181
+ storage = storage_cls._new_shared_cuda(
182
+ storage_device,
183
+ storage_handle,
184
+ storage_size_bytes,
185
+ storage_offset_bytes,
186
+ ref_counter_handle,
187
+ ref_counter_offset,
188
+ event_handle,
189
+ event_sync_required,
190
+ )
191
+ shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
192
+ storage
193
+ )
194
+ else:
195
+ # We already ref counting this Storage, but producer needs new ref-counters to be released.
196
+ storage_cls._release_ipc_counter(
197
+ ref_counter_handle, ref_counter_offset, device=storage_device
198
+ )
199
+
200
+ _storage = (
201
+ storage
202
+ if isinstance(storage, torch.UntypedStorage)
203
+ else storage._untyped_storage
204
+ )
205
+
206
+ t = torch._utils._rebuild_tensor(
207
+ torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
208
+ tensor_offset,
209
+ tensor_size,
210
+ tensor_stride,
211
+ )
212
+
213
+ if tensor_cls == torch.nn.parameter.Parameter:
214
+ # It is crucial for integer tensors to receive
215
+ # the requires_grad=False as an argument in the constructor
216
+ t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
217
+ else:
218
+ t.requires_grad = requires_grad
219
+
220
+ return t
221
+
222
+
223
+ def reduce_tensor(tensor):
224
+ if tensor.requires_grad and not tensor.is_leaf:
225
+ raise RuntimeError(
226
+ "Cowardly refusing to serialize non-leaf tensor which requires_grad, "
227
+ "since autograd does not support crossing process boundaries. "
228
+ "If you just want to transfer the data, call detach() on the tensor "
229
+ "before serializing (e.g., putting it on the queue)."
230
+ )
231
+
232
+ check_serializing_named_tensor(tensor)
233
+ torch.utils.hooks.warn_if_has_hooks(tensor)
234
+
235
+ # Note [CUDA IPC and the caching allocator]
236
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
237
+ # When you send a CUDA tensor over IPC, you might expect that you will
238
+ # get out the same storage from the other end. However, the CUDA caching
239
+ # allocator makes it difficult to preserve this invariant. Consider
240
+ # the following situation: a tensor of size 0x100 points to offset 0x20 of
241
+ # a storage at 0xA100 of size 0x100. (For simplicity, all of these
242
+ # sizes are given in bytes). HOWEVER, with the caching allocator, this storage
243
+ # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000.
244
+ #
245
+ # When we want to send this CUDA tensor over IPC, we must send the
246
+ # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
247
+ # the storage 0xA100 (because that is what CUDA supports). So, on the
248
+ # other end, there simply isn't any way to say, "Wait, you gave me
249
+ # a bigger region (0xA000) than the one I wanted (0xA100)".
250
+ #
251
+ # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
252
+ # one storage itself? No, because this cudaMalloc allocation might contain
253
+ # storages of mixed types: float, bytes, double... If you make the entire
254
+ # allocation a single storage of a type A, we'll hit an error when constructing
255
+ # a tensor of type B on the storage.
256
+ #
257
+ # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
258
+ # receiver side. However, cudaIpcMemHandles from each device in a given process may
259
+ # only be opened by one context per device per other process.
260
+ # If we open and close a memory handle multiples times in a process, CUDA is allowed
261
+ # to give it a different address; similarly, once we close the memory, we're not
262
+ # allowed to access it(and the storage/tensor built on top of it), even if it is
263
+ # still live in the original process. As we cannot make a cudaMalloc allocation
264
+ # to a single storage in one go, this requires us to cache the device pointer for
265
+ # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
266
+ # the old ones alives.
267
+ # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
268
+ #
269
+ # This is fine, because all we need to do is to save our position in the allocation,
270
+ # and reconstruct storage and tensor from it.
271
+ # 0xA000 -> -------CUDA Allocation------
272
+ # | |
273
+ # | |
274
+ # | |
275
+ # | |
276
+ # 0xA100 -> --------storage1 begin------
277
+ # | |
278
+ # 0xA120 -> --------tensor1 begin ------
279
+ # | |
280
+ # | |
281
+ # | |
282
+ # | |
283
+ # | |
284
+ # 0xA160 -> --------tensor1 end---------
285
+ # | |
286
+ # | |
287
+ # | |
288
+ # 0xA200 -> --------storage1 end--------
289
+ # | |
290
+ # 0xE000 -> --------CUDA allocation-----
291
+ #
292
+ # To send tensor1, the following info are required from sender to receiver for
293
+ # storage recontruction.
294
+ # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
295
+ # basePtr may not be exactly 0xA000 since it's a different process.
296
+ # 2. offset(0xA100) of storage1 in the CUDA allocation.
297
+ # 3. size of storage1(0x100).
298
+ #
299
+ # On receiver side:
300
+ # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
301
+ # of the same type using (basePtr, offset, size).
302
+ # 2. we can reconstruct the tensor on top of the reconstructed storage
303
+ # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
304
+ #
305
+ # This strategy has a few implications:
306
+ #
307
+ # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
308
+ # go (non-compositionally), and this requires to have a global map
309
+ # memHandle -> devPtr for each process.
310
+ #
311
+ # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize
312
+ # of the storage beyond 0x100 would merely have caused us to do a
313
+ # reallocation. You don't really want to do this, but if you did,
314
+ # all that would happen is that you would lose IPC sharing. But if
315
+ # you do this in the new world, we will happily let you write out of
316
+ # bounds of your "allocation", clobbering unrelated data in the cached
317
+ # allocator block. BAD!
318
+ #
319
+ # By the way, in old versions of PyTorch, we supported this situation
320
+ # natively using a "storage view", which permitted multiple storages to be
321
+ # views on each other. But this was the *only* use of storage views, so we
322
+ # eliminated it so that we could just use tensor views to implement the same
323
+ # thing.
324
+ #
325
+
326
+ # TODO: Handle distinguishing between subclass and non-subclass versions of NT better
327
+ # https://github.com/pytorch/pytorch/issues/110543
328
+ from torch.nested._internal.nested_tensor import NestedTensor
329
+
330
+ if tensor.is_nested and not isinstance(tensor, NestedTensor):
331
+ return reduce_nested_tensor(tensor)
332
+
333
+ if tensor.layout in {
334
+ torch.sparse_coo,
335
+ torch.sparse_csr,
336
+ torch.sparse_bsr,
337
+ torch.sparse_csc,
338
+ torch.sparse_bsc,
339
+ }:
340
+ return reduce_sparse_tensor(tensor)
341
+
342
+ storage = tensor._typed_storage()
343
+
344
+ if storage._untyped_storage.device.type == "cuda":
345
+ (
346
+ device,
347
+ handle,
348
+ storage_size_bytes,
349
+ storage_offset_bytes,
350
+ ref_counter_handle,
351
+ ref_counter_offset,
352
+ event_handle,
353
+ event_sync_required,
354
+ ) = storage._share_cuda_()
355
+ tensor_offset = tensor.storage_offset()
356
+ shared_cache[handle] = StorageWeakRef(storage)
357
+ # _backward_hooks purposely omitted here, see
358
+ # Note [Don't serialize hooks]
359
+ return (
360
+ rebuild_cuda_tensor,
361
+ (
362
+ type(tensor),
363
+ tensor.size(),
364
+ tensor.stride(),
365
+ tensor_offset, # tensor offset in its storage
366
+ type(storage),
367
+ tensor.dtype,
368
+ device,
369
+ handle, # identifier which CUDA allocation is the storage in.
370
+ storage_size_bytes, # size(in bytes) of the storage
371
+ storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
372
+ tensor.requires_grad,
373
+ ref_counter_handle,
374
+ ref_counter_offset,
375
+ event_handle,
376
+ event_sync_required,
377
+ ),
378
+ )
379
+ elif storage._untyped_storage.device.type == "meta":
380
+ return (
381
+ rebuild_meta_tensor,
382
+ (
383
+ type(tensor),
384
+ tensor.size(),
385
+ tensor.stride(),
386
+ tensor.storage_offset(),
387
+ tensor.dtype,
388
+ tensor.untyped_storage().size(),
389
+ tensor.requires_grad,
390
+ ),
391
+ )
392
+
393
+ # _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
394
+ metadata = (
395
+ tensor.storage_offset(),
396
+ tensor.size(),
397
+ tensor.stride(),
398
+ tensor.requires_grad,
399
+ )
400
+ return (rebuild_tensor, (type(tensor), storage, metadata))
401
+
402
+
403
+ def rebuild_nested_tensor(
404
+ rebuild_buffer_func,
405
+ rebuild_buffer_args,
406
+ rebuild_sizes_func,
407
+ rebuild_sizes_args,
408
+ rebuild_strides_func,
409
+ rebuild_strides_args,
410
+ rebuild_offsets_func,
411
+ rebuild_offsets_args,
412
+ ):
413
+ buffer = rebuild_buffer_func(*rebuild_buffer_args)
414
+ sizes = rebuild_sizes_func(*rebuild_sizes_args)
415
+ strides = rebuild_strides_func(*rebuild_strides_args)
416
+ offsets = rebuild_offsets_func(*rebuild_offsets_args)
417
+ return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
418
+
419
+
420
+ def reduce_nested_tensor(nt):
421
+ rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
422
+ rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
423
+ rebuild_strides_func, rebuild_strides_args = reduce_tensor(
424
+ nt._nested_tensor_strides()
425
+ )
426
+ rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
427
+ nt._nested_tensor_storage_offsets()
428
+ )
429
+
430
+ return (
431
+ rebuild_nested_tensor,
432
+ (
433
+ rebuild_buffer_func,
434
+ rebuild_buffer_args,
435
+ rebuild_sizes_func,
436
+ rebuild_sizes_args,
437
+ rebuild_strides_func,
438
+ rebuild_strides_args,
439
+ rebuild_offsets_func,
440
+ rebuild_offsets_args,
441
+ ),
442
+ )
443
+
444
+
445
+ def rebuild_sparse_coo_tensor(
446
+ rebuild_indices_func,
447
+ rebuild_indices_args,
448
+ rebuild_values_func,
449
+ rebuild_values_args,
450
+ shape,
451
+ is_coalesced,
452
+ ):
453
+ indices = rebuild_indices_func(*rebuild_indices_args)
454
+ values = rebuild_values_func(*rebuild_values_args)
455
+ return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
456
+
457
+
458
+ def rebuild_sparse_compressed_tensor(
459
+ rebuild_compressed_indices_func,
460
+ rebuild_compressed_indices_args,
461
+ rebuild_plain_indices_func,
462
+ rebuild_plain_indices_args,
463
+ rebuild_values_func,
464
+ rebuild_values_args,
465
+ shape,
466
+ layout,
467
+ ):
468
+ compressed_indices = rebuild_compressed_indices_func(
469
+ *rebuild_compressed_indices_args
470
+ )
471
+ plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
472
+ values = rebuild_values_func(*rebuild_values_args)
473
+ return torch.sparse_compressed_tensor(
474
+ compressed_indices, plain_indices, values, shape, layout=layout
475
+ )
476
+
477
+
478
+ def reduce_sparse_tensor(sparse):
479
+ if sparse.layout is torch.sparse_coo:
480
+ rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
481
+ rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
482
+ return (
483
+ rebuild_sparse_coo_tensor,
484
+ (
485
+ rebuild_indices_func,
486
+ rebuild_indices_args,
487
+ rebuild_values_func,
488
+ rebuild_values_args,
489
+ sparse.shape,
490
+ sparse.is_coalesced(),
491
+ ),
492
+ )
493
+ else:
494
+ if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
495
+ compressed_indices = sparse.crow_indices()
496
+ plain_indices = sparse.col_indices()
497
+ elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
498
+ compressed_indices = sparse.ccol_indices()
499
+ plain_indices = sparse.row_indices()
500
+ else:
501
+ raise NotImplementedError(sparse.layout)
502
+ (
503
+ rebuild_compressed_indices_func,
504
+ rebuild_compressed_indices_args,
505
+ ) = reduce_tensor(compressed_indices)
506
+ rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
507
+ plain_indices
508
+ )
509
+ rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
510
+ return (
511
+ rebuild_sparse_compressed_tensor,
512
+ (
513
+ rebuild_compressed_indices_func,
514
+ rebuild_compressed_indices_args,
515
+ rebuild_plain_indices_func,
516
+ rebuild_plain_indices_args,
517
+ rebuild_values_func,
518
+ rebuild_values_args,
519
+ sparse.shape,
520
+ sparse.layout,
521
+ ),
522
+ )
523
+
524
+
525
+ def fd_id(fd):
526
+ # Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
527
+ # this doesn't work with shared memory handles, which is why we don't
528
+ # support the "file_descriptor" sharing method on that platform.
529
+ stat = os.fstat(fd)
530
+ return (stat.st_ino, stat.st_dev)
531
+
532
+
533
+ def storage_from_cache(cls, key):
534
+ storage_ref = shared_cache.get(key)
535
+ if storage_ref is None:
536
+ return None
537
+ return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
538
+
539
+
540
+ def rebuild_storage_fd(cls, df, size):
541
+ fd = df.detach()
542
+ try:
543
+ storage = storage_from_cache(cls, fd_id(fd))
544
+ if storage is not None:
545
+ return storage
546
+ storage = cls._new_shared_fd_cpu(fd, size)
547
+ shared_cache[fd_id(fd)] = StorageWeakRef(storage)
548
+ return storage
549
+ finally:
550
+ os.close(fd)
551
+
552
+
553
+ def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
554
+ storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
555
+ cls, handle
556
+ )
557
+ if storage is not None:
558
+ return storage._shared_decref()
559
+ if dtype is None:
560
+ storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
561
+ else:
562
+ byte_size = size * torch._utils._element_size(dtype)
563
+ untyped_storage: torch.UntypedStorage = (
564
+ torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
565
+ )
566
+ storage = torch.TypedStorage(
567
+ wrap_storage=untyped_storage, dtype=dtype, _internal=True
568
+ )
569
+ shared_cache[handle] = StorageWeakRef(storage)
570
+ return storage._shared_decref()
571
+
572
+
573
+ def rebuild_storage_empty(cls):
574
+ return cls()
575
+
576
+
577
+ def rebuild_typed_storage(storage, dtype):
578
+ return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
579
+
580
+
581
+ # Use for torch.storage.TypedStorage
582
+ def reduce_typed_storage(storage):
583
+ return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
584
+
585
+
586
+ def rebuild_typed_storage_child(storage, storage_type):
587
+ return storage_type(wrap_storage=storage, _internal=True)
588
+
589
+
590
+ # Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
591
+ def reduce_typed_storage_child(storage):
592
+ return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
593
+
594
+
595
+ def reduce_storage(storage):
596
+ from . import get_sharing_strategy
597
+
598
+ if storage.is_cuda:
599
+ raise RuntimeError(
600
+ "Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
601
+ )
602
+ elif storage.device.type == "meta":
603
+ raise RuntimeError(
604
+ "Cannot pickle meta storage; try pickling a meta tensor instead"
605
+ )
606
+ elif get_sharing_strategy() == "file_system":
607
+ metadata = storage._share_filename_cpu_()
608
+ cache_key = metadata[1]
609
+ rebuild = rebuild_storage_filename
610
+ if isinstance(storage, torch.TypedStorage):
611
+ metadata += (storage.dtype,)
612
+ storage._shared_incref()
613
+ elif storage.size() == 0:
614
+ # This is special cased because Empty tensors
615
+ # (with size 0) cannot be mmapped.
616
+ return (rebuild_storage_empty, (type(storage),))
617
+ else:
618
+ fd, size = storage._share_fd_cpu_()
619
+ df = multiprocessing.reduction.DupFd(fd)
620
+ cache_key = fd_id(fd)
621
+ metadata = (df, size)
622
+ rebuild = rebuild_storage_fd # type: ignore[assignment]
623
+
624
+ shared_cache[cache_key] = StorageWeakRef(storage)
625
+ return (rebuild, (type(storage),) + metadata)
626
+
627
+
628
+ def init_reductions():
629
+ ForkingPickler.register(torch.cuda.Event, reduce_event)
630
+
631
+ for t in torch._storage_classes:
632
+ if t.__name__ == "UntypedStorage":
633
+ ForkingPickler.register(t, reduce_storage)
634
+ else:
635
+ ForkingPickler.register(t, reduce_typed_storage_child)
636
+
637
+ ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
638
+
639
+ for t in torch._tensor_classes:
640
+ ForkingPickler.register(t, reduce_tensor)
641
+
642
+ # TODO: Maybe this should be in tensor_classes? :)
643
+ ForkingPickler.register(torch.Tensor, reduce_tensor)
644
+
645
+ from torch.nn.parameter import Parameter
646
+
647
+ ForkingPickler.register(Parameter, reduce_tensor)
.venv/Lib/site-packages/torch/multiprocessing/spawn.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ import multiprocessing
4
+ import multiprocessing.connection
5
+ import os
6
+ import pickle
7
+ import signal
8
+ import sys
9
+ import tempfile
10
+ import time
11
+ import warnings
12
+ from concurrent.futures import as_completed, ThreadPoolExecutor
13
+ from typing import Optional
14
+
15
+ from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
16
+
17
+
18
+ ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+ __all__ = [
23
+ "ProcessContext",
24
+ "ProcessException",
25
+ "ProcessExitedException",
26
+ "ProcessRaisedException",
27
+ "spawn",
28
+ "SpawnContext",
29
+ "start_processes",
30
+ ]
31
+
32
+
33
+ class ProcessException(Exception):
34
+ __slots__ = ["error_index", "error_pid"]
35
+
36
+ def __init__(self, msg: str, error_index: int, pid: int):
37
+ super().__init__(msg)
38
+ self.msg = msg
39
+ self.error_index = error_index
40
+ self.pid = pid
41
+
42
+ def __reduce__(self):
43
+ return type(self), (self.msg, self.error_index, self.pid)
44
+
45
+
46
+ class ProcessRaisedException(ProcessException):
47
+ """Exception raised when a process failed due to an exception raised by the code."""
48
+
49
+ def __init__(
50
+ self,
51
+ msg: str,
52
+ error_index: int,
53
+ error_pid: int,
54
+ ):
55
+ super().__init__(msg, error_index, error_pid)
56
+
57
+
58
+ class ProcessExitedException(ProcessException):
59
+ """Exception raised when a process failed due to signal or exited with a specific code."""
60
+
61
+ __slots__ = ["exit_code"]
62
+
63
+ def __init__(
64
+ self,
65
+ msg: str,
66
+ error_index: int,
67
+ error_pid: int,
68
+ exit_code: int,
69
+ signal_name: Optional[str] = None,
70
+ ):
71
+ super().__init__(msg, error_index, error_pid)
72
+ self.exit_code = exit_code
73
+ self.signal_name = signal_name
74
+
75
+ def __reduce__(self):
76
+ return (
77
+ type(self),
78
+ (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
79
+ )
80
+
81
+
82
+ def _wrap(fn, i, args, error_file):
83
+ # prctl(2) is a Linux specific system call.
84
+ # On other systems the following function call has no effect.
85
+ # This is set to ensure that non-daemonic child processes can
86
+ # terminate if their parent terminates before they do.
87
+ _prctl_pr_set_pdeathsig(signal.SIGINT)
88
+
89
+ try:
90
+ fn(i, *args)
91
+ except KeyboardInterrupt:
92
+ pass # SIGINT; Killed by parent, do nothing
93
+ except Exception:
94
+ # Propagate exception to parent process, keeping original traceback
95
+ import traceback
96
+
97
+ with open(error_file, "wb") as fh:
98
+ pickle.dump(traceback.format_exc(), fh)
99
+ sys.exit(1)
100
+
101
+
102
+ class ProcessContext:
103
+ def __init__(self, processes, error_files):
104
+ self.error_files = error_files
105
+ self.processes = processes
106
+ self.sentinels = {
107
+ process.sentinel: index for index, process in enumerate(processes)
108
+ }
109
+
110
+ def pids(self):
111
+ return [int(process.pid) for process in self.processes]
112
+
113
+ def join(self, timeout=None):
114
+ r"""Join one or more processes within spawn context.
115
+
116
+ Attempt to join one or more processes in this spawn context.
117
+ If one of them exited with a non-zero exit status, this function
118
+ kills the remaining processes and raises an exception with the cause
119
+ of the first process exiting.
120
+
121
+ Returns ``True`` if all processes have been joined successfully,
122
+ ``False`` if there are more processes that need to be joined.
123
+
124
+ Args:
125
+ timeout (float): Wait this long before giving up on waiting.
126
+ """
127
+ # Ensure this function can be called even when we're done.
128
+ if len(self.sentinels) == 0:
129
+ return True
130
+
131
+ # Wait for any process to fail or all of them to succeed.
132
+ ready = multiprocessing.connection.wait(
133
+ self.sentinels.keys(),
134
+ timeout=timeout,
135
+ )
136
+
137
+ error_index = None
138
+ for sentinel in ready:
139
+ index = self.sentinels.pop(sentinel)
140
+ process = self.processes[index]
141
+ process.join()
142
+ if process.exitcode != 0:
143
+ error_index = index
144
+ break
145
+
146
+ # Return if there was no error.
147
+ if error_index is None:
148
+ # Return whether or not all processes have been joined.
149
+ return len(self.sentinels) == 0
150
+
151
+ # Assume failure. Terminate processes that are still alive.
152
+ # Try SIGTERM then SIGKILL if the process isn't going down.
153
+ # The reason is related to python signal handling is limited
154
+ # to main thread and if that is in c/c++ land and stuck it won't
155
+ # to handle it. We have seen processes getting stuck not handling
156
+ # SIGTERM for the above reason.
157
+ timeout: int = 30
158
+ for process in self.processes:
159
+ if process.is_alive():
160
+ log.warning("Terminating process %s via signal SIGTERM", process.pid)
161
+ process.terminate()
162
+ end = time.monotonic() + timeout
163
+ for process in self.processes:
164
+ time_to_wait = max(0, end - time.monotonic())
165
+ process.join(time_to_wait)
166
+ for process in self.processes:
167
+ if process.is_alive():
168
+ log.warning(
169
+ "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
170
+ process.pid,
171
+ )
172
+ process.kill()
173
+ process.join()
174
+
175
+ # The file will only be created if the process crashed.
176
+ failed_process = self.processes[error_index]
177
+ if not os.access(self.error_files[error_index], os.R_OK):
178
+ exitcode = self.processes[error_index].exitcode
179
+ if exitcode < 0:
180
+ try:
181
+ name = signal.Signals(-exitcode).name
182
+ except ValueError:
183
+ name = f"<Unknown signal {-exitcode}>"
184
+ raise ProcessExitedException(
185
+ "process %d terminated with signal %s" % (error_index, name),
186
+ error_index=error_index,
187
+ error_pid=failed_process.pid,
188
+ exit_code=exitcode,
189
+ signal_name=name,
190
+ )
191
+ else:
192
+ raise ProcessExitedException(
193
+ "process %d terminated with exit code %d" % (error_index, exitcode),
194
+ error_index=error_index,
195
+ error_pid=failed_process.pid,
196
+ exit_code=exitcode,
197
+ )
198
+
199
+ with open(self.error_files[error_index], "rb") as fh:
200
+ original_trace = pickle.load(fh)
201
+ msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
202
+ msg += original_trace
203
+ raise ProcessRaisedException(msg, error_index, failed_process.pid)
204
+
205
+
206
+ class SpawnContext(ProcessContext):
207
+ def __init__(self, processes, error_files):
208
+ warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
209
+ super().__init__(processes, error_files)
210
+
211
+
212
+ # Note: [start_processes]
213
+ # mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
214
+ # more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
215
+ # CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
216
+ # works better than 'spawn'. Every helper function we created for mp.spawn is indeed
217
+ # general enough, and backends like XLA can reuse them in Colab notebooks as well.
218
+ # Currently we only add this API first, we can consider adding it to documentation as
219
+ # needed in the future.
220
+ def start_processes(
221
+ fn,
222
+ args=(),
223
+ nprocs=1,
224
+ join=True,
225
+ daemon=False,
226
+ start_method="spawn",
227
+ ):
228
+ # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
229
+ # this func will start processes in parallel if start_method is 'forkserver'.
230
+ # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
231
+ # todo: investigate why spawn does not work with threadpool and raises SIGINT
232
+ if (
233
+ start_method == "forkserver"
234
+ and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
235
+ ):
236
+ log.info("Starting processes in parallel.")
237
+ start_parallel = True
238
+ else:
239
+ # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
240
+ start_parallel = False
241
+
242
+ mp = multiprocessing.get_context(start_method)
243
+ error_files = [None] * nprocs
244
+ processes = [None] * nprocs
245
+
246
+ def start_process(i):
247
+ # Each process is assigned a file to write tracebacks to. We
248
+ # use the file being non-empty to indicate an exception
249
+ # occurred (vs an expected shutdown). Note: this previously
250
+ # used a multiprocessing.Queue but that can be prone to
251
+ # deadlocks, so we went with a simpler solution for a one-shot
252
+ # message between processes.
253
+ tf = tempfile.NamedTemporaryFile(
254
+ prefix="pytorch-errorfile-", suffix=".pickle", delete=False
255
+ )
256
+ tf.close()
257
+ os.unlink(tf.name)
258
+ process = mp.Process(
259
+ target=_wrap,
260
+ args=(fn, i, args, tf.name),
261
+ daemon=daemon,
262
+ )
263
+ process.start()
264
+ return i, process, tf.name
265
+
266
+ if not start_parallel:
267
+ for i in range(nprocs):
268
+ idx, process, tf_name = start_process(i)
269
+ error_files[idx] = tf_name
270
+ processes[idx] = process
271
+ else:
272
+ with ThreadPoolExecutor(max_workers=nprocs) as executor:
273
+ futures = [executor.submit(start_process, i) for i in range(nprocs)]
274
+ for fut in as_completed(futures):
275
+ idx, process, tf_name = fut.result()
276
+ # idx and process rank needs to be the same.
277
+ error_files[idx] = tf_name
278
+ processes[idx] = process
279
+ context = ProcessContext(processes, error_files)
280
+ if not join:
281
+ return context
282
+
283
+ # Loop on join until it returns True or raises an exception.
284
+ while not context.join():
285
+ pass
286
+
287
+
288
+ def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
289
+ r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
290
+
291
+ If one of the processes exits with a non-zero exit status, the
292
+ remaining processes are killed and an exception is raised with the
293
+ cause of termination. In the case an exception was caught in the
294
+ child process, it is forwarded and its traceback is included in
295
+ the exception raised in the parent process.
296
+
297
+ Args:
298
+ fn (function): Function is called as the entrypoint of the
299
+ spawned process. This function must be defined at the top
300
+ level of a module so it can be pickled and spawned. This
301
+ is a requirement imposed by multiprocessing.
302
+
303
+ The function is called as ``fn(i, *args)``, where ``i`` is
304
+ the process index and ``args`` is the passed through tuple
305
+ of arguments.
306
+
307
+ args (tuple): Arguments passed to ``fn``.
308
+ nprocs (int): Number of processes to spawn.
309
+ join (bool): Perform a blocking join on all processes.
310
+ daemon (bool): The spawned processes' daemon flag. If set to True,
311
+ daemonic processes will be created.
312
+ start_method (str): (deprecated) this method will always use ``spawn``
313
+ as the start method. To use a different start method
314
+ use ``start_processes()``.
315
+
316
+ Returns:
317
+ None if ``join`` is ``True``,
318
+ :class:`~ProcessContext` if ``join`` is ``False``
319
+
320
+ """
321
+ if start_method != "spawn":
322
+ msg = (
323
+ f"This method only supports start_method=spawn (got: {start_method}).\n"
324
+ "To use a different start_method use:\n\t\t"
325
+ " torch.multiprocessing.start_processes(...)"
326
+ )
327
+ warnings.warn(msg, FutureWarning, stacklevel=2)
328
+ return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
.venv/Lib/site-packages/torch/nn/parallel/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing_extensions import deprecated
3
+
4
+ from torch.nn.parallel.data_parallel import data_parallel, DataParallel
5
+ from torch.nn.parallel.distributed import DistributedDataParallel
6
+ from torch.nn.parallel.parallel_apply import parallel_apply
7
+ from torch.nn.parallel.replicate import replicate
8
+ from torch.nn.parallel.scatter_gather import gather, scatter
9
+
10
+
11
+ __all__ = [
12
+ "replicate",
13
+ "scatter",
14
+ "parallel_apply",
15
+ "gather",
16
+ "data_parallel",
17
+ "DataParallel",
18
+ "DistributedDataParallel",
19
+ ]
20
+
21
+
22
+ @deprecated(
23
+ "`torch.nn.parallel.DistributedDataParallelCPU` is deprecated, "
24
+ "please use `torch.nn.parallel.DistributedDataParallel` instead.",
25
+ category=FutureWarning,
26
+ )
27
+ class DistributedDataParallelCPU(DistributedDataParallel):
28
+ pass
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.03 kB). View file
 
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-39.pyc ADDED
Binary file (5.92 kB). View file
 
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/comm.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/distributed.cpython-39.pyc ADDED
Binary file (81.5 kB). View file
 
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-39.pyc ADDED
Binary file (4.06 kB). View file
 
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-39.pyc ADDED
Binary file (5.24 kB). View file
 
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-39.pyc ADDED
Binary file (5.19 kB). View file
 
.venv/Lib/site-packages/torch/nn/qat/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Dynamic Modules.
3
+
4
+ This package is in the process of being deprecated.
5
+ Please, use `torch.ao.nn.qat.dynamic` instead.
6
+ """
7
+ from torch.nn.qat import dynamic, modules # noqa: F403
8
+ from torch.nn.qat.modules import * # noqa: F403
9
+
10
+
11
+ __all__ = [
12
+ "Linear",
13
+ "Conv1d",
14
+ "Conv2d",
15
+ "Conv3d",
16
+ "Embedding",
17
+ "EmbeddingBag",
18
+ ]
.venv/Lib/site-packages/torch/nn/qat/dynamic/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Dynamic Modules.
3
+
4
+ This package is in the process of being deprecated.
5
+ Please, use `torch.ao.nn.qat.dynamic` instead.
6
+ """
7
+ from torch.nn.qat.dynamic.modules import * # noqa: F403
.venv/Lib/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (375 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from torch.nn.qat.dynamic.modules.linear import Linear
2
+
3
+
4
+ __all__ = ["Linear"]
.venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (288 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc ADDED
Binary file (615 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/qat/dynamic/modules/linear.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/qat/dynamic`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/qat/dynamic/modules`,
8
+ while adding an import statement here.
9
+ """
10
+ from torch.ao.nn.qat.dynamic.modules.linear import Linear
.venv/Lib/site-packages/torch/nn/qat/modules/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Modules.
3
+
4
+ This package is in the process of being deprecated.
5
+ Please, use `torch.ao.nn.qat.modules` instead.
6
+ """
7
+ from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
8
+ from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
9
+ from torch.ao.nn.qat.modules.linear import Linear
10
+ from torch.nn.qat.modules import conv, embedding_ops, linear
11
+
12
+
13
+ __all__ = [
14
+ "Linear",
15
+ "Conv1d",
16
+ "Conv2d",
17
+ "Conv3d",
18
+ "Embedding",
19
+ "EmbeddingBag",
20
+ ]
.venv/Lib/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-39.pyc ADDED
Binary file (613 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/qat/modules/conv.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/qat`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/qat/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
.venv/Lib/site-packages/torch/nn/qat/modules/embedding_ops.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/qat`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/qat/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
12
+
13
+
14
+ __all__ = ["Embedding", "EmbeddingBag"]
.venv/Lib/site-packages/torch/nn/qat/modules/linear.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/qat`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/qat/modules`,
8
+ while adding an import statement here.
9
+ """
10
+ from torch.ao.nn.qat.modules.linear import Linear
.venv/Lib/site-packages/torch/nn/quantized/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.quantized import dynamic, functional, modules # noqa: F403
2
+ from torch.nn.quantized.modules import * # noqa: F403
3
+ from torch.nn.quantized.modules import MaxPool2d
4
+
5
+
6
+ __all__ = [
7
+ "BatchNorm2d",
8
+ "BatchNorm3d",
9
+ "Conv1d",
10
+ "Conv2d",
11
+ "Conv3d",
12
+ "ConvTranspose1d",
13
+ "ConvTranspose2d",
14
+ "ConvTranspose3d",
15
+ "DeQuantize",
16
+ "Dropout",
17
+ "ELU",
18
+ "Embedding",
19
+ "EmbeddingBag",
20
+ "GroupNorm",
21
+ "Hardswish",
22
+ "InstanceNorm1d",
23
+ "InstanceNorm2d",
24
+ "InstanceNorm3d",
25
+ "LayerNorm",
26
+ "LeakyReLU",
27
+ "Linear",
28
+ "LSTM",
29
+ "MultiheadAttention",
30
+ "PReLU",
31
+ "Quantize",
32
+ "ReLU6",
33
+ "Sigmoid",
34
+ "Softmax",
35
+ # Wrapper modules
36
+ "FloatFunctional",
37
+ "FXFloatFunctional",
38
+ "QFunctional",
39
+ ]
.venv/Lib/site-packages/torch/nn/quantized/dynamic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torch.ao.nn.quantized.dynamic import * # noqa: F403
.venv/Lib/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (244 Bytes). View file
 
.venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Dynamic Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
5
+ and is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/dynamic`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.dynamic.modules import conv, linear, rnn
12
+ from torch.ao.nn.quantized.dynamic.modules.conv import (
13
+ Conv1d,
14
+ Conv2d,
15
+ Conv3d,
16
+ ConvTranspose1d,
17
+ ConvTranspose2d,
18
+ ConvTranspose3d,
19
+ )
20
+ from torch.ao.nn.quantized.dynamic.modules.linear import Linear
21
+ from torch.ao.nn.quantized.dynamic.modules.rnn import (
22
+ GRU,
23
+ GRUCell,
24
+ LSTM,
25
+ LSTMCell,
26
+ RNNCell,
27
+ )
28
+
29
+
30
+ __all__ = [
31
+ "Linear",
32
+ "LSTM",
33
+ "GRU",
34
+ "LSTMCell",
35
+ "RNNCell",
36
+ "GRUCell",
37
+ "Conv1d",
38
+ "Conv2d",
39
+ "Conv3d",
40
+ "ConvTranspose1d",
41
+ "ConvTranspose2d",
42
+ "ConvTranspose3d",
43
+ ]
.venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/conv.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Dynamic Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
5
+ and is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.dynamic.modules.conv import (
12
+ Conv1d,
13
+ Conv2d,
14
+ Conv3d,
15
+ ConvTranspose1d,
16
+ ConvTranspose2d,
17
+ ConvTranspose3d,
18
+ )
19
+
20
+
21
+ __all__ = [
22
+ "Conv1d",
23
+ "Conv2d",
24
+ "Conv3d",
25
+ "ConvTranspose1d",
26
+ "ConvTranspose2d",
27
+ "ConvTranspose3d",
28
+ ]
.venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/linear.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Dynamic Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
5
+ and is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
8
+ while adding an import statement here.
9
+ """
10
+ from torch.ao.nn.quantized.dynamic.modules.linear import Linear
.venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/rnn.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Dynamic Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
5
+ and is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.dynamic.modules.rnn import (
12
+ GRU,
13
+ GRUCell,
14
+ LSTM,
15
+ LSTMCell,
16
+ pack_weight_bias,
17
+ PackedParameter,
18
+ RNNBase,
19
+ RNNCell,
20
+ RNNCellBase,
21
+ )
22
+
23
+
24
+ __all__ = [
25
+ "pack_weight_bias",
26
+ "PackedParameter",
27
+ "RNNBase",
28
+ "LSTM",
29
+ "GRU",
30
+ "RNNCellBase",
31
+ "RNNCell",
32
+ "LSTMCell",
33
+ "GRUCell",
34
+ ]
.venv/Lib/site-packages/torch/nn/quantized/functional.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""nn.quantized.functional.
2
+
3
+ Quantized equivalents of the `nn.functional`.
4
+
5
+ Note::
6
+ This location is in the process of being deprecated.
7
+ Please, use the `torch.ao.nn.quantized.functional` instead.
8
+ """
9
+
10
+ from torch.ao.nn.quantized.functional import * # noqa: F401,F403
.venv/Lib/site-packages/torch/nn/quantized/modules/__init__.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Quantized Modules.
2
+
3
+ Note::
4
+ The `torch.nn.quantized` namespace is in the process of being deprecated.
5
+ Please, use `torch.ao.nn.quantized` instead.
6
+ """
7
+
8
+ # The following imports are needed in case the user decides
9
+ # to import the files directly,
10
+ # s.a. `from torch.nn.quantized.modules.conv import ...`.
11
+ # No need to add them to the `__all__`.
12
+ from torch.ao.nn.quantized.modules import (
13
+ activation,
14
+ batchnorm,
15
+ conv,
16
+ DeQuantize,
17
+ dropout,
18
+ embedding_ops,
19
+ functional_modules,
20
+ linear,
21
+ MaxPool2d,
22
+ normalization,
23
+ Quantize,
24
+ rnn,
25
+ utils,
26
+ )
27
+ from torch.ao.nn.quantized.modules.activation import (
28
+ ELU,
29
+ Hardswish,
30
+ LeakyReLU,
31
+ MultiheadAttention,
32
+ PReLU,
33
+ ReLU6,
34
+ Sigmoid,
35
+ Softmax,
36
+ )
37
+ from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
38
+ from torch.ao.nn.quantized.modules.conv import (
39
+ Conv1d,
40
+ Conv2d,
41
+ Conv3d,
42
+ ConvTranspose1d,
43
+ ConvTranspose2d,
44
+ ConvTranspose3d,
45
+ )
46
+ from torch.ao.nn.quantized.modules.dropout import Dropout
47
+ from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
48
+ from torch.ao.nn.quantized.modules.functional_modules import (
49
+ FloatFunctional,
50
+ FXFloatFunctional,
51
+ QFunctional,
52
+ )
53
+ from torch.ao.nn.quantized.modules.linear import Linear
54
+ from torch.ao.nn.quantized.modules.normalization import (
55
+ GroupNorm,
56
+ InstanceNorm1d,
57
+ InstanceNorm2d,
58
+ InstanceNorm3d,
59
+ LayerNorm,
60
+ )
61
+ from torch.ao.nn.quantized.modules.rnn import LSTM
62
+
63
+
64
+ __all__ = [
65
+ "BatchNorm2d",
66
+ "BatchNorm3d",
67
+ "Conv1d",
68
+ "Conv2d",
69
+ "Conv3d",
70
+ "ConvTranspose1d",
71
+ "ConvTranspose2d",
72
+ "ConvTranspose3d",
73
+ "DeQuantize",
74
+ "ELU",
75
+ "Embedding",
76
+ "EmbeddingBag",
77
+ "GroupNorm",
78
+ "Hardswish",
79
+ "InstanceNorm1d",
80
+ "InstanceNorm2d",
81
+ "InstanceNorm3d",
82
+ "LayerNorm",
83
+ "LeakyReLU",
84
+ "Linear",
85
+ "LSTM",
86
+ "MultiheadAttention",
87
+ "Quantize",
88
+ "ReLU6",
89
+ "Sigmoid",
90
+ "Softmax",
91
+ "Dropout",
92
+ "PReLU",
93
+ # Wrapper modules
94
+ "FloatFunctional",
95
+ "FXFloatFunctional",
96
+ "QFunctional",
97
+ ]
.venv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.93 kB). View file
 
.venv/Lib/site-packages/torch/nn/quantized/modules/activation.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.modules.activation import (
12
+ ELU,
13
+ Hardswish,
14
+ LeakyReLU,
15
+ MultiheadAttention,
16
+ PReLU,
17
+ ReLU6,
18
+ Sigmoid,
19
+ Softmax,
20
+ )
.venv/Lib/site-packages/torch/nn/quantized/modules/batchnorm.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
.venv/Lib/site-packages/torch/nn/quantized/modules/conv.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.modules.conv import (
12
+ _reverse_repeat_padding,
13
+ Conv1d,
14
+ Conv2d,
15
+ Conv3d,
16
+ ConvTranspose1d,
17
+ ConvTranspose2d,
18
+ ConvTranspose3d,
19
+ )
20
+
21
+
22
+ __all__ = [
23
+ "Conv1d",
24
+ "Conv2d",
25
+ "Conv3d",
26
+ "ConvTranspose1d",
27
+ "ConvTranspose2d",
28
+ "ConvTranspose3d",
29
+ ]
.venv/Lib/site-packages/torch/nn/quantized/modules/dropout.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.modules.dropout import Dropout
12
+
13
+
14
+ __all__ = ["Dropout"]
.venv/Lib/site-packages/torch/nn/quantized/modules/embedding_ops.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.modules.embedding_ops import (
12
+ Embedding,
13
+ EmbeddingBag,
14
+ EmbeddingPackedParams,
15
+ )
16
+
17
+
18
+ __all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
.venv/Lib/site-packages/torch/nn/quantized/modules/functional_modules.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.modules.functional_modules import (
12
+ FloatFunctional,
13
+ FXFloatFunctional,
14
+ QFunctional,
15
+ )
16
+
17
+
18
+ __all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"]
.venv/Lib/site-packages/torch/nn/quantized/modules/linear.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ from torch.ao.nn.quantized.modules.linear import Linear, LinearPackedParams
12
+
13
+
14
+ __all__ = ["LinearPackedParams", "Linear"]