Rongjiehuang commited on
Commit
d9a4587
1 Parent(s): 90d9164
Files changed (1) hide show
  1. inference/base_tts_infer.py +2 -2
inference/base_tts_infer.py CHANGED
@@ -78,7 +78,7 @@ class BaseTTSInfer:
78
  # processed ref audio
79
  ref_audio = inp['ref_audio']
80
  processed_ref_audio = 'example/temp.wav'
81
- voice_encoder = VoiceEncoder().cuda()
82
  encoder = [self.ph_encoder, self.word_encoder]
83
  EmotionEncoder.load_model(self.hparams['emotion_encoder_path'])
84
  binarizer_cls = self.hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
@@ -185,7 +185,7 @@ class BaseTTSInfer:
185
  input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
186
 
187
  # retrieve logits & take argmax
188
- logits = self.asr_model(input_values.cuda()).logits
189
  predicted_ids = torch.argmax(logits, dim=-1)
190
 
191
  # transcribe
 
78
  # processed ref audio
79
  ref_audio = inp['ref_audio']
80
  processed_ref_audio = 'example/temp.wav'
81
+ voice_encoder = VoiceEncoder().to(self.device)
82
  encoder = [self.ph_encoder, self.word_encoder]
83
  EmotionEncoder.load_model(self.hparams['emotion_encoder_path'])
84
  binarizer_cls = self.hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
 
185
  input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
186
 
187
  # retrieve logits & take argmax
188
+ logits = self.asr_model(input_values.to(self.device)).logits
189
  predicted_ids = torch.argmax(logits, dim=-1)
190
 
191
  # transcribe