File size: 7,212 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
# mypy: allow-untyped-defs
import contextlib
import warnings
from collections.abc import Generator
import torch
from torch._C import default_generator
def set_rng_state(new_state: torch.Tensor) -> None:
r"""Sets the random number generator state.
.. note:: This function only works for CPU. For CUDA, please use
:func:`torch.manual_seed`, which works for both CPU and CUDA.
Args:
new_state (torch.ByteTensor): The desired state
"""
default_generator.set_state(new_state)
def get_rng_state() -> torch.Tensor:
r"""Returns the random number generator state as a `torch.ByteTensor`.
.. note:: The returned state is for the default generator on CPU only.
See also: :func:`torch.random.fork_rng`.
"""
return default_generator.get_state()
def manual_seed(seed) -> torch._C.Generator:
r"""Sets the seed for generating random numbers on all devices. Returns a
`torch.Generator` object.
Args:
seed (int): The desired seed. Value must be within the inclusive range
`[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
is raised. Negative inputs are remapped to positive values with the formula
`0xffff_ffff_ffff_ffff + seed`.
"""
seed = int(seed)
import torch.cuda
if not torch.cuda._is_in_bad_fork():
torch.cuda.manual_seed_all(seed)
import torch.mps
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
import torch.xpu
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)
_seed_custom_device(seed)
return default_generator.manual_seed(seed)
def seed() -> int:
r"""Sets the seed for generating random numbers to a non-deterministic
random number on all devices. Returns a 64 bit number used to seed the RNG.
"""
seed = default_generator.seed()
import torch.cuda
if not torch.cuda._is_in_bad_fork():
torch.cuda.manual_seed_all(seed)
import torch.mps
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
import torch.xpu
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)
_seed_custom_device(seed)
return seed
def _seed_custom_device(seed) -> None:
r"""Sets the seed to generate random numbers for custom device.
Args:
seed (int): The desired seed.
See [Note: support the custom device with privateuse1]
"""
seed = int(seed)
custom_backend_name = torch._C._get_privateuse1_backend_name()
if hasattr(torch, custom_backend_name):
custom_device_mod = getattr(torch, custom_backend_name)
_bad_fork_name = "_is_in_bad_fork"
_seed_all_name = "manual_seed_all"
if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
custom_device_mod, _seed_all_name
):
if not getattr(custom_device_mod, _bad_fork_name)():
getattr(custom_device_mod, _seed_all_name)(seed)
else:
message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
warnings.warn(message, UserWarning, stacklevel=3)
def initial_seed() -> int:
r"""Returns the initial seed for generating random numbers as a
Python `long`.
.. note:: The returned seed is for the default generator on CPU only.
"""
return default_generator.initial_seed()
_fork_rng_warned_already = False
@contextlib.contextmanager
def fork_rng(
devices=None,
enabled=True,
_caller="fork_rng",
_devices_kw="devices",
device_type="cuda",
) -> Generator:
"""
Forks the RNG, so that when you return, the RNG is reset
to the state that it was previously in.
Args:
devices (iterable of Device IDs): devices for which to fork
the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
on all devices, but will emit a warning if your machine has a lot
of devices, since this function will run very slowly in that case.
If you explicitly specify devices, this warning will be suppressed
enabled (bool): if ``False``, the RNG is not forked. This is a convenience
argument for easily disabling the context manager without having
to delete it and unindent your Python code under it.
device_type (str): device type str, default is `cuda`. As for custom device,
see details in [Note: support the custom device with privateuse1]
"""
if device_type == "meta":
yield
return
device_type = torch.device(device_type).type
device_mod = getattr(torch, device_type, None)
if device_mod is None:
raise RuntimeError(
f"torch has no module of `{device_type}`, you should register "
+ "a module by `torch._register_device_module`."
)
global _fork_rng_warned_already
# Internal arguments:
# _caller: the function which called fork_rng, which the user used
# _devices_kw: the devices keyword of _caller
if not enabled:
yield
return
if devices is None:
num_devices = device_mod.device_count()
if num_devices > 1 and not _fork_rng_warned_already:
message = (
f"{device_type.upper()} reports that you have {num_devices} available devices, and "
f"you have used {_caller} without explicitly specifying which devices are being used. "
f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
f" making use of a few {device_type.upper()} devices, set the environment variable "
f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
"with the set of devices you are actually using. For example, if you are using CPU only, "
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
f"`range(torch.{device_type}.device_count())`."
)
warnings.warn(message)
_fork_rng_warned_already = True
devices = list(range(num_devices))
else:
# Protect against user passing us a generator; we need to traverse this
# multiple times but a generator will be exhausted upon first traversal
devices = list(devices)
cpu_rng_state = torch.get_rng_state()
device_rng_states = [device_mod.get_rng_state(device) for device in devices]
try:
yield
finally:
torch.set_rng_state(cpu_rng_state)
for device, device_rng_state in zip(devices, device_rng_states):
device_mod.set_rng_state(device_rng_state, device)
|