lpw commited on
Commit
5dad022
1 Parent(s): 76feae7

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +47 -0
  2. audio_pipe.py +161 -0
  3. packages.txt +1 -0
  4. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip install gradio==3.3")
3
+ import gradio as gr
4
+ import numpy as np
5
+ import streamlit as st
6
+ from audio_pipe import SpeechToSpeechPipeline
7
+
8
+ title = "SpeechMatrix Speech-to-speech Translation"
9
+
10
+ description = "Gradio Demo for SpeechMatrix. To use it, simply record your audio, or click the example to load. Read more at the links below. \nNote: These models are trained on SpeechMatrix data only, and meant to serve as a baseline for future research."
11
+
12
+ article = "<p style='text-align: center'><a href='https://research.facebook.com/publications/speechmatrix' target='_blank'>SpeechMatrix</a> | <a href='https://github.com/facebookresearch/fairseq/tree/ust' target='_blank'>Github Repo</a></p>"
13
+
14
+ SRC_LIST = ['cs', 'de', 'en', 'es', 'et', 'fi', 'fr', 'hr', 'hu', 'it', 'nl', 'pl', 'pt', 'ro', 'sk', 'sl']
15
+ # SRC_LIST = ['cs', 'de', 'en', 'es', 'et', 'fi', 'fr', 'hr', 'hu', 'nl', 'pl', 'pt', 'ro', 'sk', 'sl']
16
+ TGT_LIST = ['en', 'fr', 'es']
17
+ MODEL_LIST = ['xm_transformer_sm_all-en']
18
+ for src in SRC_LIST:
19
+ for tgt in TGT_LIST:
20
+ if src != tgt:
21
+ MODEL_LIST.append(f"textless_sm_{src}_{tgt}")
22
+
23
+ examples = []
24
+ pipe_dict = {}
25
+
26
+ # io_dict = {model: gr.Interface.load(f"huggingface/facebook/{model}", api_key=st.secrets["api_key"]) for model in MODEL_LIST}
27
+ # pipe_dict = {model: SpeechToSpeechPipeline(f"facebook/{model}") for model in MODEL_LIST}
28
+ for model in MODEL_LIST:
29
+ print(f"model: {model}")
30
+ pipe_dict[model] = SpeechToSpeechPipeline(f"facebook/{model}")
31
+
32
+ def inference(audio, model):
33
+ out_audio = pipe_dict[model](audio).get_config()["value"]["name"]
34
+ # pipe = SpeechToSpeechPipeline(f"facebook/{model}")
35
+ # out_audio = pipe(audio).get_config()["value"]["name"]
36
+ return out_audio
37
+
38
+ gr.Interface(
39
+ inference,
40
+ [gr.inputs.Audio(source="microphone", type="filepath", label="Input"),gr.inputs.Dropdown(choices=MODEL_LIST, default="xm_transformer_sm_all-en",type="value", label="Model")
41
+ ],
42
+ gr.outputs.Audio(label="Output"),
43
+ article=article,
44
+ title=title,
45
+ examples=examples,
46
+ cache_examples=False,
47
+ description=description).queue().launch()
audio_pipe.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Tuple
5
+ import tempfile
6
+ import soundfile as sf
7
+ import gradio as gr
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torchaudio
12
+ # from app.pipelines import Pipeline
13
+ from fairseq import hub_utils
14
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
15
+ from fairseq.models.speech_to_speech.hub_interface import S2SHubInterface
16
+ from fairseq.models.speech_to_text.hub_interface import S2THubInterface
17
+ from fairseq.models.text_to_speech import CodeHiFiGANVocoder
18
+ from fairseq.models.text_to_speech.hub_interface import (
19
+ TTSHubInterface,
20
+ VocoderHubInterface,
21
+ )
22
+ from huggingface_hub import snapshot_download
23
+
24
+ ARG_OVERRIDES_MAP = {
25
+ "facebook/xm_transformer_s2ut_800m-es-en-st-asr-bt_h1_2022": {
26
+ "config_yaml": "config.yaml",
27
+ "task": "speech_to_text",
28
+ }
29
+ }
30
+
31
+ class SpeechToSpeechPipeline():
32
+ def __init__(self, model_id: str):
33
+ arg_overrides = ARG_OVERRIDES_MAP.get(
34
+ model_id, {}
35
+ ) # Model specific override. TODO: Update on checkpoint side in the future
36
+ arg_overrides["config_yaml"] = "config.yaml" # common override
37
+ models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
38
+ model_id,
39
+ arg_overrides=arg_overrides,
40
+ cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
41
+ )
42
+ self.cfg = cfg
43
+ self.model = models[0].cpu()
44
+ self.model.eval()
45
+ self.task = task
46
+
47
+ self.sampling_rate = getattr(self.task, "sr", None) or 16_000
48
+
49
+ tgt_lang = self.task.data_cfg.hub.get("tgt_lang", None)
50
+ pfx = f"{tgt_lang}_" if self.task.data_cfg.prepend_tgt_lang_tag else ""
51
+
52
+ generation_args = self.task.data_cfg.hub.get(f"{pfx}generation_args", None)
53
+ if generation_args is not None:
54
+ for key in generation_args:
55
+ setattr(cfg.generation, key, generation_args[key])
56
+ self.generator = task.build_generator([self.model], cfg.generation)
57
+
58
+ tts_model_id = self.task.data_cfg.hub.get(f"{pfx}tts_model_id", None)
59
+ self.unit_vocoder = self.task.data_cfg.hub.get(f"{pfx}unit_vocoder", None)
60
+ self.tts_model, self.tts_task, self.tts_generator = None, None, None
61
+ if tts_model_id is not None:
62
+ _id = tts_model_id.split(":")[-1]
63
+ cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE")
64
+ if self.unit_vocoder is not None:
65
+ library_name = "fairseq"
66
+ cache_dir = (
67
+ cache_dir or (Path.home() / ".cache" / library_name).as_posix()
68
+ )
69
+ cache_dir = snapshot_download(
70
+ f"facebook/{_id}", cache_dir=cache_dir, library_name=library_name
71
+ )
72
+
73
+ x = hub_utils.from_pretrained(
74
+ cache_dir,
75
+ "model.pt",
76
+ ".",
77
+ archive_map=CodeHiFiGANVocoder.hub_models(),
78
+ config_yaml="config.json",
79
+ fp16=False,
80
+ is_vocoder=True,
81
+ )
82
+
83
+ with open(f"{x['args']['data']}/config.json") as f:
84
+ vocoder_cfg = json.load(f)
85
+ assert (
86
+ len(x["args"]["model_path"]) == 1
87
+ ), "Too many vocoder models in the input"
88
+
89
+ vocoder = CodeHiFiGANVocoder(x["args"]["model_path"][0], vocoder_cfg)
90
+ self.tts_model = VocoderHubInterface(vocoder_cfg, vocoder)
91
+
92
+ else:
93
+ (
94
+ tts_models,
95
+ tts_cfg,
96
+ self.tts_task,
97
+ ) = load_model_ensemble_and_task_from_hf_hub(
98
+ f"facebook/{_id}",
99
+ arg_overrides={"vocoder": "griffin_lim", "fp16": False},
100
+ cache_dir=cache_dir,
101
+ )
102
+ self.tts_model = tts_models[0].cpu()
103
+ self.tts_model.eval()
104
+ tts_cfg["task"].cpu = True
105
+ TTSHubInterface.update_cfg_with_data_cfg(
106
+ tts_cfg, self.tts_task.data_cfg
107
+ )
108
+ self.tts_generator = self.tts_task.build_generator(
109
+ [self.tts_model], tts_cfg
110
+ )
111
+
112
+ def __call__(self, inputs: str) -> Tuple[np.array, int, List[str]]:
113
+ """
114
+ Args:
115
+ inputs (:obj:`np.array`):
116
+ The raw waveform of audio received. By default sampled at `self.sampling_rate`.
117
+ The shape of this array is `T`, where `T` is the time axis
118
+ Return:
119
+ A :obj:`tuple` containing:
120
+ - :obj:`np.array`:
121
+ The return shape of the array must be `C'`x`T'`
122
+ - a :obj:`int`: the sampling rate as an int in Hz.
123
+ - a :obj:`List[str]`: the annotation for each out channel.
124
+ This can be the name of the instruments for audio source separation
125
+ or some annotation for speech enhancement. The length must be `C'`.
126
+ """
127
+ # _inputs = torch.from_numpy(inputs).unsqueeze(0)
128
+ # print(f"input: {inputs}")
129
+ # _inputs = torchaudio.load(inputs)
130
+ _inputs = inputs
131
+ sample, text = None, None
132
+ if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
133
+ sample = S2THubInterface.get_model_input(self.task, _inputs)
134
+ text = S2THubInterface.get_prediction(
135
+ self.task, self.model, self.generator, sample
136
+ )
137
+ elif self.cfg.task._name in ["speech_to_speech"]:
138
+ s2shubinerface = S2SHubInterface(self.cfg, self.task, self.model)
139
+ sample = s2shubinerface.get_model_input(self.task, _inputs)
140
+ text = S2SHubInterface.get_prediction(
141
+ self.task, self.model, self.generator, sample
142
+ )
143
+
144
+ wav, sr = np.zeros((0,)), self.sampling_rate
145
+ if self.unit_vocoder is not None:
146
+ tts_sample = self.tts_model.get_model_input(text)
147
+ wav, sr = self.tts_model.get_prediction(tts_sample)
148
+ text = ""
149
+ else:
150
+ tts_sample = TTSHubInterface.get_model_input(self.tts_task, text)
151
+ wav, sr = TTSHubInterface.get_prediction(
152
+ self.tts_task, self.tts_model, self.tts_generator, tts_sample
153
+ )
154
+ temp_file = ""
155
+ with tempfile.NamedTemporaryFile(suffix=".wav") as tmp_output_file:
156
+ sf.write(tmp_output_file, wav.detach().cpu().numpy(), sr)
157
+ tmp_output_file.seek(0)
158
+ temp_file = gr.Audio(tmp_output_file.name)
159
+
160
+ # return wav, sr, [text]
161
+ return temp_file
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ six==1.15.0
2
+ urllib3
3
+ scikit-learn
4
+ requests==2.21.0
5
+ git+https://github.com/facebookresearch/fairseq.git@d47119871c2ac9a0a0aa2904dd8cfc1929b113d9#egg=fairseq