glenn-jocher commited on
Commit
a1748a8
1 Parent(s): bd3e389

test during training default to FP16

Browse files
Files changed (1) hide show
  1. test.py +5 -3
test.py CHANGED
@@ -23,6 +23,7 @@ def test(data,
23
  verbose=False):
24
  # Initialize/load model and set device
25
  if model is None:
 
26
  device = torch_utils.select_device(opt.device, batch_size=batch_size)
27
  half = device.type != 'cpu' # half precision only supported on CUDA
28
 
@@ -42,11 +43,12 @@ def test(data,
42
  if device.type != 'cpu' and torch.cuda.device_count() > 1:
43
  model = nn.DataParallel(model)
44
 
45
- training = False
46
  else: # called by train.py
47
- device = next(model.parameters()).device # get model device
48
- half = False
49
  training = True
 
 
 
 
50
 
51
  # Configure
52
  model.eval()
 
23
  verbose=False):
24
  # Initialize/load model and set device
25
  if model is None:
26
+ training = False
27
  device = torch_utils.select_device(opt.device, batch_size=batch_size)
28
  half = device.type != 'cpu' # half precision only supported on CUDA
29
 
 
43
  if device.type != 'cpu' and torch.cuda.device_count() > 1:
44
  model = nn.DataParallel(model)
45
 
 
46
  else: # called by train.py
 
 
47
  training = True
48
+ device = next(model.parameters()).device # get model device
49
+ half = device.type != 'cpu' # half precision only supported on CUDA
50
+ if half:
51
+ model.half() # to FP16
52
 
53
  # Configure
54
  model.eval()