DaS / segmenter_model /torch.py
vobecant
Initial commit
dd78229
raw
history blame
1.03 kB
import os
import torch
"""
GPU wrappers
"""
use_gpu = False
gpu_id = 0
device = None
distributed = False
dist_rank = 0
world_size = 1
def set_gpu_mode(mode, pbs=False):
global use_gpu
global device
global gpu_id
global distributed
global dist_rank
global world_size
if pbs:
gpu_id = int(os.environ.get("MPI_LOCALRANKID", 0))
dist_rank = int(os.environ.get("PMI_RANK", 0))
world_size = int(os.environ.get("PMI_SIZE", 1))
else:
gpu_id = int(os.environ.get("SLURM_LOCALID", 0))
dist_rank = int(os.environ.get("SLURM_PROCID", 0))
world_size = int(os.environ.get("SLURM_NTASKS", 1))
distributed = world_size > 1
use_gpu = mode
print('gpu_id: {}, dist_rank: {}, world_size: {}, distributed: {}'.format(gpu_id, dist_rank, world_size,
distributed))
device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu")
torch.backends.cudnn.benchmark = True