#!/usr/bin/env pythonvv import os import pathlib import tempfile import gradio as gr import torch import torchaudio from fairseq2.assets import InProcAssetMetadataProvider, asset_store from fairseq2.data import Collater, SequenceData, VocabularyInfo from fairseq2.data.audio import ( AudioDecoder, WaveformToFbankConverter, WaveformToFbankOutput, ) from seamless_communication.inference import SequenceGeneratorOptions from fairseq2.generation import NGramRepeatBlockProcessor from fairseq2.memory import MemoryBlock from fairseq2.typing import DataType, Device from huggingface_hub import snapshot_download from seamless_communication.inference import BatchedSpeechOutput, Translator, SequenceGeneratorOptions from seamless_communication.models.generator.loader import load_pretssel_vocoder_model from seamless_communication.models.unity import ( UnitTokenizer, load_gcmvn_stats, load_unity_text_tokenizer, load_unity_unit_tokenizer, ) from torch.nn import Module from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import PretsselGenerator from utils import LANGUAGE_CODE_TO_NAME DESCRIPTION = """\ # Seamless Expressive [SeamlessExpressive](https://github.com/facebookresearch/seamless_communication/blob/main/docs/expressive/README.md) is a speech-to-speech translation model that captures certain underexplored aspects of prosody such as speech rate and pauses, while preserving the style of one's voice and high content translation quality. The model is also in use on the [SeamlessExpressive demo website](https://seamless.metademolab.com/expressive?utm_source=huggingface&utm_medium=web&utm_campaign=seamless&utm_content=expressivespace). """ CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available() CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models")) if not CHECKPOINTS_PATH.exists(): snapshot_download(repo_id="facebook/seamless-expressive", repo_type="model", local_dir=CHECKPOINTS_PATH) snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH) # Ensure that we do not have any other environment resolvers and always return # "demo" for demo purposes. asset_store.env_resolvers.clear() asset_store.env_resolvers.append(lambda: "demo") # Construct an `InProcAssetMetadataProvider` with environment-specific metadata # that just overrides the regular metadata for "demo" environment. Note the "@demo" suffix. demo_metadata = [ { "name": "seamless_expressivity@demo", "checkpoint": f"file://{CHECKPOINTS_PATH}/m2m_expressive_unity.pt", "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model", }, { "name": "vocoder_pretssel@demo", "checkpoint": f"file://{CHECKPOINTS_PATH}/pretssel_melhifigan_wm-final.pt", }, { "name": "seamlessM4T_v2_large@demo", "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt", "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model", }, ] asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata)) LANGUAGE_NAME_TO_CODE = {v: k for k, v in LANGUAGE_CODE_TO_NAME.items()} if torch.cuda.is_available(): device = torch.device("cuda:0") dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 MODEL_NAME = "seamless_expressivity" VOCODER_NAME = "vocoder_pretssel" # used for ASR for toxicity m4t_translator = Translator( model_name_or_card="seamlessM4T_v2_large", vocoder_name_or_card=None, device=device, dtype=dtype, ) unit_tokenizer = load_unity_unit_tokenizer(MODEL_NAME) _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(VOCODER_NAME) gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype) gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype) translator = Translator( MODEL_NAME, vocoder_name_or_card=None, device=device, dtype=dtype, apply_mintox=False, ) text_generation_opts = SequenceGeneratorOptions( beam_size=5, unk_penalty=torch.inf, soft_max_seq_len=(0, 200), step_processor=NGramRepeatBlockProcessor( ngram_size=10, ), ) m4t_text_generation_opts = SequenceGeneratorOptions( beam_size=5, unk_penalty=torch.inf, soft_max_seq_len=(1, 200), step_processor=NGramRepeatBlockProcessor( ngram_size=10, ), ) pretssel_generator = PretsselGenerator( VOCODER_NAME, vocab_info=unit_tokenizer.vocab_info, device=device, dtype=dtype, ) decode_audio = AudioDecoder(dtype=torch.float32, device=device) convert_to_fbank = WaveformToFbankConverter( num_mel_bins=80, waveform_scale=2**15, channel_last=True, standardize=False, device=device, dtype=dtype, ) def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput: fbank = data["fbank"] std, mean = torch.std_mean(fbank, dim=0) data["fbank"] = fbank.subtract(mean).divide(std) data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std) return data collate = Collater(pad_value=0, pad_to_multiple=1) AUDIO_SAMPLE_RATE = 16000 MAX_INPUT_AUDIO_LENGTH = 10 # in seconds def remove_prosody_tokens_from_text(text): # filter out prosody tokens, there is only emphasis '*', and pause '=' text = text.replace("*", "").replace("=", "") text = " ".join(text.split()) return text def preprocess_audio(input_audio_path: str) -> None: arr, org_sr = torchaudio.load(input_audio_path) new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE) max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE) if new_arr.shape[1] > max_length: new_arr = new_arr[:, :max_length] gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.") torchaudio.save(input_audio_path, new_arr, sample_rate=AUDIO_SAMPLE_RATE) def run( input_audio_path: str, source_language: str, target_language: str, ) -> tuple[str, str]: target_language_code = LANGUAGE_NAME_TO_CODE[target_language] source_language_code = LANGUAGE_NAME_TO_CODE[source_language] preprocess_audio(input_audio_path) with pathlib.Path(input_audio_path).open("rb") as fb: block = MemoryBlock(fb.read()) example = decode_audio(block) example = convert_to_fbank(example) example = normalize_fbank(example) example = collate(example) # get transcription for mintox source_sentences, _ = m4t_translator.predict( input=example["fbank"], task_str="S2TT", # get source text tgt_lang=source_language_code, text_generation_opts=m4t_text_generation_opts, ) source_text = str(source_sentences[0]) prosody_encoder_input = example["gcmvn_fbank"] text_output, unit_output = translator.predict( example["fbank"], "S2ST", tgt_lang=target_language_code, src_lang=source_language_code, text_generation_opts=text_generation_opts, unit_generation_ngram_filtering=False, duration_factor=1.0, prosody_encoder_input=prosody_encoder_input, src_text=source_text, # for mintox check ) speech_output = pretssel_generator.predict( unit_output.units, tgt_lang=target_language_code, prosody_encoder_input=prosody_encoder_input, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: torchaudio.save( f.name, speech_output.audio_wavs[0][0].to(torch.float32).cpu(), sample_rate=speech_output.sample_rate, ) text_out = remove_prosody_tokens_from_text(str(text_output[0])) return f.name, text_out TARGET_LANGUAGE_NAMES = [ "English", "French", "German", "Spanish", ] UPDATED_LANGUAGE_LIST = { "English": ["French", "German", "Spanish"], "French": ["English", "German", "Spanish"], "German": ["English", "French", "Spanish"], "Spanish": ["English", "French", "German"], } def rs_change(rs): return gr.update( choices=UPDATED_LANGUAGE_LIST[rs], value=UPDATED_LANGUAGE_LIST[rs][0], ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Row(): with gr.Column(): with gr.Group(): input_audio = gr.Audio(label="Input speech", type="filepath") source_language = gr.Dropdown( label="Source language", choices=TARGET_LANGUAGE_NAMES, value="English", ) target_language = gr.Dropdown( label="Target language", choices=TARGET_LANGUAGE_NAMES, value="French", interactive=True, ) source_language.change( fn=rs_change, inputs=[source_language], outputs=[target_language], ) btn = gr.Button() with gr.Column(): with gr.Group(): output_audio = gr.Audio(label="Translated speech") output_text = gr.Textbox(label="Translated text") gr.Examples( examples=[ ["assets/Excited-English.wav", "English", "Spanish"], ["assets/Whisper-English.wav", "English", "German"], ["assets/FastTalking-French.wav", "French", "English"], ["assets/Sad-English.wav", "English", "Spanish"], ], inputs=[input_audio, source_language, target_language], outputs=[output_audio, output_text], fn=run, cache_examples=CACHE_EXAMPLES, api_name=False, ) btn.click( fn=run, inputs=[input_audio, source_language, target_language], outputs=[output_audio, output_text], api_name="run", ) if __name__ == "__main__": demo.queue(max_size=50).launch()