# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import shutil import torch import torch.distributed as dist def get_model(model): if isinstance(model, torch.nn.DataParallel) \ or isinstance(model, torch.nn.parallel.DistributedDataParallel): return model.module else: return model def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 else: return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(state, is_best, output_dir, is_epoch=True): if is_main_process(): ckpt_path = f'{output_dir}/checkpoint.pt' best_path = f'{output_dir}/checkpoint_best.pt' if is_best: torch.save(state, best_path) if is_epoch: if isinstance(state['epoch'], int): ckpt2_path = '{}/checkpoint_{:04d}.pt'.format(output_dir, state['epoch']) else: ckpt2_path = '{}/checkpoint_{:.4f}.pt'.format(output_dir, state['epoch']) torch.save(state, ckpt_path) shutil.copy(ckpt_path, ckpt2_path) def init_distributed_mode(args): if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() else: print('Not using distributed mode') args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}'.format( args.rank, args.dist_url), flush=True) torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank ) torch.distributed.barrier() setup_for_distributed(args.rank == 0)