# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import random import re import socket from typing import Dict, List import torch import torch.distributed as dist _LOCAL_RANK = -1 _LOCAL_WORLD_SIZE = -1 def is_enabled() -> bool: """ Returns: True if distributed training is enabled """ return dist.is_available() and dist.is_initialized() def get_global_size() -> int: """ Returns: The number of processes in the process group """ return dist.get_world_size() if is_enabled() else 1 def get_global_rank() -> int: """ Returns: The rank of the current process within the global process group. """ return dist.get_rank() if is_enabled() else 0 def get_local_rank() -> int: """ Returns: The rank of the current process within the local (per-machine) process group. """ if not is_enabled(): return 0 assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE return _LOCAL_RANK def get_local_size() -> int: """ Returns: The size of the per-machine process group, i.e. the number of processes per machine. """ if not is_enabled(): return 1 assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE return _LOCAL_WORLD_SIZE def is_main_process() -> bool: """ Returns: True if the current process is the main one. """ return get_global_rank() == 0 def _restrict_print_to_main_process() -> None: """ This function disables printing when not in the main process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop("force", False) if is_main_process() or force: builtin_print(*args, **kwargs) __builtin__.print = print def _get_master_port(seed: int = 0) -> int: MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) master_port_str = os.environ.get("MASTER_PORT") if master_port_str is None: rng = random.Random(seed) return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) return int(master_port_str) def _get_available_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # A "" host address means INADDR_ANY i.e. binding to all interfaces. # Note this is not compatible with IPv6. s.bind(("", 0)) port = s.getsockname()[1] return port _TORCH_DISTRIBUTED_ENV_VARS = ( "MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_WORLD_SIZE", ) def _collect_env_vars() -> Dict[str, str]: return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ} def _is_slurm_job_process() -> bool: return "SLURM_JOB_ID" in os.environ def _parse_slurm_node_list(s: str) -> List[str]: nodes = [] # Extract "hostname", "hostname[1-2,3,4-5]," substrings p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") for m in p.finditer(s): prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] for suffix in suffixes.split(","): span = suffix.split("-") if len(span) == 1: nodes.append(prefix + suffix) else: width = len(span[0]) start, end = int(span[0]), int(span[1]) + 1 nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) return nodes def _check_env_variable(key: str, new_value: str): # Only check for difference with preset environment variables if key in os.environ and os.environ[key] != new_value: raise RuntimeError(f"Cannot export environment variables as {key} is already set") class _TorchDistributedEnvironment: def __init__(self): self.master_addr = "127.0.0.1" self.master_port = 0 self.rank = -1 self.world_size = -1 self.local_rank = -1 self.local_world_size = -1 if _is_slurm_job_process(): return self._set_from_slurm_env() env_vars = _collect_env_vars() if not env_vars: # Environment is not set pass elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): # Environment is fully set return self._set_from_preset_env() else: # Environment is partially set collected_env_vars = ", ".join(env_vars.keys()) raise RuntimeError(f"Partially set environment: {collected_env_vars}") if torch.cuda.device_count() > 0: return self._set_from_local() raise RuntimeError("Can't initialize PyTorch distributed environment") # Slurm job created with sbatch, submitit, etc... def _set_from_slurm_env(self): # logger.info("Initialization from Slurm environment") job_id = int(os.environ["SLURM_JOB_ID"]) node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) assert len(nodes) == node_count self.master_addr = nodes[0] self.master_port = _get_master_port(seed=job_id) self.rank = int(os.environ["SLURM_PROCID"]) self.world_size = int(os.environ["SLURM_NTASKS"]) assert self.rank < self.world_size self.local_rank = int(os.environ["SLURM_LOCALID"]) self.local_world_size = self.world_size // node_count assert self.local_rank < self.local_world_size # Single node job with preset environment (i.e. torchrun) def _set_from_preset_env(self): # logger.info("Initialization from preset environment") self.master_addr = os.environ["MASTER_ADDR"] self.master_port = os.environ["MASTER_PORT"] self.rank = int(os.environ["RANK"]) self.world_size = int(os.environ["WORLD_SIZE"]) assert self.rank < self.world_size self.local_rank = int(os.environ["LOCAL_RANK"]) self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) assert self.local_rank < self.local_world_size # Single node and GPU job (i.e. local script run) def _set_from_local(self): # logger.info("Initialization from local") self.master_addr = "127.0.0.1" self.master_port = _get_available_port() self.rank = 0 self.world_size = 1 self.local_rank = 0 self.local_world_size = 1 def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": # See the "Environment variable initialization" section from # https://pytorch.org/docs/stable/distributed.html for the complete list of # environment variables required for the env:// initialization method. env_vars = { "MASTER_ADDR": self.master_addr, "MASTER_PORT": str(self.master_port), "RANK": str(self.rank), "WORLD_SIZE": str(self.world_size), "LOCAL_RANK": str(self.local_rank), "LOCAL_WORLD_SIZE": str(self.local_world_size), } if not overwrite: for k, v in env_vars.items(): _check_env_variable(k, v) os.environ.update(env_vars) return self def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False): """Enable distributed mode Args: set_cuda_current_device: If True, call torch.cuda.set_device() to set the current PyTorch CUDA device to the one matching the local rank. overwrite: If True, overwrites already set variables. Else fails. """ global _LOCAL_RANK, _LOCAL_WORLD_SIZE if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: raise RuntimeError("Distributed mode has already been enabled") torch_env = _TorchDistributedEnvironment() torch_env.export(overwrite=overwrite) if set_cuda_current_device: torch.cuda.set_device(torch_env.local_rank) if allow_nccl_timeout: # This allows to use torch distributed timeout in a NCCL backend key, value = "NCCL_ASYNC_ERROR_HANDLING", "1" if not overwrite: _check_env_variable(key, value) os.environ[key] = value dist.init_process_group(backend="nccl") dist.barrier() # Finalize setup _LOCAL_RANK = torch_env.local_rank _LOCAL_WORLD_SIZE = torch_env.local_world_size _restrict_print_to_main_process()