|
import os |
|
from enum import Enum |
|
from .device_id import DeviceId |
|
|
|
|
|
|
|
class DeviceException(Exception): |
|
pass |
|
|
|
class _Device: |
|
def __init__(self): |
|
self.set(DeviceId.CPU) |
|
|
|
def is_gpu(self): |
|
''' Returns `True` if the current device is GPU, `False` otherwise. ''' |
|
return self.current() is not DeviceId.CPU |
|
|
|
def current(self): |
|
return self._current_device |
|
|
|
def set(self, device:DeviceId): |
|
if device == DeviceId.CPU: |
|
os.environ['CUDA_VISIBLE_DEVICES']='' |
|
else: |
|
os.environ['CUDA_VISIBLE_DEVICES']=str(device.value) |
|
import torch |
|
torch.backends.cudnn.benchmark=False |
|
|
|
self._current_device = device |
|
return device |