asigalov61 commited on
Commit
36bfdeb
1 Parent(s): 1cd950b

Update pytorch_utils.py

Browse files
Files changed (1) hide show
  1. pytorch_utils.py +3 -2
pytorch_utils.py CHANGED
@@ -54,8 +54,9 @@ def forward(model, x, batch_size):
54
  pointer += batch_size
55
 
56
  with torch.no_grad():
57
- model.eval()
58
- batch_output_dict = model(batch_waveform)
 
59
 
60
  for key in batch_output_dict.keys():
61
  append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy())
 
54
  pointer += batch_size
55
 
56
  with torch.no_grad():
57
+ with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
58
+ model.eval()
59
+ batch_output_dict = model(batch_waveform)
60
 
61
  for key in batch_output_dict.keys():
62
  append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy())