Spaces:
Runtime error
Runtime error
File size: 1,396 Bytes
e775f6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# For all things related to devices
#### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS ####
import torch
TORCH_CPU_DEVICE = torch.device("cpu")
if(torch.cuda.device_count() > 0):
TORCH_CUDA_DEVICE = torch.device("cuda")
else:
print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
print("")
TORCH_CUDA_DEVICE = None
USE_CUDA = True
# use_cuda
def use_cuda(cuda_bool):
"""
----------
Author: Damon Gwinn
----------
Sets whether to use CUDA (if available), or use the CPU (not recommended)
----------
"""
global USE_CUDA
USE_CUDA = cuda_bool
# get_device
def get_device():
"""
----------
Author: Damon Gwinn
----------
Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise.
----------
"""
if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)):
return TORCH_CPU_DEVICE
else:
return TORCH_CUDA_DEVICE
# cuda_device
def cuda_device():
"""
----------
Author: Damon Gwinn
----------
Grabs the cuda device (may be None if CUDA is not available)
----------
"""
return TORCH_CUDA_DEVICE
# cpu_device
def cpu_device():
"""
----------
Author: Damon Gwinn
----------
Grabs the cpu device
----------
"""
return TORCH_CPU_DEVICE
|