hilamanor's picture
initial commit
e73da9c
import os
import torch
import socket
try:
import horovod.torch as hvd
except ImportError:
hvd = None
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
def is_using_horovod():
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
if all([var in os.environ for var in ompi_vars]) or all(
[var in os.environ for var in pmi_vars]
):
return True
else:
return False
def is_using_distributed():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"]) > 1
if "SLURM_NTASKS" in os.environ:
return int(os.environ["SLURM_NTASKS"]) > 1
return False
def world_info_from_env():
local_rank = 0
for v in (
"SLURM_LOCALID",
"MPI_LOCALRANKID",
"OMPI_COMM_WORLD_LOCAL_RANK",
"LOCAL_RANK",
):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
if args.horovod:
assert hvd is not None, "Horovod is not installed"
hvd.init()
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
args.local_rank = local_rank
args.rank = world_rank
args.world_size = world_size
# args.local_rank = int(hvd.local_rank())
# args.rank = hvd.rank()
# args.world_size = hvd.size()
args.distributed = True
os.environ["LOCAL_RANK"] = str(args.local_rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
print(
f"Distributed training: local_rank={args.local_rank}, "
f"rank={args.rank}, world_size={args.world_size}, "
f"hostname={socket.gethostname()}, pid={os.getpid()}"
)
elif is_using_distributed():
if "SLURM_PROCID" in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ["LOCAL_RANK"] = str(args.local_rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
args.local_rank = local_rank
args.rank = world_rank
args.world_size = world_size
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url
)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
print(
f"Distributed training: local_rank={args.local_rank}, "
f"rank={args.rank}, world_size={args.world_size}, "
f"hostname={socket.gethostname()}, pid={os.getpid()}"
)
if torch.cuda.is_available():
if args.distributed and not args.no_set_device_rank:
device = "cuda:%d" % args.local_rank
else:
device = "cuda:0"
torch.cuda.set_device(device)
else:
device = "cpu"
args.device = device
device = torch.device(device)
return device