#!/usr/bin/env python import os import pathlib import tempfile import gradio as gr import torch import torchaudio from fairseq2.assets import InProcAssetMetadataProvider, asset_store from fairseq2.data import Collater, SequenceData, VocabularyInfo from fairseq2.data.audio import ( AudioDecoder, WaveformToFbankConverter, WaveformToFbankOutput, ) from seamless_communication.inference import SequenceGeneratorOptions from fairseq2.generation import NGramRepeatBlockProcessor from fairseq2.memory import MemoryBlock from fairseq2.typing import DataType, Device from huggingface_hub import snapshot_download from seamless_communication.inference import BatchedSpeechOutput, Translator, SequenceGeneratorOptions from seamless_communication.models.generator.loader import load_pretssel_vocoder_model from seamless_communication.models.unity import ( UnitTokenizer, load_gcmvn_stats, load_unity_text_tokenizer, load_unity_unit_tokenizer, ) from torch.nn import Module from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import PretsselGenerator from utils import LANGUAGE_CODE_TO_NAME DESCRIPTION = """\ # Seamless Expressive [SeamlessExpressive](https://github.com/facebookresearch/seamless_communication/blob/main/docs/expressive/README.md) is a speech-to-speech translation model that captures certain underexplored aspects of prosody such as speech rate and pauses, while preserving the style of one's voice and high content translation quality. """ CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available() CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models")) if not CHECKPOINTS_PATH.exists(): snapshot_download(repo_id="facebook/seamless-expressive", repo_type="model", local_dir=CHECKPOINTS_PATH) snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH) # Ensure that we do not have any other environment resolvers and always return # "demo" for demo purposes. asset_store.env_resolvers.clear() asset_store.env_resolvers.append(lambda: "demo") # Construct an `InProcAssetMetadataProvider` with environment-specific metadata # that just overrides the regular metadata for "demo" environment. Note the "@demo" suffix. demo_metadata = [ { "name": "seamless_expressivity@demo", "checkpoint": f"file://{CHECKPOINTS_PATH}/m2m_expressive_unity.pt", "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model", }, { "name": "vocoder_pretssel@demo", "checkpoint": f"file://{CHECKPOINTS_PATH}/pretssel_melhifigan_wm-final.pt", }, { "name": "seamlessM4T_v2_large@demo", "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt", "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model", }, ] asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata)) LANGUAGE_NAME_TO_CODE = {v: k for k, v in LANGUAGE_CODE_TO_NAME.items()} if torch.cuda.is_available(): device = torch.device("cuda:0") dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 MODEL_NAME = "seamless_expressivity" VOCODER_NAME = "vocoder_pretssel" # used for ASR for toxicity m4t_translator = Translator( model_name_or_card="seamlessM4T_v2_large", vocoder_name_or_card=None, device=device, dtype=dtype, ) unit_tokenizer = load_unity_unit_tokenizer(MODEL_NAME) _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(VOCODER_NAME) gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype) gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype) translator = Translator( MODEL_NAME, vocoder_name_or_card=None, device=device, dtype=dtype, apply_mintox=False, ) text_generation_opts = SequenceGeneratorOptions( beam_size=5, unk_penalty=torch.inf, soft_max_seq_len=(0, 200), step_processor=NGramRepeatBlockProcessor( ngram_size=10, ), ) m4t_text_generation_opts = SequenceGeneratorOptions( beam_size=5, unk_penalty=torch.inf, soft_max_seq_len=(1, 200), step_processor=NGramRepeatBlockProcessor( ngram_size=10, ), ) pretssel_generator = PretsselGenerator( VOCODER_NAME, vocab_info=unit_tokenizer.vocab_info, device=device, dtype=dtype, ) decode_audio = AudioDecoder(dtype=torch.float32, device=device) convert_to_fbank = WaveformToFbankConverter( num_mel_bins=80, waveform_scale=2**15, channel_last=True, standardize=False, device=device, dtype=dtype, ) def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput: fbank = data["fbank"] std, mean = torch.std_mean(fbank, dim=0) data["fbank"] = fbank.subtract(mean).divide(std) data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std) return data collate = Collater(pad_value=0, pad_to_multiple=1) AUDIO_SAMPLE_RATE = 16000 MAX_INPUT_AUDIO_LENGTH = 10 # in seconds def remove_prosody_tokens_from_text(text): # filter out prosody tokens, there is only emphasis '*', and pause '=' text = text.replace("*", "").replace("=", "") text = " ".join(text.split()) return text def preprocess_audio(input_audio_path: str) -> None: arr, org_sr = torchaudio.load(input_audio_path) new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE) max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE) if new_arr.shape[1] > max_length: new_arr = new_arr[:, :max_length] gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.") torchaudio.save(input_audio_path, new_arr, sample_rate=AUDIO_SAMPLE_RATE) def run( input_audio_path: str, source_language: str, target_language: str, ) -> tuple[str, str]: target_language_code = LANGUAGE_NAME_TO_CODE[target_language] source_language_code = LANGUAGE_NAME_TO_CODE[source_language] preprocess_audio(input_audio_path) with pathlib.Path(input_audio_path).open("rb") as fb: block = MemoryBlock(fb.read()) example = decode_audio(block) example = convert_to_fbank(example) example = normalize_fbank(example) example = collate(example) # get transcription for mintox source_sentences, _ = m4t_translator.predict( input=example["fbank"], task_str="S2TT", # get source text tgt_lang=source_language_code, text_generation_opts=m4t_text_generation_opts, ) source_text = str(source_sentences[0]) prosody_encoder_input = example["gcmvn_fbank"] text_output, unit_output = translator.predict( example["fbank"], "S2ST", tgt_lang=target_language_code, src_lang=source_language_code, text_generation_opts=text_generation_opts, unit_generation_ngram_filtering=False, duration_factor=1.0, prosody_encoder_input=prosody_encoder_input, src_text=source_text, # for mintox check ) speech_output = pretssel_generator.predict( unit_output.units, tgt_lang=target_language_code, prosody_encoder_input=prosody_encoder_input, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: torchaudio.save( f.name, speech_output.audio_wavs[0][0].to(torch.float32).cpu(), sample_rate=speech_output.sample_rate, ) text_out = remove_prosody_tokens_from_text(str(text_output[0])) return f.name, text_out TARGET_LANGUAGE_NAMES = [ "English", "French", "German", "Spanish", ] UPDATED_LANGUAGE_LIST = { "English": ["French", "German", "Spanish"], "French": ["English", "German", "Spanish"], "German": ["English", "French", "Spanish"], "Spanish": ["English", "French", "German"], } def rs_change(rs): return gr.update( choices=UPDATED_LANGUAGE_LIST[rs], value=UPDATED_LANGUAGE_LIST[rs][0], ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Row(): with gr.Column(): with gr.Group(): input_audio = gr.Audio(label="Input speech", type="filepath") source_language = gr.Dropdown( label="Source language", choices=TARGET_LANGUAGE_NAMES, value="English", ) target_language = gr.Dropdown( label="Target language", choices=TARGET_LANGUAGE_NAMES, value="French", interactive=True, ) source_language.change( fn=rs_change, inputs=[source_language], outputs=[target_language], ) btn = gr.Button() with gr.Column(): with gr.Group(): output_audio = gr.Audio(label="Translated speech") output_text = gr.Textbox(label="Translated text") gr.Examples( examples=[ ["assets/Excited-Es.wav", "English", "Spanish"], ["assets/FastTalking-En.wav", "French", "English"], ["assets/Sad-Es.wav", "English", "Spanish"], ], inputs=[input_audio, source_language, target_language], outputs=[output_audio, output_text], fn=run, cache_examples=CACHE_EXAMPLES, api_name=False, ) btn.click( fn=run, inputs=[input_audio, source_language, target_language], outputs=[output_audio, output_text], api_name="run", ) if __name__ == "__main__": demo.queue(max_size=50).launch()