glenn-jocher commited on
Commit
f340235
1 Parent(s): 10d56d7

check `batch_size % utilized_device_count` (#3276)

Browse files

Bug fix to check batch_size divisibility of utilized CUDA device count vs total system CUDA device count.

Files changed (1) hide show
  1. utils/torch_utils.py +4 -3
utils/torch_utils.py CHANGED
@@ -72,11 +72,12 @@ def select_device(device='', batch_size=None):
72
 
73
  cuda = not cpu and torch.cuda.is_available()
74
  if cuda:
75
- n = torch.cuda.device_count()
76
- if n > 1 and batch_size: # check that batch_size is compatible with device_count
 
77
  assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
78
  space = ' ' * len(s)
79
- for i, d in enumerate(device.split(',') if device else range(n)):
80
  p = torch.cuda.get_device_properties(i)
81
  s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
82
  else:
 
72
 
73
  cuda = not cpu and torch.cuda.is_available()
74
  if cuda:
75
+ devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
76
+ n = len(devices) # device count
77
+ if n > 1 and batch_size: # check batch_size is divisible by device_count
78
  assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
79
  space = ' ' * len(s)
80
+ for i, d in enumerate(devices):
81
  p = torch.cuda.get_device_properties(i)
82
  s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
83
  else: