lmzjms commited on
Commit
97b6dd1
1 Parent(s): f889c4d

Update NeuralSeq/inference/tts/base_tts_infer.py

Browse files
NeuralSeq/inference/tts/base_tts_infer.py CHANGED
@@ -92,7 +92,7 @@ class BaseTTSInfer:
92
  input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
93
 
94
  # retrieve logits & take argmax
95
- logits = self.asr_model(input_values.cuda()).logits
96
  predicted_ids = torch.argmax(logits, dim=-1)
97
 
98
  # transcribe
 
92
  input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
93
 
94
  # retrieve logits & take argmax
95
+ logits = self.asr_model(input_values).logits
96
  predicted_ids = torch.argmax(logits, dim=-1)
97
 
98
  # transcribe