asigalov61 commited on
Commit
602f44d
·
verified ·
1 Parent(s): c16687b

Update pytorch_utils.py

Browse files
Files changed (1) hide show
  1. pytorch_utils.py +1 -1
pytorch_utils.py CHANGED
@@ -54,7 +54,7 @@ def forward(model, x, batch_size):
54
  pointer += batch_size
55
 
56
  with torch.no_grad():
57
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
58
  model.eval()
59
  batch_output_dict = model(batch_waveform)
60
 
 
54
  pointer += batch_size
55
 
56
  with torch.no_grad():
57
+ with torch.amp.autocast(device_type='cuda'):
58
  model.eval()
59
  batch_output_dict = model(batch_waveform)
60