BrainFM / utils /distributed.py
peirong26's picture
Upload 187 files
2571f24 verified
#!/usr/bin/env python3
"""Distributed helpers."""
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
def all_reduce(tensors, average=True):
"""
All reduce the provided tensors from all processes across machines.
Args:
tensors (list): tensors to perform all reduce across all processes in
all machines.
average (bool): scales the reduced tensor by the number of overall
processes across all machines.
"""
for tensor in tensors:
dist.all_reduce(tensor, async_op=False)
if average:
world_size = dist.get_world_size()
for tensor in tensors:
tensor.mul_(1.0 / world_size)
return tensors
def is_master_proc(num_gpus=8):
"""
Determines if the current process is the master process.
"""
if torch.distributed.is_initialized():
return dist.get_rank() % num_gpus == 0
else:
return True
def get_world_size():
"""
Get the size of the world.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def init_distributed_training(cfg):
"""
Initialize variables needed for distributed training.
"""
if cfg.NUM_GPUS <= 1:
return
num_gpus_per_machine = cfg.NUM_GPUS
num_machines = dist.get_world_size() // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
)
pg = dist.new_group(ranks_on_i)
if i == cfg.SHARD_ID:
global _LOCAL_PROCESS_GROUP
_LOCAL_PROCESS_GROUP = pg