glenn-jocher
commited on
Commit
•
ffef771
1
Parent(s):
0f11aaf
Update torch_utils.py (#1895)
Browse files- utils/torch_utils.py +1 -1
utils/torch_utils.py
CHANGED
@@ -61,7 +61,7 @@ def select_device(device='', batch_size=None):
|
|
61 |
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
62 |
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
|
63 |
|
64 |
-
cuda = torch.cuda.is_available()
|
65 |
if cuda:
|
66 |
n = torch.cuda.device_count()
|
67 |
if n > 1 and batch_size: # check that batch_size is compatible with device_count
|
|
|
61 |
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
62 |
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
|
63 |
|
64 |
+
cuda = not cpu and torch.cuda.is_available()
|
65 |
if cuda:
|
66 |
n = torch.cuda.device_count()
|
67 |
if n > 1 and batch_size: # check that batch_size is compatible with device_count
|