File size: 5,493 Bytes
18652d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import gc
import os
import logging
from typing import Optional, TypeVar, List, Tuple
import torch
import torch.distributed as dist
T = TypeVar("T")
log = logging.getLogger(__name__)
def seed_all(seed: int):
"""Seed all rng objects."""
import random
import numpy as np
if seed < 0 or seed > 2**32 - 1:
raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# torch.manual_seed may call manual_seed_all but calling it again here
# to make sure it gets called at least once
torch.cuda.manual_seed_all(seed)
def is_distributed() -> bool:
return dist.is_available() and dist.is_initialized()
def get_node_rank() -> int:
return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size())
def get_world_size() -> int:
if is_distributed():
return dist.get_world_size()
else:
return 1
def get_local_world_size() -> int:
return int(os.environ.get("LOCAL_WORLD_SIZE") or 1)
def get_global_rank() -> int:
if is_distributed():
return int(os.environ.get("RANK") or dist.get_rank())
else:
return 0
def get_local_rank() -> int:
return int(os.environ.get("LOCAL_RANK") or 0)
def get_fs_local_rank() -> int:
"""Get the local rank per filesystem, meaning that, regardless of the number of nodes,
if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`,
but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`.
"""
if os.environ.get("OLMO_SHARED_FS"):
return int(os.environ.get("FS_LOCAL_RANK") or get_global_rank())
else:
return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())
def move_to_device(o: T, device: torch.device) -> T:
if isinstance(o, torch.Tensor):
return o.to(device) # type: ignore[return-value]
elif isinstance(o, dict):
return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
elif isinstance(o, list):
return [move_to_device(x, device) for x in o] # type: ignore[return-value]
elif isinstance(o, tuple):
return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value]
else:
return o
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
"""
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
"""
if check_neg_inf:
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
if check_pos_inf:
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
def get_default_device() -> torch.device:
if torch.cuda.is_available() and torch.cuda.is_initialized():
return torch.device("cuda")
else:
return torch.device("cpu")
def barrier() -> None:
if is_distributed():
dist.barrier()
def peak_gpu_memory(reset: bool = False) -> Optional[float]:
"""
Get the peak GPU memory usage in MB across all ranks.
Only rank 0 will get the final result.
"""
if not torch.cuda.is_available():
return None
device = torch.device("cuda")
peak_mb = torch.cuda.max_memory_allocated(device) / 1000000
if is_distributed():
peak_mb_tensor = torch.tensor(peak_mb, device=device)
dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX)
peak_mb = peak_mb_tensor.item()
if reset:
# Reset peak stats.
torch.cuda.reset_max_memory_allocated(device)
return peak_mb
V = TypeVar("V", bool, int, float)
def synchronize_value(value: V, device: torch.device) -> V:
if dist.is_available() and dist.is_initialized():
value_tensor = torch.tensor(value, device=device)
dist.broadcast(value_tensor, 0)
return value_tensor.item() # type: ignore
else:
return value
def synchronize_flag(flag: bool, device: torch.device) -> bool:
return synchronize_value(flag, device)
def gc_cuda():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def listinstr(lst, s, delimiter=None):
assert isinstance(lst, list)
for item in lst:
if delimiter:
if all(x in s for x in item.split(delimiter)):
return True
else:
if item in s:
return True
return False
def freeze_module(module: torch.nn.Module, exclude_params: Optional[List[str]] = None):
for name, param in module.named_parameters():
if exclude_params is not None and listinstr(exclude_params, name):
continue
param.requires_grad = False
def freeze_parameters_by_name(model: torch.nn.Module, freeze_names: Tuple[str]):
for name in freeze_names:
try:
module_or_param = model.get_submodule(name)
except:
try:
module_or_param = model.get_parameter(name)
except:
log.warning(f"Could not find module or parameter with name {name}")
if isinstance(module_or_param, torch.nn.Module):
freeze_module(module_or_param)
else:
module_or_param.requires_grad = False |