Spaces:
Build error
Build error
| import logging | |
| from mpi4py import MPI | |
| import os | |
| import re | |
| import subprocess | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class MPIAdapter: | |
| """ | |
| MPIAdapter automatically detects and analyzes the training environment for distributed training | |
| and offers methods to set up distributed training jobs. | |
| For example, it determines whether training happens on AML, Philly, or locally. | |
| It also determines variables such as the world size and the rank of each GPU. | |
| """ | |
| def __init__(self, port='55551', set_env_vars=True): | |
| local_address = '127.0.0.1' | |
| default_torch_distributed_port = port # chosen arbitrarily | |
| if 'OMPI_COMM_WORLD_SIZE' not in os.environ: | |
| # application was started without MPI | |
| # default to single node with single process | |
| self.env_info = 'no MPI' | |
| self.world_size = 1 | |
| self.local_size = 1 | |
| self.rank = 0 | |
| self.local_rank = 0 | |
| self.master_address = local_address | |
| self.master_port = default_torch_distributed_port | |
| else: | |
| # application was started with MPI | |
| # get MPI parameters | |
| self.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) | |
| self.local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) | |
| self.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) | |
| self.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) | |
| if 'PHILLY_CONTAINER_IP' in os.environ: | |
| # application is running on Philly | |
| # read environment variables on master node and broadcast via MPI | |
| self.env_info = 'philly' | |
| if self.rank == 0: | |
| self.master_address = os.environ['PHILLY_CONTAINER_IP'] | |
| self.master_port = os.environ['PHILLY_CONTAINER_PORT_RANGE_START'] | |
| else: | |
| self.master_address = None | |
| self.master_port = None | |
| self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0) | |
| self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0) | |
| elif "AMLK8S_NUM_WORKER" in os.environ or "AZ_CMK8S_JOB_WORK_DIR" in os.environ: | |
| # application is running on AMLK8S (ITP) | |
| # read master address from a specific file. | |
| self.env_info = 'AMLK8S (ITP)' | |
| # from: https://k8s-wiki.azureml.com/faq.html | |
| regexp = r"[\s\S]*export[\s]*DLTS_SD_worker0_IP=([0-9.]+)[\s|s]*" | |
| with open("/dlts-runtime/env/init.env", 'r') as f: | |
| line = f.read() | |
| match = re.match(regexp, line) | |
| if match: | |
| self.master_address = str(match.group(1)) | |
| else: | |
| # Did not find master node ip in file. It must be a single-node | |
| # debugging job with custom "mpirun" command | |
| assert self.world_size == self.local_size, \ | |
| "It's not a single-node debugging job on AMLK8S (ITP), but no master ip is found in file." | |
| self.env_info = 'single-node AMLK8S (ITP) debugging job' | |
| self.master_address = local_address | |
| self.master_port = default_torch_distributed_port | |
| elif 'AZ_BATCH_MASTER_NODE' in os.environ: | |
| # application is running on multiple nodes on AML | |
| self.env_info = 'multi-node AML' | |
| master_node_params = os.environ['AZ_BATCH_MASTER_NODE'].split(':') | |
| self.master_address = master_node_params[0] | |
| self.master_port = default_torch_distributed_port | |
| elif self.world_size == self.local_size: | |
| # application is running with MPI on single node | |
| self.env_info = 'single-node AML or other MPI environment' | |
| self.master_address = local_address | |
| self.master_port = default_torch_distributed_port | |
| else: | |
| # multi-node MPI environment, but not Philly or AML | |
| # we use "hostname -I" command on rank 0 to get the master address | |
| self.env_info = 'multi-node other MPI environment' | |
| if self.rank == 0: | |
| hostname_cmd = ["hostname -I"] | |
| result = subprocess.check_output(hostname_cmd, shell=True) | |
| self.master_address = result.decode('utf-8').split()[0] | |
| self.master_port = default_torch_distributed_port | |
| else: | |
| self.master_address = None | |
| self.master_port = None | |
| self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0) | |
| self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0) | |
| self.init_method_url = f'tcp://{self.master_address}:{self.master_port}' | |
| if set_env_vars: | |
| self._set_env_vars() | |
| def log_info(self): | |
| """ | |
| Logs information about distributed training environment. | |
| """ | |
| # of not printing logger.info messages on processes with rank > 0 | |
| logger.warning('----------------') | |
| logger.warning('MPI Adapter data') | |
| logger.warning('----------------') | |
| logger.warning(f'environment info: {self.env_info}') | |
| logger.warning(f'init method url: {self.init_method_url}') | |
| logger.warning(f'world size: {self.world_size}') | |
| logger.warning(f'local size: {self.local_size}') | |
| logger.warning(f'rank: {self.rank}') | |
| logger.warning(f'local rank: {self.local_rank}') | |
| logger.warning(f'master address: {self.master_address}') | |
| logger.warning(f'master port: {self.master_port}') | |
| logger.warning('----------------') | |
| def init_process_group(self, backend): | |
| """ | |
| Initializes the default PyTorch distributed process group. | |
| """ | |
| # of not printing logger.info messages on processes with rank > 0 | |
| logger.warning('trying to initialize process group ...') | |
| torch.distributed.init_process_group(backend=backend, | |
| init_method=self.init_method_url, | |
| world_size=self.world_size, | |
| rank=self.rank) | |
| logger.warning('process group initialized') | |
| def _set_env_vars(self): | |
| """ | |
| Sets environment variables for world size, rank, local rank, master addr, and master port. | |
| """ | |
| os.environ['WORLD_SIZE'] = str(self.world_size) | |
| os.environ['RANK'] = str(self.rank) | |
| os.environ["LOCAL_RANK"] = str(self.local_rank) | |
| os.environ['MASTER_ADDR'] = self.master_address | |
| os.environ['MASTER_PORT'] = self.master_port | |