lpw commited on
Commit
418e72c
1 Parent(s): 796bd4c

Update audio_pipe.py

Browse files
Files changed (1) hide show
  1. audio_pipe.py +3 -2
audio_pipe.py CHANGED
@@ -5,6 +5,7 @@ from typing import List, Tuple
5
 
6
  import numpy as np
7
  import torch
 
8
  # from app.pipelines import Pipeline
9
  from fairseq import hub_utils
10
  from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
@@ -105,7 +106,7 @@ class SpeechToSpeechPipeline():
105
  [self.tts_model], tts_cfg
106
  )
107
 
108
- def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
109
  """
110
  Args:
111
  inputs (:obj:`np.array`):
@@ -120,7 +121,7 @@ class SpeechToSpeechPipeline():
120
  This can be the name of the instruments for audio source separation
121
  or some annotation for speech enhancement. The length must be `C'`.
122
  """
123
- _inputs = torch.from_numpy(inputs).unsqueeze(0)
124
  sample, text = None, None
125
  if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
126
  sample = S2THubInterface.get_model_input(self.task, _inputs)
 
5
 
6
  import numpy as np
7
  import torch
8
+ import torchaudio
9
  # from app.pipelines import Pipeline
10
  from fairseq import hub_utils
11
  from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
 
106
  [self.tts_model], tts_cfg
107
  )
108
 
109
+ def __call__(self, inputs: str) -> Tuple[np.array, int, List[str]]:
110
  """
111
  Args:
112
  inputs (:obj:`np.array`):
 
121
  This can be the name of the instruments for audio source separation
122
  or some annotation for speech enhancement. The length must be `C'`.
123
  """
124
+ _inputs = torchaudio.load(inputs)
125
  sample, text = None, None
126
  if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
127
  sample = S2THubInterface.get_model_input(self.task, _inputs)