glenn-jocher
commited on
Commit
•
f340235
1
Parent(s):
10d56d7
check `batch_size % utilized_device_count` (#3276)
Browse filesBug fix to check batch_size divisibility of utilized CUDA device count vs total system CUDA device count.
- utils/torch_utils.py +4 -3
utils/torch_utils.py
CHANGED
@@ -72,11 +72,12 @@ def select_device(device='', batch_size=None):
|
|
72 |
|
73 |
cuda = not cpu and torch.cuda.is_available()
|
74 |
if cuda:
|
75 |
-
|
76 |
-
|
|
|
77 |
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
78 |
space = ' ' * len(s)
|
79 |
-
for i, d in enumerate(
|
80 |
p = torch.cuda.get_device_properties(i)
|
81 |
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
82 |
else:
|
|
|
72 |
|
73 |
cuda = not cpu and torch.cuda.is_available()
|
74 |
if cuda:
|
75 |
+
devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
76 |
+
n = len(devices) # device count
|
77 |
+
if n > 1 and batch_size: # check batch_size is divisible by device_count
|
78 |
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
79 |
space = ' ' * len(s)
|
80 |
+
for i, d in enumerate(devices):
|
81 |
p = torch.cuda.get_device_properties(i)
|
82 |
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
83 |
else:
|