lpw commited on
Commit
86e5f28
1 Parent(s): c3e5809

Update audio_pipe.py

Browse files
Files changed (1) hide show
  1. audio_pipe.py +4 -2
audio_pipe.py CHANGED
@@ -106,7 +106,7 @@ class SpeechToSpeechPipeline():
106
  [self.tts_model], tts_cfg
107
  )
108
 
109
- def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
110
  """
111
  Args:
112
  inputs (:obj:`np.array`):
@@ -121,7 +121,9 @@ class SpeechToSpeechPipeline():
121
  This can be the name of the instruments for audio source separation
122
  or some annotation for speech enhancement. The length must be `C'`.
123
  """
124
- _inputs = torch.from_numpy(inputs).unsqueeze(0)
 
 
125
  sample, text = None, None
126
  if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
127
  sample = S2THubInterface.get_model_input(self.task, _inputs)
 
106
  [self.tts_model], tts_cfg
107
  )
108
 
109
+ def __call__(self, inputs: str) -> Tuple[np.array, int, List[str]]:
110
  """
111
  Args:
112
  inputs (:obj:`np.array`):
 
121
  This can be the name of the instruments for audio source separation
122
  or some annotation for speech enhancement. The length must be `C'`.
123
  """
124
+ # _inputs = torch.from_numpy(inputs).unsqueeze(0)
125
+ print(f"input: {inputs}")
126
+ _inputs = torchaudio.load(inputs)
127
  sample, text = None, None
128
  if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
129
  sample = S2THubInterface.get_model_input(self.task, _inputs)