lpw commited on
Commit
32e2145
1 Parent(s): 605e90b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +147 -0
  2. audio_pipe.py +161 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ from audio_pipe import SpeechToSpeechPipeline
5
+
6
+ # io1 = gr.Interface.load("huggingface/facebook/xm_transformer_s2ut_en-hk", api_key=os.environ['api_key'])
7
+ # io2 = gr.Interface.load("huggingface/facebook/xm_transformer_s2ut_hk-en", api_key=os.environ['api_key'])
8
+ # io3 = gr.Interface.load("huggingface/facebook/xm_transformer_unity_en-hk", api_key=os.environ['api_key'])
9
+ # io4 = gr.Interface.load("huggingface/facebook/xm_transformer_unity_hk-en", api_key=os.environ['api_key'])
10
+ pipe1 = SpeechToSpeechPipeline("facebook/xm_transformer_s2ut_en-hk")
11
+ pipe2 = SpeechToSpeechPipeline("facebook/xm_transformer_s2ut_hk-en")
12
+ pipe3 = SpeechToSpeechPipeline("facebook/xm_transformer_unity_en-hk")
13
+ pipe4 = SpeechToSpeechPipeline("facebook/xm_transformer_unity_hk-en")
14
+
15
+ def inference(audio, model):
16
+ if model == "xm_transformer_s2ut_en-hk":
17
+ out_audio = pipe1(audio).get_config()["value"]["name"]
18
+ elif model == "xm_transformer_s2ut_hk-en":
19
+ out_audio = pipe2(audio).get_config()["value"]["name"]
20
+ elif model == "xm_transformer_unity_en-hk":
21
+ out_audio = pipe3(audio).get_config()["value"]["name"]
22
+ else:
23
+ out_audio = pipe4(audio).get_config()["value"]["name"]
24
+ return out_audio
25
+
26
+
27
+ css = """
28
+ .gradio-container {
29
+ font-family: 'IBM Plex Sans', sans-serif;
30
+ }
31
+ .gr-button {
32
+ color: black;
33
+ border-color: grey;
34
+ background: white;
35
+ }
36
+ input[type='range'] {
37
+ accent-color: black;
38
+ }
39
+ .dark input[type='range'] {
40
+ accent-color: #dfdfdf;
41
+ }
42
+ .container {
43
+ max-width: 730px;
44
+ margin: auto;
45
+ padding-top: 1.5rem;
46
+ }
47
+
48
+ .details:hover {
49
+ text-decoration: underline;
50
+ }
51
+ .gr-button {
52
+ white-space: nowrap;
53
+ }
54
+ .gr-button:focus {
55
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
56
+ outline: none;
57
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
58
+ --tw-border-opacity: 1;
59
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
60
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
61
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
62
+ --tw-ring-opacity: .5;
63
+ }
64
+ .footer {
65
+ margin-bottom: 45px;
66
+ margin-top: 35px;
67
+ text-align: center;
68
+ border-bottom: 1px solid #e5e5e5;
69
+ }
70
+ .footer>p {
71
+ font-size: .8rem;
72
+ display: inline-block;
73
+ padding: 0 10px;
74
+ transform: translateY(10px);
75
+ background: white;
76
+ }
77
+ .dark .footer {
78
+ border-color: #303030;
79
+ }
80
+ .dark .footer>p {
81
+ background: #0b0f19;
82
+ }
83
+ .prompt h4{
84
+ margin: 1.25em 0 .25em 0;
85
+ font-weight: bold;
86
+ font-size: 115%;
87
+ }
88
+ .animate-spin {
89
+ animation: spin 1s linear infinite;
90
+ }
91
+ @keyframes spin {
92
+ from {
93
+ transform: rotate(0deg);
94
+ }
95
+ to {
96
+ transform: rotate(360deg);
97
+ }
98
+ }
99
+ """
100
+
101
+ block = gr.Blocks(css=css)
102
+
103
+
104
+
105
+ with block:
106
+ gr.HTML(
107
+ """
108
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
109
+ <div
110
+ style="
111
+ display: inline-flex;
112
+ align-items: center;
113
+ gap: 0.8rem;
114
+ font-size: 1.75rem;
115
+ "
116
+ >
117
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
118
+ Hokkien Translation
119
+ </h1>
120
+ </div>
121
+ <p style="margin-bottom: 10px; font-size: 94%">
122
+ A demo for fairseq speech-to-speech translation models. It supports S2UT and UnitY models for bidirectional Hokkien and English translation. Please select the model and record the input to submit.
123
+ </p>
124
+ </div>
125
+ """
126
+ )
127
+ with gr.Group():
128
+ with gr.Box():
129
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
130
+ audio = gr.Audio(
131
+ source="microphone", type="filepath", label="Input"
132
+ )
133
+
134
+ btn = gr.Button("Submit")
135
+ model = gr.Dropdown(choices=["xm_transformer_unity_en-hk", "xm_transformer_unity_hk-en", "xm_transformer_s2ut_en-hk", "xm_transformer_s2ut_hk-en"], value="xm_transformer_unity_en-hk",type="value", label="Model")
136
+ # model = gr.Dropdown(choices=["xm_transformer_unity_en-hk", "xm_transformer_unity_hk-en"], value="xm_transformer_unity_en-hk",type="value", label="Model")
137
+ out = gr.Audio(label="Output")
138
+
139
+ btn.click(inference, inputs=[audio, model], outputs=out)
140
+ gr.HTML('''
141
+ <div class="footer">
142
+ <p>Model by <a href="https://ai.facebook.com/" style="text-decoration: underline;" target="_blank">Meta AI</a>
143
+ </p>
144
+ </div>
145
+ ''')
146
+
147
+ block.launch()
audio_pipe.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Tuple
5
+ import tempfile
6
+ import soundfile as sf
7
+ import gradio as gr
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torchaudio
12
+ # from app.pipelines import Pipeline
13
+ from fairseq import hub_utils
14
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
15
+ from fairseq.models.speech_to_speech.hub_interface import S2SHubInterface
16
+ from fairseq.models.speech_to_text.hub_interface import S2THubInterface
17
+ from fairseq.models.text_to_speech import CodeHiFiGANVocoder
18
+ from fairseq.models.text_to_speech.hub_interface import (
19
+ TTSHubInterface,
20
+ VocoderHubInterface,
21
+ )
22
+ from huggingface_hub import snapshot_download
23
+
24
+ ARG_OVERRIDES_MAP = {
25
+ "facebook/xm_transformer_s2ut_800m-es-en-st-asr-bt_h1_2022": {
26
+ "config_yaml": "config.yaml",
27
+ "task": "speech_to_text",
28
+ }
29
+ }
30
+
31
+ class SpeechToSpeechPipeline():
32
+ def __init__(self, model_id: str):
33
+ arg_overrides = ARG_OVERRIDES_MAP.get(
34
+ model_id, {}
35
+ ) # Model specific override. TODO: Update on checkpoint side in the future
36
+ arg_overrides["config_yaml"] = "config.yaml" # common override
37
+ models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
38
+ model_id,
39
+ arg_overrides=arg_overrides,
40
+ cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
41
+ )
42
+ self.cfg = cfg
43
+ self.model = models[0].cpu()
44
+ self.model.eval()
45
+ self.task = task
46
+
47
+ self.sampling_rate = getattr(self.task, "sr", None) or 16_000
48
+
49
+ tgt_lang = self.task.data_cfg.hub.get("tgt_lang", None)
50
+ pfx = f"{tgt_lang}_" if self.task.data_cfg.prepend_tgt_lang_tag else ""
51
+
52
+ generation_args = self.task.data_cfg.hub.get(f"{pfx}generation_args", None)
53
+ if generation_args is not None:
54
+ for key in generation_args:
55
+ setattr(cfg.generation, key, generation_args[key])
56
+ self.generator = task.build_generator([self.model], cfg.generation)
57
+
58
+ tts_model_id = self.task.data_cfg.hub.get(f"{pfx}tts_model_id", None)
59
+ self.unit_vocoder = self.task.data_cfg.hub.get(f"{pfx}unit_vocoder", None)
60
+ self.tts_model, self.tts_task, self.tts_generator = None, None, None
61
+ if tts_model_id is not None:
62
+ _id = tts_model_id.split(":")[-1]
63
+ cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE")
64
+ if self.unit_vocoder is not None:
65
+ library_name = "fairseq"
66
+ cache_dir = (
67
+ cache_dir or (Path.home() / ".cache" / library_name).as_posix()
68
+ )
69
+ cache_dir = snapshot_download(
70
+ f"facebook/{_id}", cache_dir=cache_dir, library_name=library_name
71
+ )
72
+
73
+ x = hub_utils.from_pretrained(
74
+ cache_dir,
75
+ "model.pt",
76
+ ".",
77
+ archive_map=CodeHiFiGANVocoder.hub_models(),
78
+ config_yaml="config.json",
79
+ fp16=False,
80
+ is_vocoder=True,
81
+ )
82
+
83
+ with open(f"{x['args']['data']}/config.json") as f:
84
+ vocoder_cfg = json.load(f)
85
+ assert (
86
+ len(x["args"]["model_path"]) == 1
87
+ ), "Too many vocoder models in the input"
88
+
89
+ vocoder = CodeHiFiGANVocoder(x["args"]["model_path"][0], vocoder_cfg)
90
+ self.tts_model = VocoderHubInterface(vocoder_cfg, vocoder)
91
+
92
+ else:
93
+ (
94
+ tts_models,
95
+ tts_cfg,
96
+ self.tts_task,
97
+ ) = load_model_ensemble_and_task_from_hf_hub(
98
+ f"facebook/{_id}",
99
+ arg_overrides={"vocoder": "griffin_lim", "fp16": False},
100
+ cache_dir=cache_dir,
101
+ )
102
+ self.tts_model = tts_models[0].cpu()
103
+ self.tts_model.eval()
104
+ tts_cfg["task"].cpu = True
105
+ TTSHubInterface.update_cfg_with_data_cfg(
106
+ tts_cfg, self.tts_task.data_cfg
107
+ )
108
+ self.tts_generator = self.tts_task.build_generator(
109
+ [self.tts_model], tts_cfg
110
+ )
111
+
112
+ def __call__(self, inputs: str) -> Tuple[np.array, int, List[str]]:
113
+ """
114
+ Args:
115
+ inputs (:obj:`np.array`):
116
+ The raw waveform of audio received. By default sampled at `self.sampling_rate`.
117
+ The shape of this array is `T`, where `T` is the time axis
118
+ Return:
119
+ A :obj:`tuple` containing:
120
+ - :obj:`np.array`:
121
+ The return shape of the array must be `C'`x`T'`
122
+ - a :obj:`int`: the sampling rate as an int in Hz.
123
+ - a :obj:`List[str]`: the annotation for each out channel.
124
+ This can be the name of the instruments for audio source separation
125
+ or some annotation for speech enhancement. The length must be `C'`.
126
+ """
127
+ # _inputs = torch.from_numpy(inputs).unsqueeze(0)
128
+ # print(f"input: {inputs}")
129
+ # _inputs = torchaudio.load(inputs)
130
+ _inputs = inputs
131
+ sample, text = None, None
132
+ if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
133
+ sample = S2THubInterface.get_model_input(self.task, _inputs)
134
+ text = S2THubInterface.get_prediction(
135
+ self.task, self.model, self.generator, sample
136
+ )
137
+ elif self.cfg.task._name in ["speech_to_speech"]:
138
+ s2shubinerface = S2SHubInterface(self.cfg, self.task, self.model)
139
+ sample = s2shubinerface.get_model_input(self.task, _inputs)
140
+ text = S2SHubInterface.get_prediction(
141
+ self.task, self.model, self.generator, sample
142
+ )
143
+
144
+ wav, sr = np.zeros((0,)), self.sampling_rate
145
+ if self.unit_vocoder is not None:
146
+ tts_sample = self.tts_model.get_model_input(text)
147
+ wav, sr = self.tts_model.get_prediction(tts_sample)
148
+ text = ""
149
+ else:
150
+ tts_sample = TTSHubInterface.get_model_input(self.tts_task, text)
151
+ wav, sr = TTSHubInterface.get_prediction(
152
+ self.tts_task, self.tts_model, self.tts_generator, tts_sample
153
+ )
154
+ temp_file = ""
155
+ with tempfile.NamedTemporaryFile(suffix=".wav") as tmp_output_file:
156
+ sf.write(tmp_output_file, wav.detach().cpu().numpy(), sr)
157
+ tmp_output_file.seek(0)
158
+ temp_file = gr.Audio(tmp_output_file.name)
159
+
160
+ # return wav, sr, [text]
161
+ return temp_file
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ api-inference-community==0.0.23
2
+ git+https://github.com/facebookresearch/fairseq.git@d47119871c2ac9a0a0aa2904dd8cfc1929b113d9#egg=fairseq
3
+ huggingface_hub==0.5.1