File size: 833 Bytes
0cb9530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
from enum import Enum
from .device_id import DeviceId

#NOTE:  This must be called first before any torch imports in order to work properly!

class DeviceException(Exception):
    pass

class _Device:
    def __init__(self):
        self.set(DeviceId.CPU)

    def is_gpu(self):
        ''' Returns `True` if the current device is GPU, `False` otherwise. '''
        return self.current() is not DeviceId.CPU
  
    def current(self):
        return self._current_device

    def set(self, device:DeviceId):     
        if device == DeviceId.CPU:
            os.environ['CUDA_VISIBLE_DEVICES']=''
        else:
            os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
            import torch
            torch.backends.cudnn.benchmark=False
        
        self._current_device = device    
        return device