import os import torch """ GPU wrappers """ use_gpu = False gpu_id = 0 device = None distributed = False dist_rank = 0 world_size = 1 def set_gpu_mode(mode, pbs=False): global use_gpu global device global gpu_id global distributed global dist_rank global world_size if pbs: gpu_id = int(os.environ.get("MPI_LOCALRANKID", 0)) dist_rank = int(os.environ.get("PMI_RANK", 0)) world_size = int(os.environ.get("PMI_SIZE", 1)) else: gpu_id = int(os.environ.get("SLURM_LOCALID", 0)) dist_rank = int(os.environ.get("SLURM_PROCID", 0)) world_size = int(os.environ.get("SLURM_NTASKS", 1)) distributed = world_size > 1 use_gpu = mode print('gpu_id: {}, dist_rank: {}, world_size: {}, distributed: {}'.format(gpu_id, dist_rank, world_size, distributed)) device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu") torch.backends.cudnn.benchmark = True