gpt-omni commited on
Commit
8b673c6
1 Parent(s): 03a92bb

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +1 -1
inference.py CHANGED
@@ -417,7 +417,7 @@ class OmniInference:
417
  list_output = [[] for i in range(8)]
418
  tokens_A, token_T = next_token_batch(
419
  model,
420
- audio_feature.to(torch.float32).to(model.device),
421
  input_ids,
422
  [T - 3, T - 3],
423
  ["A1T2", "A1T2"],
 
417
  list_output = [[] for i in range(8)]
418
  tokens_A, token_T = next_token_batch(
419
  model,
420
+ audio_feature.to(torch.float32).to(device),
421
  input_ids,
422
  [T - 3, T - 3],
423
  ["A1T2", "A1T2"],