|
import dataclasses |
|
import os |
|
import socket |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.distributed |
|
|
|
|
|
@dataclasses.dataclass |
|
class DistributedOption: |
|
|
|
distributed: bool = False |
|
|
|
dist_backend: str = "nccl" |
|
|
|
|
|
dist_init_method: str = "env://" |
|
dist_world_size: Optional[int] = None |
|
dist_rank: Optional[int] = None |
|
local_rank: Optional[int] = None |
|
ngpu: int = 0 |
|
dist_master_addr: Optional[str] = None |
|
dist_master_port: Optional[int] = None |
|
dist_launcher: Optional[str] = None |
|
multiprocessing_distributed: bool = True |
|
|
|
def init_options(self): |
|
if self.distributed: |
|
if self.dist_init_method == "env://": |
|
if get_master_addr(self.dist_master_addr, self.dist_launcher) is None: |
|
raise RuntimeError( |
|
"--dist_master_addr or MASTER_ADDR must be set " |
|
"if --dist_init_method == 'env://'" |
|
) |
|
if get_master_port(self.dist_master_port) is None: |
|
raise RuntimeError( |
|
"--dist_master_port or MASTER_PORT must be set " |
|
"if --dist_init_port == 'env://'" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.dist_rank = get_rank(self.dist_rank, self.dist_launcher) |
|
self.dist_world_size = get_world_size( |
|
self.dist_world_size, self.dist_launcher |
|
) |
|
self.local_rank = get_local_rank(self.local_rank, self.dist_launcher) |
|
|
|
if self.local_rank is not None: |
|
if self.ngpu > 1: |
|
raise RuntimeError(f"Assuming 1GPU in this case: ngpu={self.ngpu}") |
|
if "CUDA_VISIBLE_DEVICES" in os.environ: |
|
cvd = os.environ["CUDA_VISIBLE_DEVICES"] |
|
if self.local_rank >= len(cvd.split(",")): |
|
raise RuntimeError( |
|
f"LOCAL_RANK={self.local_rank} is bigger " |
|
f"than the number of visible devices: {cvd}" |
|
) |
|
|
|
if ( |
|
self.dist_rank is not None |
|
and self.dist_world_size is not None |
|
and self.dist_rank >= self.dist_world_size |
|
): |
|
raise RuntimeError( |
|
f"RANK >= WORLD_SIZE: {self.dist_rank} >= {self.dist_world_size}" |
|
) |
|
|
|
if self.dist_init_method == "env://": |
|
self.dist_master_addr = get_master_addr( |
|
self.dist_master_addr, self.dist_launcher |
|
) |
|
self.dist_master_port = get_master_port(self.dist_master_port) |
|
if ( |
|
self.dist_master_addr is not None |
|
and self.dist_master_port is not None |
|
): |
|
self.dist_init_method = ( |
|
f"tcp://{self.dist_master_addr}:{self.dist_master_port}" |
|
) |
|
|
|
def init_torch_distributed(self): |
|
if self.distributed: |
|
|
|
|
|
os.environ.setdefault("NCCL_DEBUG", "INFO") |
|
|
|
|
|
|
|
os.environ.setdefault("NCCL_BLOCKING_WAIT", "1") |
|
|
|
torch.distributed.init_process_group( |
|
backend=self.dist_backend, |
|
init_method=self.dist_init_method, |
|
world_size=self.dist_world_size, |
|
rank=self.dist_rank, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.local_rank is not None and self.ngpu > 0: |
|
torch.cuda.set_device(self.local_rank) |
|
|
|
|
|
def resolve_distributed_mode(args): |
|
|
|
|
|
|
|
if args.multiprocessing_distributed: |
|
num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher) |
|
|
|
if num_nodes > 1: |
|
args.distributed = True |
|
|
|
elif args.ngpu > 1: |
|
args.distributed = True |
|
|
|
else: |
|
args.distributed = False |
|
|
|
if args.ngpu <= 1: |
|
|
|
args.multiprocessing_distributed = False |
|
if args.ngpu == 1: |
|
|
|
|
|
args.local_rank = 0 |
|
|
|
if num_nodes > 1 and get_node_rank(args.dist_rank, args.dist_launcher) is None: |
|
raise RuntimeError( |
|
"--dist_rank or RANK must be set " |
|
"if --multiprocessing_distributed == true" |
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if get_world_size(args.dist_world_size, args.dist_launcher) > 1: |
|
args.distributed = True |
|
|
|
else: |
|
args.distributed = False |
|
|
|
if args.distributed and args.ngpu > 0: |
|
if get_local_rank(args.local_rank, args.dist_launcher) is None: |
|
raise RuntimeError( |
|
"--local_rank or LOCAL_RANK must be set " |
|
"if --multiprocessing_distributed == false" |
|
) |
|
if args.distributed: |
|
if get_node_rank(args.dist_rank, args.dist_launcher) is None: |
|
raise RuntimeError( |
|
"--dist_rank or RANK must be set " |
|
"if --multiprocessing_distributed == false" |
|
) |
|
if args.distributed and args.dist_launcher == "slurm" and not is_in_slurm_step(): |
|
raise RuntimeError("Launch by 'srun' command if --dist_launcher='slurm'") |
|
|
|
|
|
def is_in_slurm_job() -> bool: |
|
return "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ |
|
|
|
|
|
def is_in_slurm_step() -> bool: |
|
return ( |
|
is_in_slurm_job() |
|
and "SLURM_STEP_NUM_NODES" in os.environ |
|
and "SLURM_STEP_NODELIST" in os.environ |
|
) |
|
|
|
|
|
def _int_or_none(x: Optional[str]) -> Optional[int]: |
|
if x is None: |
|
return x |
|
return int(x) |
|
|
|
|
|
def free_port(): |
|
"""Find free port using bind(). |
|
|
|
There are some interval between finding this port and using it |
|
and the other process might catch the port by that time. |
|
Thus it is not guaranteed that the port is really empty. |
|
|
|
""" |
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: |
|
sock.bind(("", 0)) |
|
return sock.getsockname()[1] |
|
|
|
|
|
def get_rank(prior=None, launcher: str = None) -> Optional[int]: |
|
if prior is None: |
|
if launcher == "slurm": |
|
if not is_in_slurm_step(): |
|
raise RuntimeError("This process seems not to be launched by 'srun'") |
|
prior = os.environ["SLURM_PROCID"] |
|
elif launcher == "mpi": |
|
raise RuntimeError( |
|
"launcher=mpi is used for 'multiprocessing-distributed' mode" |
|
) |
|
elif launcher is not None: |
|
raise RuntimeError(f"launcher='{launcher}' is not supported") |
|
|
|
if prior is not None: |
|
return int(prior) |
|
else: |
|
|
|
return _int_or_none(os.environ.get("RANK")) |
|
|
|
|
|
def get_world_size(prior=None, launcher: str = None) -> int: |
|
if prior is None: |
|
if launcher == "slurm": |
|
if not is_in_slurm_step(): |
|
raise RuntimeError("This process seems not to be launched by 'srun'") |
|
prior = int(os.environ["SLURM_NTASKS"]) |
|
elif launcher == "mpi": |
|
raise RuntimeError( |
|
"launcher=mpi is used for 'multiprocessing-distributed' mode" |
|
) |
|
elif launcher is not None: |
|
raise RuntimeError(f"launcher='{launcher}' is not supported") |
|
|
|
if prior is not None: |
|
return int(prior) |
|
else: |
|
|
|
return int(os.environ.get("WORLD_SIZE", "1")) |
|
|
|
|
|
def get_local_rank(prior=None, launcher: str = None) -> Optional[int]: |
|
|
|
|
|
if prior is None: |
|
if launcher == "slurm": |
|
if not is_in_slurm_step(): |
|
raise RuntimeError("This process seems not to be launched by 'srun'") |
|
|
|
prior = int(os.environ["SLURM_LOCALID"]) |
|
elif launcher == "mpi": |
|
raise RuntimeError( |
|
"launcher=mpi is used for 'multiprocessing-distributed' mode" |
|
) |
|
elif launcher is not None: |
|
raise RuntimeError(f"launcher='{launcher}' is not supported") |
|
|
|
if prior is not None: |
|
return int(prior) |
|
|
|
elif "LOCAL_RANK" in os.environ: |
|
return int(os.environ["LOCAL_RANK"]) |
|
|
|
elif "CUDA_VISIBLE_DEVICES" in os.environ: |
|
|
|
|
|
|
|
|
|
|
|
|
|
cvd = os.environ["CUDA_VISIBLE_DEVICES"].split(",") |
|
if len(cvd) == 1 and "LOCAL_RANK" not in os.environ: |
|
|
|
|
|
|
|
|
|
|
|
return int(os.environ.pop("CUDA_VISIBLE_DEVICES")) |
|
else: |
|
return None |
|
else: |
|
return None |
|
|
|
|
|
def get_master_addr(prior=None, launcher: str = None) -> Optional[str]: |
|
if prior is None: |
|
if launcher == "slurm": |
|
if not is_in_slurm_step(): |
|
raise RuntimeError("This process seems not to be launched by 'srun'") |
|
|
|
|
|
nodelist = os.environ["SLURM_STEP_NODELIST"] |
|
prior = nodelist.split(",")[0].split("-")[0].replace("[", "") |
|
|
|
if prior is not None: |
|
return str(prior) |
|
else: |
|
return os.environ.get("MASTER_ADDR") |
|
|
|
|
|
def get_master_port(prior=None) -> Optional[int]: |
|
if prior is not None: |
|
return prior |
|
else: |
|
return _int_or_none(os.environ.get("MASTER_PORT")) |
|
|
|
|
|
def get_node_rank(prior=None, launcher: str = None) -> Optional[int]: |
|
"""Get Node Rank. |
|
|
|
Use for "multiprocessing distributed" mode. |
|
The initial RANK equals to the Node id in this case and |
|
the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed. |
|
|
|
""" |
|
if prior is not None: |
|
return prior |
|
elif launcher == "slurm": |
|
if not is_in_slurm_step(): |
|
raise RuntimeError("This process seems not to be launched by 'srun'") |
|
|
|
|
|
if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]: |
|
raise RuntimeError( |
|
"Run with --ntasks_per_node=1 if mutliprocessing_distributed=true" |
|
) |
|
return int(os.environ["SLURM_NODEID"]) |
|
elif launcher == "mpi": |
|
|
|
from mpi4py import MPI |
|
|
|
comm = MPI.COMM_WORLD |
|
|
|
return comm.Get_rank() |
|
elif launcher is not None: |
|
raise RuntimeError(f"launcher='{launcher}' is not supported") |
|
else: |
|
return _int_or_none(os.environ.get("RANK")) |
|
|
|
|
|
def get_num_nodes(prior=None, launcher: str = None) -> Optional[int]: |
|
"""Get the number of nodes. |
|
|
|
Use for "multiprocessing distributed" mode. |
|
RANK equals to the Node id in this case and |
|
the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed. |
|
|
|
""" |
|
if prior is not None: |
|
return prior |
|
elif launcher == "slurm": |
|
if not is_in_slurm_step(): |
|
raise RuntimeError("This process seems not to be launched by 'srun'") |
|
|
|
|
|
if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]: |
|
raise RuntimeError( |
|
"Run with --ntasks_per_node=1 if mutliprocessing_distributed=true" |
|
) |
|
return int(os.environ["SLURM_STEP_NUM_NODES"]) |
|
elif launcher == "mpi": |
|
|
|
from mpi4py import MPI |
|
|
|
comm = MPI.COMM_WORLD |
|
|
|
return comm.Get_size() |
|
elif launcher is not None: |
|
raise RuntimeError(f"launcher='{launcher}' is not supported") |
|
else: |
|
|
|
return int(os.environ.get("WORLD_SIZE", 1)) |
|
|