| import json | |
| import logging | |
| import os | |
| import subprocess | |
| from argparse import ArgumentParser | |
| logger = logging.getLogger(__name__) | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parsed, unknown = parser.parse_known_args() | |
| for arg in unknown: | |
| if arg.startswith(("-", "--")): | |
| parser.add_argument(arg.split("=")[0]) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| port = 8888 | |
| num_gpus = int(os.environ["SM_NUM_GPUS"]) | |
| hosts = json.loads(os.environ["SM_HOSTS"]) | |
| num_nodes = len(hosts) | |
| current_host = os.environ["SM_CURRENT_HOST"] | |
| rank = hosts.index(current_host) | |
| os.environ["NCCL_DEBUG"] = "INFO" | |
| if num_nodes > 1: | |
| cmd = f"""python -m torch.distributed.launch \ | |
| --nnodes={num_nodes} \ | |
| --node_rank={rank} \ | |
| --nproc_per_node={num_gpus} \ | |
| --master_addr={hosts[0]} \ | |
| --master_port={port} \ | |
| ./run_glue.py \ | |
| {"".join([f" --{parameter} {value}" for parameter, value in args.__dict__.items()])}""" | |
| else: | |
| cmd = f"""python -m torch.distributed.launch \ | |
| --nproc_per_node={num_gpus} \ | |
| ./run_glue.py \ | |
| {"".join([f" --{parameter} {value}" for parameter, value in args.__dict__.items()])}""" | |
| try: | |
| subprocess.run(cmd, shell=True) | |
| except Exception as e: | |
| logger.info(e) | |
| if __name__ == "__main__": | |
| main() | |