lpw commited on
Commit
6dbb4b2
1 Parent(s): 4380a2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -6,32 +6,24 @@ from fairseq.models.speech_to_speech.hub_interface import S2SHubInterface
6
  from fairseq.models.speech_to_text.hub_interface import S2THubInterface
7
  from audio_pipe import SpeechToSpeechPipeline
8
 
9
- io1 = gr.Interface.load("huggingface/facebook/xm_transformer_s2ut_en-hk", api_key=os.environ['api_key'])
10
- io2 = gr.Interface.load("huggingface/facebook/xm_transformer_s2ut_hk-en", api_key=os.environ['api_key'])
11
- io3 = gr.Interface.load("huggingface/facebook/xm_transformer_unity_en-hk", api_key=os.environ['api_key'])
12
- io4 = gr.Interface.load("huggingface/facebook/xm_transformer_unity_hk-en", api_key=os.environ['api_key'])
13
- pipe = SpeechToSpeechPipeline("facebook/xm_transformer_unity_hk-en")
14
-
15
- def call_model(audio, model):
16
- # pipe = SpeechToSpeechPipeline("facebook/xm_transformer_unity_hk-en")
17
- # wav, sr, text = pipe(audio)
18
- # temp_file = pipe(audio)
19
- # return gr.Audio(temp_file)
20
- print(pipe(audio).get_config())
21
- return pipe(audio).get_config()["value"]["name"]
22
 
23
  def inference(audio, model):
24
  if model == "xm_transformer_s2ut_en-hk":
25
- out_audio = io1(audio)
26
  elif model == "xm_transformer_s2ut_hk-en":
27
- out_audio = io2(audio)
28
  elif model == "xm_transformer_unity_en-hk":
29
- out_audio = io3(audio)
30
- elif model == "xm_transformer_unity_hk-en_gpu":
31
- out_audio = call_model(audio, model)
32
  else:
33
- out_audio = io4(audio)
34
- print(out_audio)
35
  return out_audio
36
 
37
 
 
6
  from fairseq.models.speech_to_text.hub_interface import S2THubInterface
7
  from audio_pipe import SpeechToSpeechPipeline
8
 
9
+ # io1 = gr.Interface.load("huggingface/facebook/xm_transformer_s2ut_en-hk", api_key=os.environ['api_key'])
10
+ # io2 = gr.Interface.load("huggingface/facebook/xm_transformer_s2ut_hk-en", api_key=os.environ['api_key'])
11
+ # io3 = gr.Interface.load("huggingface/facebook/xm_transformer_unity_en-hk", api_key=os.environ['api_key'])
12
+ # io4 = gr.Interface.load("huggingface/facebook/xm_transformer_unity_hk-en", api_key=os.environ['api_key'])
13
+ pipe1 = SpeechToSpeechPipeline("facebook/xm_transformer_s2ut_en-hk")
14
+ pipe2 = SpeechToSpeechPipeline("facebook/xm_transformer_s2ut_hk-en")
15
+ pipe3 = SpeechToSpeechPipeline("facebook/xm_transformer_unity_en-hk")
16
+ pipe4 = SpeechToSpeechPipeline("facebook/xm_transformer_unity_hk-en")
 
 
 
 
 
17
 
18
  def inference(audio, model):
19
  if model == "xm_transformer_s2ut_en-hk":
20
+ out_audio = pipe1(audio).get_config()["value"]["name"]
21
  elif model == "xm_transformer_s2ut_hk-en":
22
+ out_audio = pipe2(audio).get_config()["value"]["name"]
23
  elif model == "xm_transformer_unity_en-hk":
24
+ out_audio = pipe3(audio).get_config()["value"]["name"]
 
 
25
  else:
26
+ out_audio = pipe4(audio).get_config()["value"]["name"]
 
27
  return out_audio
28
 
29