zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
from itertools import zip_longest, chain
import os.path as osp
import random
import torch
import os
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
import functools
from typing import Callable, Optional, Tuple
import pickle
import shutil
def _init_dist_pytorch(backend, **kwargs) -> None:
"""Initialize distributed environment with PyTorch launcher.
Args:
backend (str): Backend of torch.distributed. Supported backends are
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
torch_dist.init_process_group(backend=backend, **kwargs)
def get_dist_info(group=None) -> Tuple[int, int]:
"""Get distributed information of the given process group.
Note:
Calling ``get_dist_info`` in non-distributed environment will return
(0, 1).
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
tuple[int, int]: Return a tuple containing the ``rank`` and
``world_size``.
"""
world_size = get_world_size(group)
rank = get_rank(group)
return rank, world_size
def get_world_size(group: Optional[ProcessGroup] = None) -> int:
"""Return the number of the given process group.
Note:
Calling ``get_world_size`` in non-distributed environment will return
1.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
int: Return the number of processes of the given process group if in
distributed environment, otherwise 1.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
return torch_dist.get_world_size(group)
else:
return 1
def get_rank(group: Optional[ProcessGroup] = None) -> int:
"""Return the rank of the given process group.
Rank is a unique identifier assigned to each process within a distributed
process group. They are always consecutive integers ranging from 0 to
``world_size``.
Note:
Calling ``get_rank`` in non-distributed environment will return 0.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
int: Return the rank of the process group if in distributed
environment, otherwise 0.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
return torch_dist.get_rank(group)
else:
return 0
def is_distributed() -> bool:
"""Return True if distributed environment has been initialized."""
return torch_dist.is_available() and torch_dist.is_initialized()
def get_default_group() -> Optional[ProcessGroup]:
"""Return default process group."""
return torch_dist.distributed_c10d._get_default_group()
def is_main_process(group: Optional[ProcessGroup] = None) -> bool:
"""Whether the current rank of the given process group is equal to 0.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
Returns:
bool: Return True if the current rank of the given process group is
equal to 0, otherwise False.
"""
return get_rank(group) == 0
def master_only(func: Callable) -> Callable:
"""Decorate those methods which should be executed in master process.
Args:
func (callable): Function to be decorated.
Returns:
callable: Return decorated function.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_main_process():
return func(*args, **kwargs)
return wrapper
def collect_results_cpu(result_part: list,
size: int,
tmpdir='./dist_test_temp'):
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
``tmpdir`` and collect them by the rank 0 worker.
Args:
result_part (list): Result list containing result parts
to be collected. Each item of ``result_part`` should be a picklable
object.
size (int): Size of the results, commonly equal to length of
the results.
tmpdir (str | None): Temporal directory for collected results to
store. If set to None, it will create a random temporal directory
for it. Defaults to None.
Returns:
list or None: The collected results.
"""
rank, world_size = get_dist_info()
if world_size == 1:
return result_part[:size]
# create a tmp dir if it is not specified
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
# dump the part result to the dir
with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: # type: ignore
pickle.dump(result_part, f, protocol=2)
barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
path = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore
if not osp.exists(path):
raise FileNotFoundError(
f'{tmpdir} is not an shared directory for '
f'rank {i}, please make sure {tmpdir} is a shared '
'directory for all ranks!')
with open(path, 'rb') as f:
part_list.append(pickle.load(f))
# sort the results
ordered_results = []
zipped_results = zip_longest(*part_list)
ordered_results = [
i for i in chain.from_iterable(zipped_results) if i is not None
]
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir) # type: ignore
return ordered_results
def barrier(group: Optional[ProcessGroup] = None) -> None:
"""Synchronize all processes from the given process group.
This collective blocks processes until the whole group enters this
function.
Note:
Calling ``barrier`` in non-distributed environment will do nothing.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Defaults to None.
"""
if is_distributed():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
if group is None:
group = get_default_group()
torch_dist.barrier(group)