Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/Lib/site-packages/torch/lib/kineto.lib +3 -0
- .venv/Lib/site-packages/torch/mtia/__init__.py +332 -0
- .venv/Lib/site-packages/torch/mtia/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/mtia/__pycache__/_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/multiprocessing/__init__.py +100 -0
- .venv/Lib/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/multiprocessing/_atfork.py +35 -0
- .venv/Lib/site-packages/torch/multiprocessing/pool.py +52 -0
- .venv/Lib/site-packages/torch/multiprocessing/queue.py +43 -0
- .venv/Lib/site-packages/torch/multiprocessing/reductions.py +647 -0
- .venv/Lib/site-packages/torch/multiprocessing/spawn.py +328 -0
- .venv/Lib/site-packages/torch/nn/parallel/__init__.py +28 -0
- .venv/Lib/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/parallel/__pycache__/comm.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/parallel/__pycache__/distributed.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/qat/__init__.py +18 -0
- .venv/Lib/site-packages/torch/nn/qat/dynamic/__init__.py +7 -0
- .venv/Lib/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__init__.py +4 -0
- .venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/qat/dynamic/modules/linear.py +10 -0
- .venv/Lib/site-packages/torch/nn/qat/modules/__init__.py +20 -0
- .venv/Lib/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/qat/modules/conv.py +11 -0
- .venv/Lib/site-packages/torch/nn/qat/modules/embedding_ops.py +14 -0
- .venv/Lib/site-packages/torch/nn/qat/modules/linear.py +10 -0
- .venv/Lib/site-packages/torch/nn/quantized/__init__.py +39 -0
- .venv/Lib/site-packages/torch/nn/quantized/dynamic/__init__.py +1 -0
- .venv/Lib/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__init__.py +43 -0
- .venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/conv.py +28 -0
- .venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/linear.py +10 -0
- .venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/rnn.py +34 -0
- .venv/Lib/site-packages/torch/nn/quantized/functional.py +10 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/__init__.py +97 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/activation.py +20 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/batchnorm.py +11 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/conv.py +29 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/dropout.py +14 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/embedding_ops.py +18 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/functional_modules.py +18 -0
- .venv/Lib/site-packages/torch/nn/quantized/modules/linear.py +14 -0
.gitattributes
CHANGED
|
@@ -122,3 +122,4 @@ reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs
|
|
| 122 |
.venv/Lib/site-packages/torch/lib/cudnn_engines_runtime_compiled64_9.dll filter=lfs diff=lfs merge=lfs -text
|
| 123 |
.venv/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text
|
| 124 |
.venv/Lib/site-packages/torch/lib/nvrtc-builtins64_121.dll filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 122 |
.venv/Lib/site-packages/torch/lib/cudnn_engines_runtime_compiled64_9.dll filter=lfs diff=lfs merge=lfs -text
|
| 123 |
.venv/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text
|
| 124 |
.venv/Lib/site-packages/torch/lib/nvrtc-builtins64_121.dll filter=lfs diff=lfs merge=lfs -text
|
| 125 |
+
.venv/Lib/site-packages/torch/lib/kineto.lib filter=lfs diff=lfs merge=lfs -text
|
.venv/Lib/site-packages/torch/lib/kineto.lib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4b349dc42209360e73bcabaf3160289923aec193db7996966926407cb51fb76
|
| 3 |
+
size 21732956
|
.venv/Lib/site-packages/torch/mtia/__init__.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
r"""
|
| 3 |
+
This package enables an interface for accessing MTIA backend in python
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import threading
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import device as _device, Tensor
|
| 12 |
+
from torch._utils import _dummy_type, _LazySeedTracker, classproperty
|
| 13 |
+
from torch.types import Device
|
| 14 |
+
|
| 15 |
+
from ._utils import _get_device_index
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
_device_t = Union[_device, str, int, None]
|
| 19 |
+
|
| 20 |
+
# torch.mtia.Event/Stream is alias of torch.Event/Stream
|
| 21 |
+
Event = torch.Event
|
| 22 |
+
Stream = torch.Stream
|
| 23 |
+
|
| 24 |
+
_initialized = False
|
| 25 |
+
_queued_calls: List[
|
| 26 |
+
Tuple[Callable[[], None], List[str]]
|
| 27 |
+
] = [] # don't invoke these until initialization occurs
|
| 28 |
+
_tls = threading.local()
|
| 29 |
+
_initialization_lock = threading.Lock()
|
| 30 |
+
_lazy_seed_tracker = _LazySeedTracker()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def init():
|
| 34 |
+
_lazy_init()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def is_initialized():
|
| 38 |
+
r"""Return whether PyTorch's MTIA state has been initialized."""
|
| 39 |
+
return _initialized and not _is_in_bad_fork()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _is_in_bad_fork() -> bool:
|
| 43 |
+
return torch._C._mtia_isInBadFork()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _lazy_init() -> None:
|
| 47 |
+
global _initialized, _queued_calls
|
| 48 |
+
if is_initialized() or hasattr(_tls, "is_initializing"):
|
| 49 |
+
return
|
| 50 |
+
with _initialization_lock:
|
| 51 |
+
# We be double-checking locking, boys! This is OK because
|
| 52 |
+
# the above test was GIL protected anyway. The inner test
|
| 53 |
+
# is for when a thread blocked on some other thread which was
|
| 54 |
+
# doing the initialization; when they get the lock, they will
|
| 55 |
+
# find there is nothing left to do.
|
| 56 |
+
if is_initialized():
|
| 57 |
+
return
|
| 58 |
+
# It is important to prevent other threads from entering _lazy_init
|
| 59 |
+
# immediately, while we are still guaranteed to have the GIL, because some
|
| 60 |
+
# of the C calls we make below will release the GIL
|
| 61 |
+
if _is_in_bad_fork():
|
| 62 |
+
raise RuntimeError(
|
| 63 |
+
"Cannot re-initialize MTIA in forked subprocess. To use MTIA with "
|
| 64 |
+
"multiprocessing, you must use the 'spawn' start method"
|
| 65 |
+
)
|
| 66 |
+
if not _is_compiled():
|
| 67 |
+
raise AssertionError(
|
| 68 |
+
"Torch not compiled with MTIA enabled. "
|
| 69 |
+
"Ensure you have `import mtia.host_runtime.torch_mtia` in your python "
|
| 70 |
+
"src file and include `//mtia/host_runtime/torch_mtia:torch_mtia` as "
|
| 71 |
+
"your target dependency!"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
torch._C._mtia_init()
|
| 75 |
+
# Some of the queued calls may reentrantly call _lazy_init();
|
| 76 |
+
# we need to just return without initializing in that case.
|
| 77 |
+
# However, we must not let any *other* threads in!
|
| 78 |
+
_tls.is_initializing = True
|
| 79 |
+
|
| 80 |
+
for calls in _lazy_seed_tracker.get_calls():
|
| 81 |
+
if calls:
|
| 82 |
+
_queued_calls.append(calls)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
for queued_call, orig_traceback in _queued_calls:
|
| 86 |
+
try:
|
| 87 |
+
queued_call()
|
| 88 |
+
except Exception as e:
|
| 89 |
+
msg = (
|
| 90 |
+
f"MTIA call failed lazily at initialization with error: {str(e)}\n\n"
|
| 91 |
+
f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
| 92 |
+
)
|
| 93 |
+
raise DeferredMtiaCallError(msg) from e
|
| 94 |
+
finally:
|
| 95 |
+
delattr(_tls, "is_initializing")
|
| 96 |
+
_initialized = True
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DeferredMtiaCallError(Exception):
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _is_compiled() -> bool:
|
| 104 |
+
r"""Return true if compiled with MTIA support."""
|
| 105 |
+
return torch._C._mtia_isBuilt()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def is_available() -> bool:
|
| 109 |
+
r"""Return true if MTIA device is available"""
|
| 110 |
+
if not _is_compiled():
|
| 111 |
+
return False
|
| 112 |
+
# MTIA has to init devices first to know if there is any devices available.
|
| 113 |
+
return device_count() > 0
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def synchronize(device: Optional[_device_t] = None) -> None:
|
| 117 |
+
r"""Waits for all jobs in all streams on a MTIA device to complete."""
|
| 118 |
+
with torch.mtia.device(device):
|
| 119 |
+
return torch._C._mtia_deviceSynchronize()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def device_count() -> int:
|
| 123 |
+
r"""Return the number of MTIA devices available."""
|
| 124 |
+
return torch._C._accelerator_hooks_device_count()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def current_device() -> int:
|
| 128 |
+
r"""Return the index of a currently selected device."""
|
| 129 |
+
return torch._C._accelerator_hooks_get_current_device()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
| 133 |
+
r"""Return the currently selected :class:`Stream` for a given device.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
device (torch.device or int, optional): selected device. Returns
|
| 137 |
+
the currently selected :class:`Stream` for the current device, given
|
| 138 |
+
by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
|
| 139 |
+
(default).
|
| 140 |
+
"""
|
| 141 |
+
return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def default_stream(device: Optional[_device_t] = None) -> Stream:
|
| 145 |
+
r"""Return the default :class:`Stream` for a given device.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
device (torch.device or int, optional): selected device. Returns
|
| 149 |
+
the default :class:`Stream` for the current device, given by
|
| 150 |
+
:func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
|
| 151 |
+
(default).
|
| 152 |
+
"""
|
| 153 |
+
return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]:
|
| 157 |
+
r"""Return a dictionary of MTIA memory allocator statistics for a given device.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
device (torch.device or int, optional) selected device. Returns
|
| 161 |
+
statistics for the current device, given by current_device(),
|
| 162 |
+
if device is None (default).
|
| 163 |
+
"""
|
| 164 |
+
if not is_initialized():
|
| 165 |
+
return {}
|
| 166 |
+
return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def set_stream(stream: Stream):
|
| 170 |
+
r"""Set the current stream.This is a wrapper API to set the stream.
|
| 171 |
+
Usage of this function is discouraged in favor of the ``stream``
|
| 172 |
+
context manager.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
stream (Stream): selected stream. This function is a no-op
|
| 176 |
+
if this argument is ``None``.
|
| 177 |
+
"""
|
| 178 |
+
if stream is None:
|
| 179 |
+
return
|
| 180 |
+
torch._C._mtia_setCurrentStream(stream)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def set_device(device: _device_t) -> None:
|
| 184 |
+
r"""Set the current device.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
device (torch.device or int): selected device. This function is a no-op
|
| 188 |
+
if this argument is negative.
|
| 189 |
+
"""
|
| 190 |
+
device = _get_device_index(device)
|
| 191 |
+
if device >= 0:
|
| 192 |
+
torch._C._accelerator_hooks_set_current_device(device)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class device:
|
| 196 |
+
r"""Context-manager that changes the selected device.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
device (torch.device or int): device index to select. It's a no-op if
|
| 200 |
+
this argument is a negative integer or ``None``.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self, device: Any):
|
| 204 |
+
self.idx = _get_device_index(device, optional=True)
|
| 205 |
+
self.prev_idx = -1
|
| 206 |
+
|
| 207 |
+
def __enter__(self):
|
| 208 |
+
self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx)
|
| 209 |
+
|
| 210 |
+
def __exit__(self, type: Any, value: Any, traceback: Any):
|
| 211 |
+
self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx)
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class StreamContext:
|
| 216 |
+
r"""Context-manager that selects a given stream.
|
| 217 |
+
|
| 218 |
+
All MTIA kernels queued within its context will be enqueued on a selected
|
| 219 |
+
stream.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
Stream (Stream): selected stream. This manager is a no-op if it's
|
| 223 |
+
``None``.
|
| 224 |
+
.. note:: Streams are per-device.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
cur_stream: Optional["torch.mtia.Stream"]
|
| 228 |
+
|
| 229 |
+
def __init__(self, stream: Optional["torch.mtia.Stream"]):
|
| 230 |
+
self.cur_stream = None
|
| 231 |
+
self.stream = stream
|
| 232 |
+
self.idx = _get_device_index(None, True)
|
| 233 |
+
if not torch.jit.is_scripting():
|
| 234 |
+
if self.idx is None:
|
| 235 |
+
self.idx = -1
|
| 236 |
+
|
| 237 |
+
self.src_prev_stream = (
|
| 238 |
+
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
| 239 |
+
)
|
| 240 |
+
self.dst_prev_stream = (
|
| 241 |
+
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def __enter__(self):
|
| 245 |
+
# Local cur_stream variable for type refinement
|
| 246 |
+
cur_stream = self.stream
|
| 247 |
+
# Return if stream is None or MTIA device not available
|
| 248 |
+
if cur_stream is None or self.idx == -1:
|
| 249 |
+
return
|
| 250 |
+
self.src_prev_stream = torch.mtia.current_stream(None)
|
| 251 |
+
|
| 252 |
+
# If the stream is not on the current device, then
|
| 253 |
+
# set the current stream on the device
|
| 254 |
+
if self.src_prev_stream.device != cur_stream.device:
|
| 255 |
+
with device(cur_stream.device):
|
| 256 |
+
self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device)
|
| 257 |
+
torch.mtia.set_stream(cur_stream)
|
| 258 |
+
|
| 259 |
+
def __exit__(self, type: Any, value: Any, traceback: Any):
|
| 260 |
+
# Local cur_stream variable for type refinement
|
| 261 |
+
cur_stream = self.stream
|
| 262 |
+
# If stream is None or no MTIA device available, return
|
| 263 |
+
if cur_stream is None or self.idx == -1:
|
| 264 |
+
return
|
| 265 |
+
|
| 266 |
+
# Reset the stream on the original device
|
| 267 |
+
# and destination device
|
| 268 |
+
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
|
| 269 |
+
torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
|
| 270 |
+
torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
|
| 274 |
+
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
| 275 |
+
|
| 276 |
+
Arguments:
|
| 277 |
+
stream (Stream): selected stream. This manager is a no-op if it's
|
| 278 |
+
``None``.
|
| 279 |
+
..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream
|
| 280 |
+
"""
|
| 281 |
+
return StreamContext(stream)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def get_rng_state(device: Union[int, str, torch.device] = "mtia") -> Tensor:
|
| 285 |
+
r"""Returns the random number generator state as a ByteTensor.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
device (torch.device or int, optional): The device to return the RNG state of.
|
| 289 |
+
Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device).
|
| 290 |
+
"""
|
| 291 |
+
warnings.warn(
|
| 292 |
+
"get_rng_state is not implemented in torch.mtia",
|
| 293 |
+
UserWarning,
|
| 294 |
+
stacklevel=2,
|
| 295 |
+
)
|
| 296 |
+
return torch.zeros([1], dtype=torch.uint8, device=device)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def set_rng_state(
|
| 300 |
+
new_state: Tensor, device: Union[int, str, torch.device] = "mtia"
|
| 301 |
+
) -> None:
|
| 302 |
+
r"""Sets the random number generator state.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
new_state (torch.ByteTensor): The desired state
|
| 306 |
+
device (torch.device or int, optional): The device to set the RNG state.
|
| 307 |
+
Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device).
|
| 308 |
+
"""
|
| 309 |
+
warnings.warn(
|
| 310 |
+
"set_rng_state is not implemented in torch.mtia",
|
| 311 |
+
UserWarning,
|
| 312 |
+
stacklevel=2,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
__all__ = [
|
| 317 |
+
"init",
|
| 318 |
+
"is_available",
|
| 319 |
+
"is_initialized",
|
| 320 |
+
"synchronize",
|
| 321 |
+
"device_count",
|
| 322 |
+
"current_device",
|
| 323 |
+
"current_stream",
|
| 324 |
+
"default_stream",
|
| 325 |
+
"memory_stats",
|
| 326 |
+
"set_device",
|
| 327 |
+
"set_stream",
|
| 328 |
+
"stream",
|
| 329 |
+
"device",
|
| 330 |
+
"set_rng_state",
|
| 331 |
+
"get_rng_state",
|
| 332 |
+
]
|
.venv/Lib/site-packages/torch/mtia/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
.venv/Lib/site-packages/torch/mtia/__pycache__/_utils.cpython-39.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
.venv/Lib/site-packages/torch/multiprocessing/__init__.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module.
|
| 3 |
+
|
| 4 |
+
It registers custom reducers, that use shared memory to provide shared
|
| 5 |
+
views on the same data in different processes. Once the tensor/storage is moved
|
| 6 |
+
to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible
|
| 7 |
+
to send it to other processes without making any copies.
|
| 8 |
+
|
| 9 |
+
The API is 100% compatible with the original module - it's enough to change
|
| 10 |
+
``import multiprocessing`` to ``import torch.multiprocessing`` to have all the
|
| 11 |
+
tensors sent through the queues or shared via other mechanisms, moved to shared
|
| 12 |
+
memory.
|
| 13 |
+
|
| 14 |
+
Because of the similarity of APIs we do not document most of this package
|
| 15 |
+
contents, and we recommend referring to very good docs of the original module.
|
| 16 |
+
"""
|
| 17 |
+
import multiprocessing
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from .reductions import init_reductions
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
from multiprocessing import * # noqa: F403
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# This call adds a Linux specific prctl(2) wrapper function to this module.
|
| 35 |
+
# See https://github.com/pytorch/pytorch/pull/14391 for more information.
|
| 36 |
+
torch._C._multiprocessing_init()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
"""Add helper function to spawn N processes and wait for completion of any of
|
| 40 |
+
them. This depends `mp.get_context` which was added in Python 3.4."""
|
| 41 |
+
from .spawn import (
|
| 42 |
+
ENV_VAR_PARALLEL_START,
|
| 43 |
+
ProcessContext,
|
| 44 |
+
ProcessExitedException,
|
| 45 |
+
ProcessRaisedException,
|
| 46 |
+
spawn,
|
| 47 |
+
SpawnContext,
|
| 48 |
+
start_processes,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if sys.platform == "darwin" or sys.platform == "win32":
|
| 53 |
+
_sharing_strategy = "file_system"
|
| 54 |
+
_all_sharing_strategies = {"file_system"}
|
| 55 |
+
else:
|
| 56 |
+
_sharing_strategy = "file_descriptor"
|
| 57 |
+
_all_sharing_strategies = {"file_descriptor", "file_system"}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def set_sharing_strategy(new_strategy):
|
| 61 |
+
"""Set the strategy for sharing CPU tensors.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
new_strategy (str): Name of the selected strategy. Should be one of
|
| 65 |
+
the values returned by :func:`get_all_sharing_strategies()`.
|
| 66 |
+
"""
|
| 67 |
+
global _sharing_strategy
|
| 68 |
+
assert new_strategy in _all_sharing_strategies
|
| 69 |
+
_sharing_strategy = new_strategy
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_sharing_strategy():
|
| 73 |
+
"""Return the current strategy for sharing CPU tensors."""
|
| 74 |
+
return _sharing_strategy
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_all_sharing_strategies():
|
| 78 |
+
"""Return a set of sharing strategies supported on a current system."""
|
| 79 |
+
return _all_sharing_strategies
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _set_thread_name(name: str) -> None:
|
| 83 |
+
"""Set the name of the current thread.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
name (str): Name of the current thread.
|
| 87 |
+
"""
|
| 88 |
+
torch._C._set_thread_name(name)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _get_thread_name() -> str:
|
| 92 |
+
"""Get the name of the current thread.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
str: Name of the current thread.
|
| 96 |
+
"""
|
| 97 |
+
return torch._C._get_thread_name()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
init_reductions()
|
.venv/Lib/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-39.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
.venv/Lib/site-packages/torch/multiprocessing/_atfork.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = ["register_after_fork"]
|
| 6 |
+
|
| 7 |
+
if sys.platform == "win32":
|
| 8 |
+
import multiprocessing.util as _util
|
| 9 |
+
|
| 10 |
+
def _register(func):
|
| 11 |
+
def wrapper(arg):
|
| 12 |
+
func()
|
| 13 |
+
|
| 14 |
+
_util.register_after_fork(_register, wrapper)
|
| 15 |
+
|
| 16 |
+
else:
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
def _register(func):
|
| 20 |
+
os.register_at_fork(after_in_child=func)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def register_after_fork(func):
|
| 24 |
+
"""Register a callable to be executed in the child process after a fork.
|
| 25 |
+
|
| 26 |
+
Note:
|
| 27 |
+
In python < 3.7 this will only work with processes created using the
|
| 28 |
+
``multiprocessing`` module. In python >= 3.7 it also works with
|
| 29 |
+
``os.fork()``.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
func (function): Function taking no arguments to be called in the child after fork
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
_register(func)
|
.venv/Lib/site-packages/torch/multiprocessing/pool.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing.pool
|
| 2 |
+
import multiprocessing.util as util
|
| 3 |
+
|
| 4 |
+
from .queue import SimpleQueue
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def clean_worker(*args, **kwargs):
|
| 8 |
+
import gc
|
| 9 |
+
|
| 10 |
+
multiprocessing.pool.worker(*args, **kwargs)
|
| 11 |
+
# Regular multiprocessing workers don't fully clean up after themselves,
|
| 12 |
+
# so we have to explicitly trigger garbage collection to make sure that all
|
| 13 |
+
# destructors are called...
|
| 14 |
+
gc.collect()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Pool(multiprocessing.pool.Pool):
|
| 18 |
+
"""Pool implementation which uses our version of SimpleQueue.
|
| 19 |
+
|
| 20 |
+
This lets us pass tensors in shared memory across processes instead of
|
| 21 |
+
serializing the underlying data.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def _setup_queues(self):
|
| 25 |
+
self._inqueue = SimpleQueue()
|
| 26 |
+
self._outqueue = SimpleQueue()
|
| 27 |
+
self._quick_put = self._inqueue._writer.send
|
| 28 |
+
self._quick_get = self._outqueue._reader.recv
|
| 29 |
+
|
| 30 |
+
def _repopulate_pool(self):
|
| 31 |
+
"""Increase the number of pool processes to the specified number.
|
| 32 |
+
|
| 33 |
+
Bring the number of pool processes up to the specified number, for use after
|
| 34 |
+
reaping workers which have exited.
|
| 35 |
+
"""
|
| 36 |
+
for i in range(self._processes - len(self._pool)):
|
| 37 |
+
# changed worker -> clean_worker
|
| 38 |
+
args = (
|
| 39 |
+
self._inqueue,
|
| 40 |
+
self._outqueue,
|
| 41 |
+
self._initializer,
|
| 42 |
+
self._initargs,
|
| 43 |
+
self._maxtasksperchild,
|
| 44 |
+
)
|
| 45 |
+
if hasattr(self, "_wrap_exception"):
|
| 46 |
+
args += (self._wrap_exception,)
|
| 47 |
+
w = self.Process(target=clean_worker, args=args)
|
| 48 |
+
self._pool.append(w)
|
| 49 |
+
w.name = w.name.replace("Process", "PoolWorker")
|
| 50 |
+
w.daemon = True
|
| 51 |
+
w.start()
|
| 52 |
+
util.debug("added worker")
|
.venv/Lib/site-packages/torch/multiprocessing/queue.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import io
|
| 3 |
+
import multiprocessing.queues
|
| 4 |
+
import pickle
|
| 5 |
+
from multiprocessing.reduction import ForkingPickler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConnectionWrapper:
|
| 9 |
+
"""Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, conn):
|
| 12 |
+
self.conn = conn
|
| 13 |
+
|
| 14 |
+
def send(self, obj):
|
| 15 |
+
buf = io.BytesIO()
|
| 16 |
+
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
|
| 17 |
+
self.send_bytes(buf.getvalue())
|
| 18 |
+
|
| 19 |
+
def recv(self):
|
| 20 |
+
buf = self.recv_bytes()
|
| 21 |
+
return pickle.loads(buf)
|
| 22 |
+
|
| 23 |
+
def __getattr__(self, name):
|
| 24 |
+
if "conn" in self.__dict__:
|
| 25 |
+
return getattr(self.conn, name)
|
| 26 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Queue(multiprocessing.queues.Queue):
|
| 30 |
+
def __init__(self, *args, **kwargs):
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
|
| 33 |
+
self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
|
| 34 |
+
self._send = self._writer.send
|
| 35 |
+
self._recv = self._reader.recv
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SimpleQueue(multiprocessing.queues.SimpleQueue):
|
| 39 |
+
def _make_methods(self):
|
| 40 |
+
if not isinstance(self._reader, ConnectionWrapper):
|
| 41 |
+
self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
|
| 42 |
+
self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
|
| 43 |
+
super()._make_methods() # type: ignore[misc]
|
.venv/Lib/site-packages/torch/multiprocessing/reductions.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import multiprocessing
|
| 3 |
+
import os
|
| 4 |
+
import threading
|
| 5 |
+
from multiprocessing.reduction import ForkingPickler
|
| 6 |
+
from multiprocessing.util import register_after_fork
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch._namedtensor_internals import check_serializing_named_tensor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
# Early load resource_sharer to prevent a partially initialized instance
|
| 15 |
+
# from being inherited in a forked child process. The reduce_storage method
|
| 16 |
+
# requires this module indirectly through DupFd(). The built-in mp.Queue
|
| 17 |
+
# class pickles arguments in a background thread which may overlap with the
|
| 18 |
+
# fork.
|
| 19 |
+
import multiprocessing.resource_sharer
|
| 20 |
+
except ImportError:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class StorageWeakRef:
|
| 25 |
+
r"""A weak reference to a Storage.
|
| 26 |
+
|
| 27 |
+
The cdata member is a Python number containing the integer representation of
|
| 28 |
+
the Storage pointer.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
__slots__ = ["cdata", "_free_weak_ref"]
|
| 32 |
+
|
| 33 |
+
def __init__(self, storage):
|
| 34 |
+
self.cdata = storage._weak_ref()
|
| 35 |
+
# Save a direct reference to _free_weak_ref because the `torch` module
|
| 36 |
+
# might be cleared during Python shutdown before this module is cleared.
|
| 37 |
+
self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_weakref(cls, cdata):
|
| 41 |
+
instance = cls.__new__(cls)
|
| 42 |
+
instance.cdata = cdata
|
| 43 |
+
instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
|
| 44 |
+
return instance
|
| 45 |
+
|
| 46 |
+
def expired(self):
|
| 47 |
+
return torch.Storage._expired(self.cdata) # type: ignore[attr-defined]
|
| 48 |
+
|
| 49 |
+
def __del__(self):
|
| 50 |
+
self._free_weak_ref(self.cdata)
|
| 51 |
+
|
| 52 |
+
def __hash__(self):
|
| 53 |
+
return self.cdata
|
| 54 |
+
|
| 55 |
+
def __eq__(self, other):
|
| 56 |
+
if id(self) == id(other):
|
| 57 |
+
return True
|
| 58 |
+
return self.cdata == other.cdata
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SharedCache(dict):
|
| 62 |
+
"""Dictionary from multiprocessing handles to StorageWeakRef."""
|
| 63 |
+
|
| 64 |
+
def __init__(self) -> None:
|
| 65 |
+
# free_dead_references() is called if the len exceeds the current
|
| 66 |
+
# limit. The limit scales with the number of remaining live objects.
|
| 67 |
+
self.limit = 128
|
| 68 |
+
# `fork` inherits lock state, so in case we fork when the lock is held,
|
| 69 |
+
# we register a function to reset the lock to a new object to avoid
|
| 70 |
+
# possible deadlocks, following python multiprocessing library design.
|
| 71 |
+
self._after_fork()
|
| 72 |
+
register_after_fork(self, SharedCache._after_fork)
|
| 73 |
+
|
| 74 |
+
def _after_fork(self):
|
| 75 |
+
self.lock = threading.Lock()
|
| 76 |
+
|
| 77 |
+
def get(self, key):
|
| 78 |
+
with self.lock:
|
| 79 |
+
return dict.get(self, key)
|
| 80 |
+
|
| 81 |
+
def __setitem__(self, key, storage_ref):
|
| 82 |
+
with self.lock:
|
| 83 |
+
dict.__setitem__(self, key, storage_ref)
|
| 84 |
+
if len(self) > self.limit:
|
| 85 |
+
self.free_dead_references()
|
| 86 |
+
|
| 87 |
+
def free_dead_references(self):
|
| 88 |
+
live = 0
|
| 89 |
+
for key, storage_ref in list(self.items()):
|
| 90 |
+
if storage_ref.expired():
|
| 91 |
+
del self[key]
|
| 92 |
+
else:
|
| 93 |
+
live += 1
|
| 94 |
+
self.limit = max(128, live * 2)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# mapping from handles to StorageWeakRef objects
|
| 98 |
+
shared_cache = SharedCache()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def rebuild_event(device, handle):
|
| 102 |
+
return torch.cuda.Event.from_ipc_handle(device, handle)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def reduce_event(event):
|
| 106 |
+
handle = event.ipc_handle()
|
| 107 |
+
return (rebuild_event, (event.device, handle))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def rebuild_tensor(cls, storage, metadata):
|
| 111 |
+
storage_offset, size, stride, requires_grad = metadata
|
| 112 |
+
t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
|
| 113 |
+
if cls == torch.nn.parameter.Parameter:
|
| 114 |
+
# we have to pass requires_grad into constructor, rather than set it as an
|
| 115 |
+
# attribute later, because it's an important check for Integer Tensors to
|
| 116 |
+
# have requires_grad=False (or else they raise an error)
|
| 117 |
+
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
| 118 |
+
else:
|
| 119 |
+
t.requires_grad = requires_grad
|
| 120 |
+
return t
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def rebuild_meta_tensor(
|
| 124 |
+
tensor_cls,
|
| 125 |
+
tensor_size,
|
| 126 |
+
tensor_stride,
|
| 127 |
+
tensor_offset,
|
| 128 |
+
dtype,
|
| 129 |
+
storage_size_bytes,
|
| 130 |
+
requires_grad,
|
| 131 |
+
):
|
| 132 |
+
untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
|
| 133 |
+
|
| 134 |
+
typed_storage = torch.TypedStorage(
|
| 135 |
+
wrap_storage=untyped_storage, dtype=dtype, _internal=True
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
t = torch._utils._rebuild_tensor(
|
| 139 |
+
typed_storage,
|
| 140 |
+
tensor_offset,
|
| 141 |
+
tensor_size,
|
| 142 |
+
tensor_stride,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if tensor_cls == torch.nn.parameter.Parameter:
|
| 146 |
+
# It is crucial for integer tensors to receive
|
| 147 |
+
# the requires_grad=False as an argument in the constructor
|
| 148 |
+
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
| 149 |
+
else:
|
| 150 |
+
t.requires_grad = requires_grad
|
| 151 |
+
|
| 152 |
+
return t
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def rebuild_cuda_tensor(
|
| 156 |
+
tensor_cls,
|
| 157 |
+
tensor_size,
|
| 158 |
+
tensor_stride,
|
| 159 |
+
tensor_offset,
|
| 160 |
+
storage_cls,
|
| 161 |
+
dtype,
|
| 162 |
+
storage_device,
|
| 163 |
+
storage_handle,
|
| 164 |
+
storage_size_bytes,
|
| 165 |
+
storage_offset_bytes,
|
| 166 |
+
requires_grad,
|
| 167 |
+
ref_counter_handle,
|
| 168 |
+
ref_counter_offset,
|
| 169 |
+
event_handle,
|
| 170 |
+
event_sync_required,
|
| 171 |
+
):
|
| 172 |
+
# If storage_handle is None, storage points to nullptr.
|
| 173 |
+
if storage_handle is None or storage_size_bytes == 0:
|
| 174 |
+
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
|
| 175 |
+
else:
|
| 176 |
+
storage = storage_from_cache(
|
| 177 |
+
storage_cls, (storage_handle, storage_offset_bytes)
|
| 178 |
+
)
|
| 179 |
+
if storage is None:
|
| 180 |
+
torch.cuda._lazy_init()
|
| 181 |
+
storage = storage_cls._new_shared_cuda(
|
| 182 |
+
storage_device,
|
| 183 |
+
storage_handle,
|
| 184 |
+
storage_size_bytes,
|
| 185 |
+
storage_offset_bytes,
|
| 186 |
+
ref_counter_handle,
|
| 187 |
+
ref_counter_offset,
|
| 188 |
+
event_handle,
|
| 189 |
+
event_sync_required,
|
| 190 |
+
)
|
| 191 |
+
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
|
| 192 |
+
storage
|
| 193 |
+
)
|
| 194 |
+
else:
|
| 195 |
+
# We already ref counting this Storage, but producer needs new ref-counters to be released.
|
| 196 |
+
storage_cls._release_ipc_counter(
|
| 197 |
+
ref_counter_handle, ref_counter_offset, device=storage_device
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
_storage = (
|
| 201 |
+
storage
|
| 202 |
+
if isinstance(storage, torch.UntypedStorage)
|
| 203 |
+
else storage._untyped_storage
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
t = torch._utils._rebuild_tensor(
|
| 207 |
+
torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
|
| 208 |
+
tensor_offset,
|
| 209 |
+
tensor_size,
|
| 210 |
+
tensor_stride,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if tensor_cls == torch.nn.parameter.Parameter:
|
| 214 |
+
# It is crucial for integer tensors to receive
|
| 215 |
+
# the requires_grad=False as an argument in the constructor
|
| 216 |
+
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
| 217 |
+
else:
|
| 218 |
+
t.requires_grad = requires_grad
|
| 219 |
+
|
| 220 |
+
return t
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def reduce_tensor(tensor):
|
| 224 |
+
if tensor.requires_grad and not tensor.is_leaf:
|
| 225 |
+
raise RuntimeError(
|
| 226 |
+
"Cowardly refusing to serialize non-leaf tensor which requires_grad, "
|
| 227 |
+
"since autograd does not support crossing process boundaries. "
|
| 228 |
+
"If you just want to transfer the data, call detach() on the tensor "
|
| 229 |
+
"before serializing (e.g., putting it on the queue)."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
check_serializing_named_tensor(tensor)
|
| 233 |
+
torch.utils.hooks.warn_if_has_hooks(tensor)
|
| 234 |
+
|
| 235 |
+
# Note [CUDA IPC and the caching allocator]
|
| 236 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 237 |
+
# When you send a CUDA tensor over IPC, you might expect that you will
|
| 238 |
+
# get out the same storage from the other end. However, the CUDA caching
|
| 239 |
+
# allocator makes it difficult to preserve this invariant. Consider
|
| 240 |
+
# the following situation: a tensor of size 0x100 points to offset 0x20 of
|
| 241 |
+
# a storage at 0xA100 of size 0x100. (For simplicity, all of these
|
| 242 |
+
# sizes are given in bytes). HOWEVER, with the caching allocator, this storage
|
| 243 |
+
# might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000.
|
| 244 |
+
#
|
| 245 |
+
# When we want to send this CUDA tensor over IPC, we must send the
|
| 246 |
+
# *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
|
| 247 |
+
# the storage 0xA100 (because that is what CUDA supports). So, on the
|
| 248 |
+
# other end, there simply isn't any way to say, "Wait, you gave me
|
| 249 |
+
# a bigger region (0xA000) than the one I wanted (0xA100)".
|
| 250 |
+
#
|
| 251 |
+
# OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
|
| 252 |
+
# one storage itself? No, because this cudaMalloc allocation might contain
|
| 253 |
+
# storages of mixed types: float, bytes, double... If you make the entire
|
| 254 |
+
# allocation a single storage of a type A, we'll hit an error when constructing
|
| 255 |
+
# a tensor of type B on the storage.
|
| 256 |
+
#
|
| 257 |
+
# cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
|
| 258 |
+
# receiver side. However, cudaIpcMemHandles from each device in a given process may
|
| 259 |
+
# only be opened by one context per device per other process.
|
| 260 |
+
# If we open and close a memory handle multiples times in a process, CUDA is allowed
|
| 261 |
+
# to give it a different address; similarly, once we close the memory, we're not
|
| 262 |
+
# allowed to access it(and the storage/tensor built on top of it), even if it is
|
| 263 |
+
# still live in the original process. As we cannot make a cudaMalloc allocation
|
| 264 |
+
# to a single storage in one go, this requires us to cache the device pointer for
|
| 265 |
+
# each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
|
| 266 |
+
# the old ones alives.
|
| 267 |
+
# See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
|
| 268 |
+
#
|
| 269 |
+
# This is fine, because all we need to do is to save our position in the allocation,
|
| 270 |
+
# and reconstruct storage and tensor from it.
|
| 271 |
+
# 0xA000 -> -------CUDA Allocation------
|
| 272 |
+
# | |
|
| 273 |
+
# | |
|
| 274 |
+
# | |
|
| 275 |
+
# | |
|
| 276 |
+
# 0xA100 -> --------storage1 begin------
|
| 277 |
+
# | |
|
| 278 |
+
# 0xA120 -> --------tensor1 begin ------
|
| 279 |
+
# | |
|
| 280 |
+
# | |
|
| 281 |
+
# | |
|
| 282 |
+
# | |
|
| 283 |
+
# | |
|
| 284 |
+
# 0xA160 -> --------tensor1 end---------
|
| 285 |
+
# | |
|
| 286 |
+
# | |
|
| 287 |
+
# | |
|
| 288 |
+
# 0xA200 -> --------storage1 end--------
|
| 289 |
+
# | |
|
| 290 |
+
# 0xE000 -> --------CUDA allocation-----
|
| 291 |
+
#
|
| 292 |
+
# To send tensor1, the following info are required from sender to receiver for
|
| 293 |
+
# storage recontruction.
|
| 294 |
+
# 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
|
| 295 |
+
# basePtr may not be exactly 0xA000 since it's a different process.
|
| 296 |
+
# 2. offset(0xA100) of storage1 in the CUDA allocation.
|
| 297 |
+
# 3. size of storage1(0x100).
|
| 298 |
+
#
|
| 299 |
+
# On receiver side:
|
| 300 |
+
# 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
|
| 301 |
+
# of the same type using (basePtr, offset, size).
|
| 302 |
+
# 2. we can reconstruct the tensor on top of the reconstructed storage
|
| 303 |
+
# Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
|
| 304 |
+
#
|
| 305 |
+
# This strategy has a few implications:
|
| 306 |
+
#
|
| 307 |
+
# 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
|
| 308 |
+
# go (non-compositionally), and this requires to have a global map
|
| 309 |
+
# memHandle -> devPtr for each process.
|
| 310 |
+
#
|
| 311 |
+
# 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize
|
| 312 |
+
# of the storage beyond 0x100 would merely have caused us to do a
|
| 313 |
+
# reallocation. You don't really want to do this, but if you did,
|
| 314 |
+
# all that would happen is that you would lose IPC sharing. But if
|
| 315 |
+
# you do this in the new world, we will happily let you write out of
|
| 316 |
+
# bounds of your "allocation", clobbering unrelated data in the cached
|
| 317 |
+
# allocator block. BAD!
|
| 318 |
+
#
|
| 319 |
+
# By the way, in old versions of PyTorch, we supported this situation
|
| 320 |
+
# natively using a "storage view", which permitted multiple storages to be
|
| 321 |
+
# views on each other. But this was the *only* use of storage views, so we
|
| 322 |
+
# eliminated it so that we could just use tensor views to implement the same
|
| 323 |
+
# thing.
|
| 324 |
+
#
|
| 325 |
+
|
| 326 |
+
# TODO: Handle distinguishing between subclass and non-subclass versions of NT better
|
| 327 |
+
# https://github.com/pytorch/pytorch/issues/110543
|
| 328 |
+
from torch.nested._internal.nested_tensor import NestedTensor
|
| 329 |
+
|
| 330 |
+
if tensor.is_nested and not isinstance(tensor, NestedTensor):
|
| 331 |
+
return reduce_nested_tensor(tensor)
|
| 332 |
+
|
| 333 |
+
if tensor.layout in {
|
| 334 |
+
torch.sparse_coo,
|
| 335 |
+
torch.sparse_csr,
|
| 336 |
+
torch.sparse_bsr,
|
| 337 |
+
torch.sparse_csc,
|
| 338 |
+
torch.sparse_bsc,
|
| 339 |
+
}:
|
| 340 |
+
return reduce_sparse_tensor(tensor)
|
| 341 |
+
|
| 342 |
+
storage = tensor._typed_storage()
|
| 343 |
+
|
| 344 |
+
if storage._untyped_storage.device.type == "cuda":
|
| 345 |
+
(
|
| 346 |
+
device,
|
| 347 |
+
handle,
|
| 348 |
+
storage_size_bytes,
|
| 349 |
+
storage_offset_bytes,
|
| 350 |
+
ref_counter_handle,
|
| 351 |
+
ref_counter_offset,
|
| 352 |
+
event_handle,
|
| 353 |
+
event_sync_required,
|
| 354 |
+
) = storage._share_cuda_()
|
| 355 |
+
tensor_offset = tensor.storage_offset()
|
| 356 |
+
shared_cache[handle] = StorageWeakRef(storage)
|
| 357 |
+
# _backward_hooks purposely omitted here, see
|
| 358 |
+
# Note [Don't serialize hooks]
|
| 359 |
+
return (
|
| 360 |
+
rebuild_cuda_tensor,
|
| 361 |
+
(
|
| 362 |
+
type(tensor),
|
| 363 |
+
tensor.size(),
|
| 364 |
+
tensor.stride(),
|
| 365 |
+
tensor_offset, # tensor offset in its storage
|
| 366 |
+
type(storage),
|
| 367 |
+
tensor.dtype,
|
| 368 |
+
device,
|
| 369 |
+
handle, # identifier which CUDA allocation is the storage in.
|
| 370 |
+
storage_size_bytes, # size(in bytes) of the storage
|
| 371 |
+
storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
|
| 372 |
+
tensor.requires_grad,
|
| 373 |
+
ref_counter_handle,
|
| 374 |
+
ref_counter_offset,
|
| 375 |
+
event_handle,
|
| 376 |
+
event_sync_required,
|
| 377 |
+
),
|
| 378 |
+
)
|
| 379 |
+
elif storage._untyped_storage.device.type == "meta":
|
| 380 |
+
return (
|
| 381 |
+
rebuild_meta_tensor,
|
| 382 |
+
(
|
| 383 |
+
type(tensor),
|
| 384 |
+
tensor.size(),
|
| 385 |
+
tensor.stride(),
|
| 386 |
+
tensor.storage_offset(),
|
| 387 |
+
tensor.dtype,
|
| 388 |
+
tensor.untyped_storage().size(),
|
| 389 |
+
tensor.requires_grad,
|
| 390 |
+
),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
|
| 394 |
+
metadata = (
|
| 395 |
+
tensor.storage_offset(),
|
| 396 |
+
tensor.size(),
|
| 397 |
+
tensor.stride(),
|
| 398 |
+
tensor.requires_grad,
|
| 399 |
+
)
|
| 400 |
+
return (rebuild_tensor, (type(tensor), storage, metadata))
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def rebuild_nested_tensor(
|
| 404 |
+
rebuild_buffer_func,
|
| 405 |
+
rebuild_buffer_args,
|
| 406 |
+
rebuild_sizes_func,
|
| 407 |
+
rebuild_sizes_args,
|
| 408 |
+
rebuild_strides_func,
|
| 409 |
+
rebuild_strides_args,
|
| 410 |
+
rebuild_offsets_func,
|
| 411 |
+
rebuild_offsets_args,
|
| 412 |
+
):
|
| 413 |
+
buffer = rebuild_buffer_func(*rebuild_buffer_args)
|
| 414 |
+
sizes = rebuild_sizes_func(*rebuild_sizes_args)
|
| 415 |
+
strides = rebuild_strides_func(*rebuild_strides_args)
|
| 416 |
+
offsets = rebuild_offsets_func(*rebuild_offsets_args)
|
| 417 |
+
return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def reduce_nested_tensor(nt):
|
| 421 |
+
rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
|
| 422 |
+
rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
|
| 423 |
+
rebuild_strides_func, rebuild_strides_args = reduce_tensor(
|
| 424 |
+
nt._nested_tensor_strides()
|
| 425 |
+
)
|
| 426 |
+
rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
|
| 427 |
+
nt._nested_tensor_storage_offsets()
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
return (
|
| 431 |
+
rebuild_nested_tensor,
|
| 432 |
+
(
|
| 433 |
+
rebuild_buffer_func,
|
| 434 |
+
rebuild_buffer_args,
|
| 435 |
+
rebuild_sizes_func,
|
| 436 |
+
rebuild_sizes_args,
|
| 437 |
+
rebuild_strides_func,
|
| 438 |
+
rebuild_strides_args,
|
| 439 |
+
rebuild_offsets_func,
|
| 440 |
+
rebuild_offsets_args,
|
| 441 |
+
),
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def rebuild_sparse_coo_tensor(
|
| 446 |
+
rebuild_indices_func,
|
| 447 |
+
rebuild_indices_args,
|
| 448 |
+
rebuild_values_func,
|
| 449 |
+
rebuild_values_args,
|
| 450 |
+
shape,
|
| 451 |
+
is_coalesced,
|
| 452 |
+
):
|
| 453 |
+
indices = rebuild_indices_func(*rebuild_indices_args)
|
| 454 |
+
values = rebuild_values_func(*rebuild_values_args)
|
| 455 |
+
return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def rebuild_sparse_compressed_tensor(
|
| 459 |
+
rebuild_compressed_indices_func,
|
| 460 |
+
rebuild_compressed_indices_args,
|
| 461 |
+
rebuild_plain_indices_func,
|
| 462 |
+
rebuild_plain_indices_args,
|
| 463 |
+
rebuild_values_func,
|
| 464 |
+
rebuild_values_args,
|
| 465 |
+
shape,
|
| 466 |
+
layout,
|
| 467 |
+
):
|
| 468 |
+
compressed_indices = rebuild_compressed_indices_func(
|
| 469 |
+
*rebuild_compressed_indices_args
|
| 470 |
+
)
|
| 471 |
+
plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
|
| 472 |
+
values = rebuild_values_func(*rebuild_values_args)
|
| 473 |
+
return torch.sparse_compressed_tensor(
|
| 474 |
+
compressed_indices, plain_indices, values, shape, layout=layout
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def reduce_sparse_tensor(sparse):
|
| 479 |
+
if sparse.layout is torch.sparse_coo:
|
| 480 |
+
rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
|
| 481 |
+
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
|
| 482 |
+
return (
|
| 483 |
+
rebuild_sparse_coo_tensor,
|
| 484 |
+
(
|
| 485 |
+
rebuild_indices_func,
|
| 486 |
+
rebuild_indices_args,
|
| 487 |
+
rebuild_values_func,
|
| 488 |
+
rebuild_values_args,
|
| 489 |
+
sparse.shape,
|
| 490 |
+
sparse.is_coalesced(),
|
| 491 |
+
),
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
| 495 |
+
compressed_indices = sparse.crow_indices()
|
| 496 |
+
plain_indices = sparse.col_indices()
|
| 497 |
+
elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
|
| 498 |
+
compressed_indices = sparse.ccol_indices()
|
| 499 |
+
plain_indices = sparse.row_indices()
|
| 500 |
+
else:
|
| 501 |
+
raise NotImplementedError(sparse.layout)
|
| 502 |
+
(
|
| 503 |
+
rebuild_compressed_indices_func,
|
| 504 |
+
rebuild_compressed_indices_args,
|
| 505 |
+
) = reduce_tensor(compressed_indices)
|
| 506 |
+
rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
|
| 507 |
+
plain_indices
|
| 508 |
+
)
|
| 509 |
+
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
|
| 510 |
+
return (
|
| 511 |
+
rebuild_sparse_compressed_tensor,
|
| 512 |
+
(
|
| 513 |
+
rebuild_compressed_indices_func,
|
| 514 |
+
rebuild_compressed_indices_args,
|
| 515 |
+
rebuild_plain_indices_func,
|
| 516 |
+
rebuild_plain_indices_args,
|
| 517 |
+
rebuild_values_func,
|
| 518 |
+
rebuild_values_args,
|
| 519 |
+
sparse.shape,
|
| 520 |
+
sparse.layout,
|
| 521 |
+
),
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def fd_id(fd):
|
| 526 |
+
# Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
|
| 527 |
+
# this doesn't work with shared memory handles, which is why we don't
|
| 528 |
+
# support the "file_descriptor" sharing method on that platform.
|
| 529 |
+
stat = os.fstat(fd)
|
| 530 |
+
return (stat.st_ino, stat.st_dev)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def storage_from_cache(cls, key):
|
| 534 |
+
storage_ref = shared_cache.get(key)
|
| 535 |
+
if storage_ref is None:
|
| 536 |
+
return None
|
| 537 |
+
return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def rebuild_storage_fd(cls, df, size):
|
| 541 |
+
fd = df.detach()
|
| 542 |
+
try:
|
| 543 |
+
storage = storage_from_cache(cls, fd_id(fd))
|
| 544 |
+
if storage is not None:
|
| 545 |
+
return storage
|
| 546 |
+
storage = cls._new_shared_fd_cpu(fd, size)
|
| 547 |
+
shared_cache[fd_id(fd)] = StorageWeakRef(storage)
|
| 548 |
+
return storage
|
| 549 |
+
finally:
|
| 550 |
+
os.close(fd)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
|
| 554 |
+
storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
|
| 555 |
+
cls, handle
|
| 556 |
+
)
|
| 557 |
+
if storage is not None:
|
| 558 |
+
return storage._shared_decref()
|
| 559 |
+
if dtype is None:
|
| 560 |
+
storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
|
| 561 |
+
else:
|
| 562 |
+
byte_size = size * torch._utils._element_size(dtype)
|
| 563 |
+
untyped_storage: torch.UntypedStorage = (
|
| 564 |
+
torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
|
| 565 |
+
)
|
| 566 |
+
storage = torch.TypedStorage(
|
| 567 |
+
wrap_storage=untyped_storage, dtype=dtype, _internal=True
|
| 568 |
+
)
|
| 569 |
+
shared_cache[handle] = StorageWeakRef(storage)
|
| 570 |
+
return storage._shared_decref()
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def rebuild_storage_empty(cls):
|
| 574 |
+
return cls()
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def rebuild_typed_storage(storage, dtype):
|
| 578 |
+
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# Use for torch.storage.TypedStorage
|
| 582 |
+
def reduce_typed_storage(storage):
|
| 583 |
+
return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def rebuild_typed_storage_child(storage, storage_type):
|
| 587 |
+
return storage_type(wrap_storage=storage, _internal=True)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
|
| 591 |
+
def reduce_typed_storage_child(storage):
|
| 592 |
+
return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def reduce_storage(storage):
|
| 596 |
+
from . import get_sharing_strategy
|
| 597 |
+
|
| 598 |
+
if storage.is_cuda:
|
| 599 |
+
raise RuntimeError(
|
| 600 |
+
"Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
|
| 601 |
+
)
|
| 602 |
+
elif storage.device.type == "meta":
|
| 603 |
+
raise RuntimeError(
|
| 604 |
+
"Cannot pickle meta storage; try pickling a meta tensor instead"
|
| 605 |
+
)
|
| 606 |
+
elif get_sharing_strategy() == "file_system":
|
| 607 |
+
metadata = storage._share_filename_cpu_()
|
| 608 |
+
cache_key = metadata[1]
|
| 609 |
+
rebuild = rebuild_storage_filename
|
| 610 |
+
if isinstance(storage, torch.TypedStorage):
|
| 611 |
+
metadata += (storage.dtype,)
|
| 612 |
+
storage._shared_incref()
|
| 613 |
+
elif storage.size() == 0:
|
| 614 |
+
# This is special cased because Empty tensors
|
| 615 |
+
# (with size 0) cannot be mmapped.
|
| 616 |
+
return (rebuild_storage_empty, (type(storage),))
|
| 617 |
+
else:
|
| 618 |
+
fd, size = storage._share_fd_cpu_()
|
| 619 |
+
df = multiprocessing.reduction.DupFd(fd)
|
| 620 |
+
cache_key = fd_id(fd)
|
| 621 |
+
metadata = (df, size)
|
| 622 |
+
rebuild = rebuild_storage_fd # type: ignore[assignment]
|
| 623 |
+
|
| 624 |
+
shared_cache[cache_key] = StorageWeakRef(storage)
|
| 625 |
+
return (rebuild, (type(storage),) + metadata)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def init_reductions():
|
| 629 |
+
ForkingPickler.register(torch.cuda.Event, reduce_event)
|
| 630 |
+
|
| 631 |
+
for t in torch._storage_classes:
|
| 632 |
+
if t.__name__ == "UntypedStorage":
|
| 633 |
+
ForkingPickler.register(t, reduce_storage)
|
| 634 |
+
else:
|
| 635 |
+
ForkingPickler.register(t, reduce_typed_storage_child)
|
| 636 |
+
|
| 637 |
+
ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
|
| 638 |
+
|
| 639 |
+
for t in torch._tensor_classes:
|
| 640 |
+
ForkingPickler.register(t, reduce_tensor)
|
| 641 |
+
|
| 642 |
+
# TODO: Maybe this should be in tensor_classes? :)
|
| 643 |
+
ForkingPickler.register(torch.Tensor, reduce_tensor)
|
| 644 |
+
|
| 645 |
+
from torch.nn.parameter import Parameter
|
| 646 |
+
|
| 647 |
+
ForkingPickler.register(Parameter, reduce_tensor)
|
.venv/Lib/site-packages/torch/multiprocessing/spawn.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import multiprocessing.connection
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import signal
|
| 8 |
+
import sys
|
| 9 |
+
import tempfile
|
| 10 |
+
import time
|
| 11 |
+
import warnings
|
| 12 |
+
from concurrent.futures import as_completed, ThreadPoolExecutor
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"ProcessContext",
|
| 24 |
+
"ProcessException",
|
| 25 |
+
"ProcessExitedException",
|
| 26 |
+
"ProcessRaisedException",
|
| 27 |
+
"spawn",
|
| 28 |
+
"SpawnContext",
|
| 29 |
+
"start_processes",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ProcessException(Exception):
|
| 34 |
+
__slots__ = ["error_index", "error_pid"]
|
| 35 |
+
|
| 36 |
+
def __init__(self, msg: str, error_index: int, pid: int):
|
| 37 |
+
super().__init__(msg)
|
| 38 |
+
self.msg = msg
|
| 39 |
+
self.error_index = error_index
|
| 40 |
+
self.pid = pid
|
| 41 |
+
|
| 42 |
+
def __reduce__(self):
|
| 43 |
+
return type(self), (self.msg, self.error_index, self.pid)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ProcessRaisedException(ProcessException):
|
| 47 |
+
"""Exception raised when a process failed due to an exception raised by the code."""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
msg: str,
|
| 52 |
+
error_index: int,
|
| 53 |
+
error_pid: int,
|
| 54 |
+
):
|
| 55 |
+
super().__init__(msg, error_index, error_pid)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ProcessExitedException(ProcessException):
|
| 59 |
+
"""Exception raised when a process failed due to signal or exited with a specific code."""
|
| 60 |
+
|
| 61 |
+
__slots__ = ["exit_code"]
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
msg: str,
|
| 66 |
+
error_index: int,
|
| 67 |
+
error_pid: int,
|
| 68 |
+
exit_code: int,
|
| 69 |
+
signal_name: Optional[str] = None,
|
| 70 |
+
):
|
| 71 |
+
super().__init__(msg, error_index, error_pid)
|
| 72 |
+
self.exit_code = exit_code
|
| 73 |
+
self.signal_name = signal_name
|
| 74 |
+
|
| 75 |
+
def __reduce__(self):
|
| 76 |
+
return (
|
| 77 |
+
type(self),
|
| 78 |
+
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _wrap(fn, i, args, error_file):
|
| 83 |
+
# prctl(2) is a Linux specific system call.
|
| 84 |
+
# On other systems the following function call has no effect.
|
| 85 |
+
# This is set to ensure that non-daemonic child processes can
|
| 86 |
+
# terminate if their parent terminates before they do.
|
| 87 |
+
_prctl_pr_set_pdeathsig(signal.SIGINT)
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
fn(i, *args)
|
| 91 |
+
except KeyboardInterrupt:
|
| 92 |
+
pass # SIGINT; Killed by parent, do nothing
|
| 93 |
+
except Exception:
|
| 94 |
+
# Propagate exception to parent process, keeping original traceback
|
| 95 |
+
import traceback
|
| 96 |
+
|
| 97 |
+
with open(error_file, "wb") as fh:
|
| 98 |
+
pickle.dump(traceback.format_exc(), fh)
|
| 99 |
+
sys.exit(1)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ProcessContext:
|
| 103 |
+
def __init__(self, processes, error_files):
|
| 104 |
+
self.error_files = error_files
|
| 105 |
+
self.processes = processes
|
| 106 |
+
self.sentinels = {
|
| 107 |
+
process.sentinel: index for index, process in enumerate(processes)
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def pids(self):
|
| 111 |
+
return [int(process.pid) for process in self.processes]
|
| 112 |
+
|
| 113 |
+
def join(self, timeout=None):
|
| 114 |
+
r"""Join one or more processes within spawn context.
|
| 115 |
+
|
| 116 |
+
Attempt to join one or more processes in this spawn context.
|
| 117 |
+
If one of them exited with a non-zero exit status, this function
|
| 118 |
+
kills the remaining processes and raises an exception with the cause
|
| 119 |
+
of the first process exiting.
|
| 120 |
+
|
| 121 |
+
Returns ``True`` if all processes have been joined successfully,
|
| 122 |
+
``False`` if there are more processes that need to be joined.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
timeout (float): Wait this long before giving up on waiting.
|
| 126 |
+
"""
|
| 127 |
+
# Ensure this function can be called even when we're done.
|
| 128 |
+
if len(self.sentinels) == 0:
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
# Wait for any process to fail or all of them to succeed.
|
| 132 |
+
ready = multiprocessing.connection.wait(
|
| 133 |
+
self.sentinels.keys(),
|
| 134 |
+
timeout=timeout,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
error_index = None
|
| 138 |
+
for sentinel in ready:
|
| 139 |
+
index = self.sentinels.pop(sentinel)
|
| 140 |
+
process = self.processes[index]
|
| 141 |
+
process.join()
|
| 142 |
+
if process.exitcode != 0:
|
| 143 |
+
error_index = index
|
| 144 |
+
break
|
| 145 |
+
|
| 146 |
+
# Return if there was no error.
|
| 147 |
+
if error_index is None:
|
| 148 |
+
# Return whether or not all processes have been joined.
|
| 149 |
+
return len(self.sentinels) == 0
|
| 150 |
+
|
| 151 |
+
# Assume failure. Terminate processes that are still alive.
|
| 152 |
+
# Try SIGTERM then SIGKILL if the process isn't going down.
|
| 153 |
+
# The reason is related to python signal handling is limited
|
| 154 |
+
# to main thread and if that is in c/c++ land and stuck it won't
|
| 155 |
+
# to handle it. We have seen processes getting stuck not handling
|
| 156 |
+
# SIGTERM for the above reason.
|
| 157 |
+
timeout: int = 30
|
| 158 |
+
for process in self.processes:
|
| 159 |
+
if process.is_alive():
|
| 160 |
+
log.warning("Terminating process %s via signal SIGTERM", process.pid)
|
| 161 |
+
process.terminate()
|
| 162 |
+
end = time.monotonic() + timeout
|
| 163 |
+
for process in self.processes:
|
| 164 |
+
time_to_wait = max(0, end - time.monotonic())
|
| 165 |
+
process.join(time_to_wait)
|
| 166 |
+
for process in self.processes:
|
| 167 |
+
if process.is_alive():
|
| 168 |
+
log.warning(
|
| 169 |
+
"Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
|
| 170 |
+
process.pid,
|
| 171 |
+
)
|
| 172 |
+
process.kill()
|
| 173 |
+
process.join()
|
| 174 |
+
|
| 175 |
+
# The file will only be created if the process crashed.
|
| 176 |
+
failed_process = self.processes[error_index]
|
| 177 |
+
if not os.access(self.error_files[error_index], os.R_OK):
|
| 178 |
+
exitcode = self.processes[error_index].exitcode
|
| 179 |
+
if exitcode < 0:
|
| 180 |
+
try:
|
| 181 |
+
name = signal.Signals(-exitcode).name
|
| 182 |
+
except ValueError:
|
| 183 |
+
name = f"<Unknown signal {-exitcode}>"
|
| 184 |
+
raise ProcessExitedException(
|
| 185 |
+
"process %d terminated with signal %s" % (error_index, name),
|
| 186 |
+
error_index=error_index,
|
| 187 |
+
error_pid=failed_process.pid,
|
| 188 |
+
exit_code=exitcode,
|
| 189 |
+
signal_name=name,
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
raise ProcessExitedException(
|
| 193 |
+
"process %d terminated with exit code %d" % (error_index, exitcode),
|
| 194 |
+
error_index=error_index,
|
| 195 |
+
error_pid=failed_process.pid,
|
| 196 |
+
exit_code=exitcode,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
with open(self.error_files[error_index], "rb") as fh:
|
| 200 |
+
original_trace = pickle.load(fh)
|
| 201 |
+
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
|
| 202 |
+
msg += original_trace
|
| 203 |
+
raise ProcessRaisedException(msg, error_index, failed_process.pid)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class SpawnContext(ProcessContext):
|
| 207 |
+
def __init__(self, processes, error_files):
|
| 208 |
+
warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
|
| 209 |
+
super().__init__(processes, error_files)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# Note: [start_processes]
|
| 213 |
+
# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
|
| 214 |
+
# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
|
| 215 |
+
# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
|
| 216 |
+
# works better than 'spawn'. Every helper function we created for mp.spawn is indeed
|
| 217 |
+
# general enough, and backends like XLA can reuse them in Colab notebooks as well.
|
| 218 |
+
# Currently we only add this API first, we can consider adding it to documentation as
|
| 219 |
+
# needed in the future.
|
| 220 |
+
def start_processes(
|
| 221 |
+
fn,
|
| 222 |
+
args=(),
|
| 223 |
+
nprocs=1,
|
| 224 |
+
join=True,
|
| 225 |
+
daemon=False,
|
| 226 |
+
start_method="spawn",
|
| 227 |
+
):
|
| 228 |
+
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
|
| 229 |
+
# this func will start processes in parallel if start_method is 'forkserver'.
|
| 230 |
+
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
|
| 231 |
+
# todo: investigate why spawn does not work with threadpool and raises SIGINT
|
| 232 |
+
if (
|
| 233 |
+
start_method == "forkserver"
|
| 234 |
+
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
|
| 235 |
+
):
|
| 236 |
+
log.info("Starting processes in parallel.")
|
| 237 |
+
start_parallel = True
|
| 238 |
+
else:
|
| 239 |
+
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
|
| 240 |
+
start_parallel = False
|
| 241 |
+
|
| 242 |
+
mp = multiprocessing.get_context(start_method)
|
| 243 |
+
error_files = [None] * nprocs
|
| 244 |
+
processes = [None] * nprocs
|
| 245 |
+
|
| 246 |
+
def start_process(i):
|
| 247 |
+
# Each process is assigned a file to write tracebacks to. We
|
| 248 |
+
# use the file being non-empty to indicate an exception
|
| 249 |
+
# occurred (vs an expected shutdown). Note: this previously
|
| 250 |
+
# used a multiprocessing.Queue but that can be prone to
|
| 251 |
+
# deadlocks, so we went with a simpler solution for a one-shot
|
| 252 |
+
# message between processes.
|
| 253 |
+
tf = tempfile.NamedTemporaryFile(
|
| 254 |
+
prefix="pytorch-errorfile-", suffix=".pickle", delete=False
|
| 255 |
+
)
|
| 256 |
+
tf.close()
|
| 257 |
+
os.unlink(tf.name)
|
| 258 |
+
process = mp.Process(
|
| 259 |
+
target=_wrap,
|
| 260 |
+
args=(fn, i, args, tf.name),
|
| 261 |
+
daemon=daemon,
|
| 262 |
+
)
|
| 263 |
+
process.start()
|
| 264 |
+
return i, process, tf.name
|
| 265 |
+
|
| 266 |
+
if not start_parallel:
|
| 267 |
+
for i in range(nprocs):
|
| 268 |
+
idx, process, tf_name = start_process(i)
|
| 269 |
+
error_files[idx] = tf_name
|
| 270 |
+
processes[idx] = process
|
| 271 |
+
else:
|
| 272 |
+
with ThreadPoolExecutor(max_workers=nprocs) as executor:
|
| 273 |
+
futures = [executor.submit(start_process, i) for i in range(nprocs)]
|
| 274 |
+
for fut in as_completed(futures):
|
| 275 |
+
idx, process, tf_name = fut.result()
|
| 276 |
+
# idx and process rank needs to be the same.
|
| 277 |
+
error_files[idx] = tf_name
|
| 278 |
+
processes[idx] = process
|
| 279 |
+
context = ProcessContext(processes, error_files)
|
| 280 |
+
if not join:
|
| 281 |
+
return context
|
| 282 |
+
|
| 283 |
+
# Loop on join until it returns True or raises an exception.
|
| 284 |
+
while not context.join():
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
|
| 289 |
+
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
|
| 290 |
+
|
| 291 |
+
If one of the processes exits with a non-zero exit status, the
|
| 292 |
+
remaining processes are killed and an exception is raised with the
|
| 293 |
+
cause of termination. In the case an exception was caught in the
|
| 294 |
+
child process, it is forwarded and its traceback is included in
|
| 295 |
+
the exception raised in the parent process.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
fn (function): Function is called as the entrypoint of the
|
| 299 |
+
spawned process. This function must be defined at the top
|
| 300 |
+
level of a module so it can be pickled and spawned. This
|
| 301 |
+
is a requirement imposed by multiprocessing.
|
| 302 |
+
|
| 303 |
+
The function is called as ``fn(i, *args)``, where ``i`` is
|
| 304 |
+
the process index and ``args`` is the passed through tuple
|
| 305 |
+
of arguments.
|
| 306 |
+
|
| 307 |
+
args (tuple): Arguments passed to ``fn``.
|
| 308 |
+
nprocs (int): Number of processes to spawn.
|
| 309 |
+
join (bool): Perform a blocking join on all processes.
|
| 310 |
+
daemon (bool): The spawned processes' daemon flag. If set to True,
|
| 311 |
+
daemonic processes will be created.
|
| 312 |
+
start_method (str): (deprecated) this method will always use ``spawn``
|
| 313 |
+
as the start method. To use a different start method
|
| 314 |
+
use ``start_processes()``.
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
None if ``join`` is ``True``,
|
| 318 |
+
:class:`~ProcessContext` if ``join`` is ``False``
|
| 319 |
+
|
| 320 |
+
"""
|
| 321 |
+
if start_method != "spawn":
|
| 322 |
+
msg = (
|
| 323 |
+
f"This method only supports start_method=spawn (got: {start_method}).\n"
|
| 324 |
+
"To use a different start_method use:\n\t\t"
|
| 325 |
+
" torch.multiprocessing.start_processes(...)"
|
| 326 |
+
)
|
| 327 |
+
warnings.warn(msg, FutureWarning, stacklevel=2)
|
| 328 |
+
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
|
.venv/Lib/site-packages/torch/nn/parallel/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing_extensions import deprecated
|
| 3 |
+
|
| 4 |
+
from torch.nn.parallel.data_parallel import data_parallel, DataParallel
|
| 5 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
| 6 |
+
from torch.nn.parallel.parallel_apply import parallel_apply
|
| 7 |
+
from torch.nn.parallel.replicate import replicate
|
| 8 |
+
from torch.nn.parallel.scatter_gather import gather, scatter
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"replicate",
|
| 13 |
+
"scatter",
|
| 14 |
+
"parallel_apply",
|
| 15 |
+
"gather",
|
| 16 |
+
"data_parallel",
|
| 17 |
+
"DataParallel",
|
| 18 |
+
"DistributedDataParallel",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@deprecated(
|
| 23 |
+
"`torch.nn.parallel.DistributedDataParallelCPU` is deprecated, "
|
| 24 |
+
"please use `torch.nn.parallel.DistributedDataParallel` instead.",
|
| 25 |
+
category=FutureWarning,
|
| 26 |
+
)
|
| 27 |
+
class DistributedDataParallelCPU(DistributedDataParallel):
|
| 28 |
+
pass
|
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-39.pyc
ADDED
|
Binary file (5.92 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/comm.cpython-39.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/data_parallel.cpython-39.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/distributed.cpython-39.pyc
ADDED
|
Binary file (81.5 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/parallel_apply.cpython-39.pyc
ADDED
|
Binary file (4.06 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/replicate.cpython-39.pyc
ADDED
|
Binary file (5.24 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/parallel/__pycache__/scatter_gather.cpython-39.pyc
ADDED
|
Binary file (5.19 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/qat/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This package is in the process of being deprecated.
|
| 5 |
+
Please, use `torch.ao.nn.qat.dynamic` instead.
|
| 6 |
+
"""
|
| 7 |
+
from torch.nn.qat import dynamic, modules # noqa: F403
|
| 8 |
+
from torch.nn.qat.modules import * # noqa: F403
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"Linear",
|
| 13 |
+
"Conv1d",
|
| 14 |
+
"Conv2d",
|
| 15 |
+
"Conv3d",
|
| 16 |
+
"Embedding",
|
| 17 |
+
"EmbeddingBag",
|
| 18 |
+
]
|
.venv/Lib/site-packages/torch/nn/qat/dynamic/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This package is in the process of being deprecated.
|
| 5 |
+
Please, use `torch.ao.nn.qat.dynamic` instead.
|
| 6 |
+
"""
|
| 7 |
+
from torch.nn.qat.dynamic.modules import * # noqa: F403
|
.venv/Lib/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (375 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn.qat.dynamic.modules.linear import Linear
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = ["Linear"]
|
.venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (288 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc
ADDED
|
Binary file (615 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/qat/dynamic/modules/linear.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/qat/dynamic`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/qat/dynamic/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
from torch.ao.nn.qat.dynamic.modules.linear import Linear
|
.venv/Lib/site-packages/torch/nn/qat/modules/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Modules.
|
| 3 |
+
|
| 4 |
+
This package is in the process of being deprecated.
|
| 5 |
+
Please, use `torch.ao.nn.qat.modules` instead.
|
| 6 |
+
"""
|
| 7 |
+
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
|
| 8 |
+
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
|
| 9 |
+
from torch.ao.nn.qat.modules.linear import Linear
|
| 10 |
+
from torch.nn.qat.modules import conv, embedding_ops, linear
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"Linear",
|
| 15 |
+
"Conv1d",
|
| 16 |
+
"Conv2d",
|
| 17 |
+
"Conv3d",
|
| 18 |
+
"Embedding",
|
| 19 |
+
"EmbeddingBag",
|
| 20 |
+
]
|
.venv/Lib/site-packages/torch/nn/qat/modules/__pycache__/conv.cpython-39.pyc
ADDED
|
Binary file (613 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/qat/modules/conv.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/qat`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/qat/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
|
.venv/Lib/site-packages/torch/nn/qat/modules/embedding_ops.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/qat`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/qat/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = ["Embedding", "EmbeddingBag"]
|
.venv/Lib/site-packages/torch/nn/qat/modules/linear.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/qat`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/qat/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
from torch.ao.nn.qat.modules.linear import Linear
|
.venv/Lib/site-packages/torch/nn/quantized/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn.quantized import dynamic, functional, modules # noqa: F403
|
| 2 |
+
from torch.nn.quantized.modules import * # noqa: F403
|
| 3 |
+
from torch.nn.quantized.modules import MaxPool2d
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"BatchNorm2d",
|
| 8 |
+
"BatchNorm3d",
|
| 9 |
+
"Conv1d",
|
| 10 |
+
"Conv2d",
|
| 11 |
+
"Conv3d",
|
| 12 |
+
"ConvTranspose1d",
|
| 13 |
+
"ConvTranspose2d",
|
| 14 |
+
"ConvTranspose3d",
|
| 15 |
+
"DeQuantize",
|
| 16 |
+
"Dropout",
|
| 17 |
+
"ELU",
|
| 18 |
+
"Embedding",
|
| 19 |
+
"EmbeddingBag",
|
| 20 |
+
"GroupNorm",
|
| 21 |
+
"Hardswish",
|
| 22 |
+
"InstanceNorm1d",
|
| 23 |
+
"InstanceNorm2d",
|
| 24 |
+
"InstanceNorm3d",
|
| 25 |
+
"LayerNorm",
|
| 26 |
+
"LeakyReLU",
|
| 27 |
+
"Linear",
|
| 28 |
+
"LSTM",
|
| 29 |
+
"MultiheadAttention",
|
| 30 |
+
"PReLU",
|
| 31 |
+
"Quantize",
|
| 32 |
+
"ReLU6",
|
| 33 |
+
"Sigmoid",
|
| 34 |
+
"Softmax",
|
| 35 |
+
# Wrapper modules
|
| 36 |
+
"FloatFunctional",
|
| 37 |
+
"FXFloatFunctional",
|
| 38 |
+
"QFunctional",
|
| 39 |
+
]
|
.venv/Lib/site-packages/torch/nn/quantized/dynamic/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.quantized.dynamic import * # noqa: F403
|
.venv/Lib/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (244 Bytes). View file
|
|
|
.venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
|
| 5 |
+
and is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/dynamic`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.dynamic.modules import conv, linear, rnn
|
| 12 |
+
from torch.ao.nn.quantized.dynamic.modules.conv import (
|
| 13 |
+
Conv1d,
|
| 14 |
+
Conv2d,
|
| 15 |
+
Conv3d,
|
| 16 |
+
ConvTranspose1d,
|
| 17 |
+
ConvTranspose2d,
|
| 18 |
+
ConvTranspose3d,
|
| 19 |
+
)
|
| 20 |
+
from torch.ao.nn.quantized.dynamic.modules.linear import Linear
|
| 21 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import (
|
| 22 |
+
GRU,
|
| 23 |
+
GRUCell,
|
| 24 |
+
LSTM,
|
| 25 |
+
LSTMCell,
|
| 26 |
+
RNNCell,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"Linear",
|
| 32 |
+
"LSTM",
|
| 33 |
+
"GRU",
|
| 34 |
+
"LSTMCell",
|
| 35 |
+
"RNNCell",
|
| 36 |
+
"GRUCell",
|
| 37 |
+
"Conv1d",
|
| 38 |
+
"Conv2d",
|
| 39 |
+
"Conv3d",
|
| 40 |
+
"ConvTranspose1d",
|
| 41 |
+
"ConvTranspose2d",
|
| 42 |
+
"ConvTranspose3d",
|
| 43 |
+
]
|
.venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/conv.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
|
| 5 |
+
and is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.dynamic.modules.conv import (
|
| 12 |
+
Conv1d,
|
| 13 |
+
Conv2d,
|
| 14 |
+
Conv3d,
|
| 15 |
+
ConvTranspose1d,
|
| 16 |
+
ConvTranspose2d,
|
| 17 |
+
ConvTranspose3d,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"Conv1d",
|
| 23 |
+
"Conv2d",
|
| 24 |
+
"Conv3d",
|
| 25 |
+
"ConvTranspose1d",
|
| 26 |
+
"ConvTranspose2d",
|
| 27 |
+
"ConvTranspose3d",
|
| 28 |
+
]
|
.venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/linear.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
|
| 5 |
+
and is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
from torch.ao.nn.quantized.dynamic.modules.linear import Linear
|
.venv/Lib/site-packages/torch/nn/quantized/dynamic/modules/rnn.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
|
| 5 |
+
and is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import (
|
| 12 |
+
GRU,
|
| 13 |
+
GRUCell,
|
| 14 |
+
LSTM,
|
| 15 |
+
LSTMCell,
|
| 16 |
+
pack_weight_bias,
|
| 17 |
+
PackedParameter,
|
| 18 |
+
RNNBase,
|
| 19 |
+
RNNCell,
|
| 20 |
+
RNNCellBase,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"pack_weight_bias",
|
| 26 |
+
"PackedParameter",
|
| 27 |
+
"RNNBase",
|
| 28 |
+
"LSTM",
|
| 29 |
+
"GRU",
|
| 30 |
+
"RNNCellBase",
|
| 31 |
+
"RNNCell",
|
| 32 |
+
"LSTMCell",
|
| 33 |
+
"GRUCell",
|
| 34 |
+
]
|
.venv/Lib/site-packages/torch/nn/quantized/functional.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""nn.quantized.functional.
|
| 2 |
+
|
| 3 |
+
Quantized equivalents of the `nn.functional`.
|
| 4 |
+
|
| 5 |
+
Note::
|
| 6 |
+
This location is in the process of being deprecated.
|
| 7 |
+
Please, use the `torch.ao.nn.quantized.functional` instead.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from torch.ao.nn.quantized.functional import * # noqa: F401,F403
|
.venv/Lib/site-packages/torch/nn/quantized/modules/__init__.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""Quantized Modules.
|
| 2 |
+
|
| 3 |
+
Note::
|
| 4 |
+
The `torch.nn.quantized` namespace is in the process of being deprecated.
|
| 5 |
+
Please, use `torch.ao.nn.quantized` instead.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# The following imports are needed in case the user decides
|
| 9 |
+
# to import the files directly,
|
| 10 |
+
# s.a. `from torch.nn.quantized.modules.conv import ...`.
|
| 11 |
+
# No need to add them to the `__all__`.
|
| 12 |
+
from torch.ao.nn.quantized.modules import (
|
| 13 |
+
activation,
|
| 14 |
+
batchnorm,
|
| 15 |
+
conv,
|
| 16 |
+
DeQuantize,
|
| 17 |
+
dropout,
|
| 18 |
+
embedding_ops,
|
| 19 |
+
functional_modules,
|
| 20 |
+
linear,
|
| 21 |
+
MaxPool2d,
|
| 22 |
+
normalization,
|
| 23 |
+
Quantize,
|
| 24 |
+
rnn,
|
| 25 |
+
utils,
|
| 26 |
+
)
|
| 27 |
+
from torch.ao.nn.quantized.modules.activation import (
|
| 28 |
+
ELU,
|
| 29 |
+
Hardswish,
|
| 30 |
+
LeakyReLU,
|
| 31 |
+
MultiheadAttention,
|
| 32 |
+
PReLU,
|
| 33 |
+
ReLU6,
|
| 34 |
+
Sigmoid,
|
| 35 |
+
Softmax,
|
| 36 |
+
)
|
| 37 |
+
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
|
| 38 |
+
from torch.ao.nn.quantized.modules.conv import (
|
| 39 |
+
Conv1d,
|
| 40 |
+
Conv2d,
|
| 41 |
+
Conv3d,
|
| 42 |
+
ConvTranspose1d,
|
| 43 |
+
ConvTranspose2d,
|
| 44 |
+
ConvTranspose3d,
|
| 45 |
+
)
|
| 46 |
+
from torch.ao.nn.quantized.modules.dropout import Dropout
|
| 47 |
+
from torch.ao.nn.quantized.modules.embedding_ops import Embedding, EmbeddingBag
|
| 48 |
+
from torch.ao.nn.quantized.modules.functional_modules import (
|
| 49 |
+
FloatFunctional,
|
| 50 |
+
FXFloatFunctional,
|
| 51 |
+
QFunctional,
|
| 52 |
+
)
|
| 53 |
+
from torch.ao.nn.quantized.modules.linear import Linear
|
| 54 |
+
from torch.ao.nn.quantized.modules.normalization import (
|
| 55 |
+
GroupNorm,
|
| 56 |
+
InstanceNorm1d,
|
| 57 |
+
InstanceNorm2d,
|
| 58 |
+
InstanceNorm3d,
|
| 59 |
+
LayerNorm,
|
| 60 |
+
)
|
| 61 |
+
from torch.ao.nn.quantized.modules.rnn import LSTM
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
__all__ = [
|
| 65 |
+
"BatchNorm2d",
|
| 66 |
+
"BatchNorm3d",
|
| 67 |
+
"Conv1d",
|
| 68 |
+
"Conv2d",
|
| 69 |
+
"Conv3d",
|
| 70 |
+
"ConvTranspose1d",
|
| 71 |
+
"ConvTranspose2d",
|
| 72 |
+
"ConvTranspose3d",
|
| 73 |
+
"DeQuantize",
|
| 74 |
+
"ELU",
|
| 75 |
+
"Embedding",
|
| 76 |
+
"EmbeddingBag",
|
| 77 |
+
"GroupNorm",
|
| 78 |
+
"Hardswish",
|
| 79 |
+
"InstanceNorm1d",
|
| 80 |
+
"InstanceNorm2d",
|
| 81 |
+
"InstanceNorm3d",
|
| 82 |
+
"LayerNorm",
|
| 83 |
+
"LeakyReLU",
|
| 84 |
+
"Linear",
|
| 85 |
+
"LSTM",
|
| 86 |
+
"MultiheadAttention",
|
| 87 |
+
"Quantize",
|
| 88 |
+
"ReLU6",
|
| 89 |
+
"Sigmoid",
|
| 90 |
+
"Softmax",
|
| 91 |
+
"Dropout",
|
| 92 |
+
"PReLU",
|
| 93 |
+
# Wrapper modules
|
| 94 |
+
"FloatFunctional",
|
| 95 |
+
"FXFloatFunctional",
|
| 96 |
+
"QFunctional",
|
| 97 |
+
]
|
.venv/Lib/site-packages/torch/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
.venv/Lib/site-packages/torch/nn/quantized/modules/activation.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.modules.activation import (
|
| 12 |
+
ELU,
|
| 13 |
+
Hardswish,
|
| 14 |
+
LeakyReLU,
|
| 15 |
+
MultiheadAttention,
|
| 16 |
+
PReLU,
|
| 17 |
+
ReLU6,
|
| 18 |
+
Sigmoid,
|
| 19 |
+
Softmax,
|
| 20 |
+
)
|
.venv/Lib/site-packages/torch/nn/quantized/modules/batchnorm.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.modules.batchnorm import BatchNorm2d, BatchNorm3d
|
.venv/Lib/site-packages/torch/nn/quantized/modules/conv.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.modules.conv import (
|
| 12 |
+
_reverse_repeat_padding,
|
| 13 |
+
Conv1d,
|
| 14 |
+
Conv2d,
|
| 15 |
+
Conv3d,
|
| 16 |
+
ConvTranspose1d,
|
| 17 |
+
ConvTranspose2d,
|
| 18 |
+
ConvTranspose3d,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"Conv1d",
|
| 24 |
+
"Conv2d",
|
| 25 |
+
"Conv3d",
|
| 26 |
+
"ConvTranspose1d",
|
| 27 |
+
"ConvTranspose2d",
|
| 28 |
+
"ConvTranspose3d",
|
| 29 |
+
]
|
.venv/Lib/site-packages/torch/nn/quantized/modules/dropout.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.modules.dropout import Dropout
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = ["Dropout"]
|
.venv/Lib/site-packages/torch/nn/quantized/modules/embedding_ops.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.modules.embedding_ops import (
|
| 12 |
+
Embedding,
|
| 13 |
+
EmbeddingBag,
|
| 14 |
+
EmbeddingPackedParams,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
|
.venv/Lib/site-packages/torch/nn/quantized/modules/functional_modules.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.modules.functional_modules import (
|
| 12 |
+
FloatFunctional,
|
| 13 |
+
FXFloatFunctional,
|
| 14 |
+
QFunctional,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"]
|
.venv/Lib/site-packages/torch/nn/quantized/modules/linear.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from torch.ao.nn.quantized.modules.linear import Linear, LinearPackedParams
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = ["LinearPackedParams", "Linear"]
|