import os import torch import socket try: import horovod.torch as hvd except ImportError: hvd = None def is_global_master(args): return args.rank == 0 def is_local_master(args): return args.local_rank == 0 def is_master(args, local=False): return is_local_master(args) if local else is_global_master(args) def is_using_horovod(): # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] pmi_vars = ["PMI_RANK", "PMI_SIZE"] if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): return True else: return False def is_using_distributed(): if 'WORLD_SIZE' in os.environ: return int(os.environ['WORLD_SIZE']) > 1 if 'SLURM_NTASKS' in os.environ: return int(os.environ['SLURM_NTASKS']) > 1 return False def world_info_from_env(): local_rank = 0 for v in ('SLURM_LOCALID', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'LOCAL_RANK'): if v in os.environ: local_rank = int(os.environ[v]) break global_rank = 0 for v in ('SLURM_PROCID', 'PMI_RANK', 'OMPI_COMM_WORLD_RANK', 'RANK'): if v in os.environ: global_rank = int(os.environ[v]) break world_size = 1 for v in ('SLURM_NTASKS', 'PMI_SIZE', 'OMPI_COMM_WORLD_SIZE', 'WORLD_SIZE'): if v in os.environ: world_size = int(os.environ[v]) break return local_rank, global_rank, world_size def init_distributed_device(args): # Distributed training = training on more than one GPU. # Works in both single and multi-node scenarios. args.distributed = False args.world_size = 1 args.rank = 0 # global rank args.local_rank = 0 if args.horovod: assert hvd is not None, "Horovod is not installed" hvd.init() world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) args.local_rank = local_rank args.rank = world_rank args.world_size = world_size # args.local_rank = int(hvd.local_rank()) # args.rank = hvd.rank() # args.world_size = hvd.size() args.distributed = True os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) print(f"Distributed training: local_rank={args.local_rank}, " f"rank={args.rank}, world_size={args.world_size}, " f"hostname={socket.gethostname()}, pid={os.getpid()}") elif is_using_distributed(): if 'SLURM_PROCID' in os.environ: # DDP via SLURM args.local_rank, args.rank, args.world_size = world_info_from_env() # SLURM var -> torch.distributed vars in case needed os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, ) elif 'OMPI_COMM_WORLD_SIZE' in os.environ: # using Summit cluster world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) args.local_rank = local_rank args.rank = world_rank args.world_size = world_size torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, ) else: # DDP via torchrun, torch.distributed.launch args.local_rank, _, _ = world_info_from_env() torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() args.distributed = True print(f"Distributed training: local_rank={args.local_rank}, " f"rank={args.rank}, world_size={args.world_size}, " f"hostname={socket.gethostname()}, pid={os.getpid()}") if torch.cuda.is_available(): if args.distributed and not args.no_set_device_rank: device = 'cuda:%d' % args.local_rank else: device = 'cuda:0' torch.cuda.set_device(device) else: device = 'cpu' args.device = device device = torch.device(device) return device