Imadsarvm commited on
Commit
59f0369
1 Parent(s): 67933a0

Upload app (2).py

Browse files
Files changed (1) hide show
  1. app (2).py +312 -0
app (2).py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import pathlib
5
+ import tempfile
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import torchaudio
10
+ from fairseq2.assets import InProcAssetMetadataProvider, asset_store
11
+ from fairseq2.data import Collater, SequenceData, VocabularyInfo
12
+ from fairseq2.data.audio import (
13
+ AudioDecoder,
14
+ WaveformToFbankConverter,
15
+ WaveformToFbankOutput,
16
+ )
17
+
18
+ from seamless_communication.inference import SequenceGeneratorOptions
19
+ from fairseq2.generation import NGramRepeatBlockProcessor
20
+ from fairseq2.memory import MemoryBlock
21
+ from fairseq2.typing import DataType, Device
22
+ from huggingface_hub import snapshot_download
23
+ from seamless_communication.inference import BatchedSpeechOutput, Translator, SequenceGeneratorOptions
24
+ from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
25
+ from seamless_communication.models.unity import (
26
+ UnitTokenizer,
27
+ load_gcmvn_stats,
28
+ load_unity_text_tokenizer,
29
+ load_unity_unit_tokenizer,
30
+ )
31
+ from torch.nn import Module
32
+ from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import PretsselGenerator
33
+
34
+ from utils import LANGUAGE_CODE_TO_NAME
35
+
36
+ DESCRIPTION = """\
37
+ # Seamless Expressive
38
+
39
+
40
+ [SeamlessExpressive](https://github.com/facebookresearch/seamless_communication/blob/main/docs/expressive/README.md) is a speech-to-speech translation model that captures certain underexplored aspects of prosody such as speech rate and pauses, while preserving the style of one's voice and high content translation quality. The model is also in use on the [SeamlessExpressive demo website](https://seamless.metademolab.com/expressive?utm_source=huggingface&utm_medium=web&utm_campaign=seamless&utm_content=expressivespace).
41
+ """
42
+
43
+ CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
44
+
45
+ CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
46
+ if not CHECKPOINTS_PATH.exists():
47
+ snapshot_download(repo_id="facebook/seamless-expressive", repo_type="model", local_dir=CHECKPOINTS_PATH)
48
+ snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
49
+
50
+ # Ensure that we do not have any other environment resolvers and always return
51
+ # "demo" for demo purposes.
52
+ asset_store.env_resolvers.clear()
53
+ asset_store.env_resolvers.append(lambda: "demo")
54
+
55
+ # Construct an `InProcAssetMetadataProvider` with environment-specific metadata
56
+ # that just overrides the regular metadata for "demo" environment. Note the "@demo" suffix.
57
+ demo_metadata = [
58
+ {
59
+ "name": "seamless_expressivity@demo",
60
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/m2m_expressive_unity.pt",
61
+ "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
62
+ },
63
+ {
64
+ "name": "vocoder_pretssel@demo",
65
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/pretssel_melhifigan_wm-final.pt",
66
+ },
67
+ {
68
+ "name": "seamlessM4T_v2_large@demo",
69
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
70
+ "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
71
+ },
72
+ ]
73
+
74
+ asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
75
+
76
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in LANGUAGE_CODE_TO_NAME.items()}
77
+
78
+
79
+ if torch.cuda.is_available():
80
+ device = torch.device("cuda:0")
81
+ dtype = torch.float16
82
+ else:
83
+ device = torch.device("cpu")
84
+ dtype = torch.float32
85
+
86
+
87
+ MODEL_NAME = "seamless_expressivity"
88
+ VOCODER_NAME = "vocoder_pretssel"
89
+
90
+ # used for ASR for toxicity
91
+ m4t_translator = Translator(
92
+ model_name_or_card="seamlessM4T_v2_large",
93
+ vocoder_name_or_card=None,
94
+ device=device,
95
+ dtype=dtype,
96
+ )
97
+ unit_tokenizer = load_unity_unit_tokenizer(MODEL_NAME)
98
+
99
+ _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(VOCODER_NAME)
100
+ gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
101
+ gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
102
+
103
+ translator = Translator(
104
+ MODEL_NAME,
105
+ vocoder_name_or_card=None,
106
+ device=device,
107
+ dtype=dtype,
108
+ apply_mintox=False,
109
+ )
110
+
111
+ text_generation_opts = SequenceGeneratorOptions(
112
+ beam_size=5,
113
+ unk_penalty=torch.inf,
114
+ soft_max_seq_len=(0, 200),
115
+ step_processor=NGramRepeatBlockProcessor(
116
+ ngram_size=10,
117
+ ),
118
+ )
119
+ m4t_text_generation_opts = SequenceGeneratorOptions(
120
+ beam_size=5,
121
+ unk_penalty=torch.inf,
122
+ soft_max_seq_len=(1, 200),
123
+ step_processor=NGramRepeatBlockProcessor(
124
+ ngram_size=10,
125
+ ),
126
+ )
127
+
128
+ pretssel_generator = PretsselGenerator(
129
+ VOCODER_NAME,
130
+ vocab_info=unit_tokenizer.vocab_info,
131
+ device=device,
132
+ dtype=dtype,
133
+ )
134
+
135
+ decode_audio = AudioDecoder(dtype=torch.float32, device=device)
136
+
137
+ convert_to_fbank = WaveformToFbankConverter(
138
+ num_mel_bins=80,
139
+ waveform_scale=2**15,
140
+ channel_last=True,
141
+ standardize=False,
142
+ device=device,
143
+ dtype=dtype,
144
+ )
145
+
146
+
147
+ def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
148
+ fbank = data["fbank"]
149
+ std, mean = torch.std_mean(fbank, dim=0)
150
+ data["fbank"] = fbank.subtract(mean).divide(std)
151
+ data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
152
+ return data
153
+
154
+
155
+ collate = Collater(pad_value=0, pad_to_multiple=1)
156
+
157
+
158
+ AUDIO_SAMPLE_RATE = 16000
159
+ MAX_INPUT_AUDIO_LENGTH = 10 # in seconds
160
+
161
+
162
+ def remove_prosody_tokens_from_text(text):
163
+ # filter out prosody tokens, there is only emphasis '*', and pause '='
164
+ text = text.replace("*", "").replace("=", "")
165
+ text = " ".join(text.split())
166
+ return text
167
+
168
+
169
+ def preprocess_audio(input_audio_path: str) -> None:
170
+ arr, org_sr = torchaudio.load(input_audio_path)
171
+ new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
172
+ max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
173
+ if new_arr.shape[1] > max_length:
174
+ new_arr = new_arr[:, :max_length]
175
+ gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
176
+ torchaudio.save(input_audio_path, new_arr, sample_rate=AUDIO_SAMPLE_RATE)
177
+
178
+
179
+ def run(
180
+ input_audio_path: str,
181
+ source_language: str,
182
+ target_language: str,
183
+ ) -> tuple[str, str]:
184
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
185
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
186
+
187
+ preprocess_audio(input_audio_path)
188
+
189
+ with pathlib.Path(input_audio_path).open("rb") as fb:
190
+ block = MemoryBlock(fb.read())
191
+ example = decode_audio(block)
192
+
193
+ example = convert_to_fbank(example)
194
+ example = normalize_fbank(example)
195
+ example = collate(example)
196
+
197
+ # get transcription for mintox
198
+ source_sentences, _ = m4t_translator.predict(
199
+ input=example["fbank"],
200
+ task_str="S2TT", # get source text
201
+ tgt_lang=source_language_code,
202
+ text_generation_opts=m4t_text_generation_opts,
203
+ )
204
+ source_text = str(source_sentences[0])
205
+
206
+ prosody_encoder_input = example["gcmvn_fbank"]
207
+ text_output, unit_output = translator.predict(
208
+ example["fbank"],
209
+ "S2ST",
210
+ tgt_lang=target_language_code,
211
+ src_lang=source_language_code,
212
+ text_generation_opts=text_generation_opts,
213
+ unit_generation_ngram_filtering=False,
214
+ duration_factor=1.0,
215
+ prosody_encoder_input=prosody_encoder_input,
216
+ src_text=source_text, # for mintox check
217
+ )
218
+ speech_output = pretssel_generator.predict(
219
+ unit_output.units,
220
+ tgt_lang=target_language_code,
221
+ prosody_encoder_input=prosody_encoder_input,
222
+ )
223
+
224
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
225
+ torchaudio.save(
226
+ f.name,
227
+ speech_output.audio_wavs[0][0].to(torch.float32).cpu(),
228
+ sample_rate=speech_output.sample_rate,
229
+ )
230
+
231
+ text_out = remove_prosody_tokens_from_text(str(text_output[0]))
232
+
233
+ return f.name, text_out
234
+
235
+
236
+ TARGET_LANGUAGE_NAMES = [
237
+ "English",
238
+ "French",
239
+ "German",
240
+ "Spanish",
241
+ ]
242
+
243
+ UPDATED_LANGUAGE_LIST = {
244
+ "English": ["French", "German", "Spanish"],
245
+ "French": ["English", "German", "Spanish"],
246
+ "German": ["English", "French", "Spanish"],
247
+ "Spanish": ["English", "French", "German"],
248
+ }
249
+
250
+ def rs_change(rs):
251
+ return gr.update(
252
+ choices=UPDATED_LANGUAGE_LIST[rs],
253
+ value=UPDATED_LANGUAGE_LIST[rs][0],
254
+ )
255
+
256
+ with gr.Blocks(css="style.css") as demo:
257
+ gr.Markdown(DESCRIPTION)
258
+ gr.DuplicateButton(
259
+ value="Duplicate Space for private use",
260
+ elem_id="duplicate-button",
261
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
262
+ )
263
+ with gr.Row():
264
+ with gr.Column():
265
+ with gr.Group():
266
+ input_audio = gr.Audio(label="Input speech", type="filepath")
267
+ source_language = gr.Dropdown(
268
+ label="Source language",
269
+ choices=TARGET_LANGUAGE_NAMES,
270
+ value="English",
271
+ )
272
+ target_language = gr.Dropdown(
273
+ label="Target language",
274
+ choices=TARGET_LANGUAGE_NAMES,
275
+ value="French",
276
+ interactive=True,
277
+ )
278
+ source_language.change(
279
+ fn=rs_change,
280
+ inputs=[source_language],
281
+ outputs=[target_language],
282
+ )
283
+
284
+ btn = gr.Button()
285
+ with gr.Column():
286
+ with gr.Group():
287
+ output_audio = gr.Audio(label="Translated speech")
288
+ output_text = gr.Textbox(label="Translated text")
289
+
290
+ gr.Examples(
291
+ examples=[
292
+ ["assets/Excited-English.wav", "English", "Spanish"],
293
+ ["assets/Whisper-English.wav", "English", "German"],
294
+ ["assets/FastTalking-French.wav", "French", "English"],
295
+ ["assets/Sad-English.wav", "English", "Spanish"],
296
+ ],
297
+ inputs=[input_audio, source_language, target_language],
298
+ outputs=[output_audio, output_text],
299
+ fn=run,
300
+ cache_examples=CACHE_EXAMPLES,
301
+ api_name=False,
302
+ )
303
+
304
+ btn.click(
305
+ fn=run,
306
+ inputs=[input_audio, source_language, target_language],
307
+ outputs=[output_audio, output_text],
308
+ api_name="run",
309
+ )
310
+
311
+ if __name__ == "__main__":
312
+ demo.queue(max_size=50).launch()