import datetime import argparse, importlib from pytorch_lightning import seed_everything import torch import torch.distributed as dist def setup_dist(local_rank): if dist.is_initialized(): return torch.cuda.set_device(local_rank) torch.distributed.init_process_group('nccl', init_method='env://') def get_dist_info(): if dist.is_available(): initialized = dist.is_initialized() else: initialized = False if initialized: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 return rank, world_size if __name__ == '__main__': now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") parser = argparse.ArgumentParser() parser.add_argument("--module", type=str, help="module name", default="inference") parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) args, unknown = parser.parse_known_args() inference_api = importlib.import_module(args.module, package=None) inference_parser = inference_api.get_parser() inference_args, unknown = inference_parser.parse_known_args() seed_everything(inference_args.seed) setup_dist(args.local_rank) torch.backends.cudnn.benchmark = True rank, gpu_num = get_dist_info() print("@CoLVDM Inference [rank%d]: %s"%(rank, now)) inference_api.run_inference(inference_args, gpu_num, rank)