File size: 7,298 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
from itertools import zip_longest, chain
import os.path as osp
import random
import torch
import os
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
import functools
from typing import Callable, Optional, Tuple
import pickle
import shutil
def _init_dist_pytorch(backend, **kwargs) -> None:
"""Initialize distributed environment with PyTorch launcher.
Args:
backend (str): Backend of torch.distributed. Supported backends are
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
torch_dist.init_process_group(backend=backend, **kwargs)
def get_dist_info(group=None) -> Tuple[int, int]:
"""Get distributed information of the given process group.
Note:
Calling ``get_dist_info`` in non-distributed environment will return
(0, 1).
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
tuple[int, int]: Return a tuple containing the ``rank`` and
``world_size``.
"""
world_size = get_world_size(group)
rank = get_rank(group)
return rank, world_size
def get_world_size(group: Optional[ProcessGroup] = None) -> int:
"""Return the number of the given process group.
Note:
Calling ``get_world_size`` in non-distributed environment will return
1.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
int: Return the number of processes of the given process group if in
distributed environment, otherwise 1.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
return torch_dist.get_world_size(group)
else:
return 1
def get_rank(group: Optional[ProcessGroup] = None) -> int:
"""Return the rank of the given process group.
Rank is a unique identifier assigned to each process within a distributed
process group. They are always consecutive integers ranging from 0 to
``world_size``.
Note:
Calling ``get_rank`` in non-distributed environment will return 0.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
int: Return the rank of the process group if in distributed
environment, otherwise 0.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
return torch_dist.get_rank(group)
else:
return 0
def is_distributed() -> bool:
"""Return True if distributed environment has been initialized."""
return torch_dist.is_available() and torch_dist.is_initialized()
def get_default_group() -> Optional[ProcessGroup]:
"""Return default process group."""
return torch_dist.distributed_c10d._get_default_group()
def is_main_process(group: Optional[ProcessGroup] = None) -> bool:
"""Whether the current rank of the given process group is equal to 0.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
bool: Return True if the current rank of the given process group is
equal to 0, otherwise False.
"""
return get_rank(group) == 0
def master_only(func: Callable) -> Callable:
"""Decorate those methods which should be executed in master process.
Args:
func (callable): Function to be decorated.
Returns:
callable: Return decorated function.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_main_process():
return func(*args, **kwargs)
return wrapper
def collect_results_cpu(result_part: list,
size: int,
tmpdir='./dist_test_temp'):
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
``tmpdir`` and collect them by the rank 0 worker.
Args:
result_part (list): Result list containing result parts
to be collected. Each item of ``result_part`` should be a picklable
object.
size (int): Size of the results, commonly equal to length of
the results.
tmpdir (str | None): Temporal directory for collected results to
store. If set to None, it will create a random temporal directory
for it. Defaults to None.
Returns:
list or None: The collected results.
"""
rank, world_size = get_dist_info()
if world_size == 1:
return result_part[:size]
# create a tmp dir if it is not specified
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
# dump the part result to the dir
with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: # type: ignore
pickle.dump(result_part, f, protocol=2)
barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
path = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore
if not osp.exists(path):
raise FileNotFoundError(
f'{tmpdir} is not an shared directory for '
f'rank {i}, please make sure {tmpdir} is a shared '
'directory for all ranks!')
with open(path, 'rb') as f:
part_list.append(pickle.load(f))
# sort the results
ordered_results = []
zipped_results = zip_longest(*part_list)
ordered_results = [
i for i in chain.from_iterable(zipped_results) if i is not None
]
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir) # type: ignore
return ordered_results
def barrier(group: Optional[ProcessGroup] = None) -> None:
"""Synchronize all processes from the given process group.
This collective blocks processes until the whole group enters this
function.
Note:
Calling ``barrier`` in non-distributed environment will do nothing.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
torch_dist.barrier(group) |