juangtzi's picture
Update app.py
f9ef521 verified
raw
history blame
5.29 kB
import gradio as gr
import numpy as np
import torch
from transformers import pipeline, VitsModel, AutoTokenizer, AutoTokenizer
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
device = "cuda:0" if torch.cuda.is_available() else "cpu"
translation_models = {
"en": "Helsinki-NLP/opus-mt-en-es", # Inglés a Español
"fr": "Helsinki-NLP/opus-mt-fr-es", # Francés a Español
"de": "Helsinki-NLP/opus-mt-de-es", # Alemán a Español
"it": "Helsinki-NLP/opus-mt-it-es", # Italiano a Español
"pt": "Helsinki-NLP/opus-mt-pt-es", # Portugués a Español
"nl": "Helsinki-NLP/opus-mt-nl-es", # Neerlandés (Holandés) a Español
"fi": "Helsinki-NLP/opus-mt-fi-es", # Finés a Español
"sv": "Helsinki-NLP/opus-mt-sv-es", # Sueco a Español
"da": "Helsinki-NLP/opus-mt-da-es", # Danés a Español
"no": "Helsinki-NLP/opus-mt-no-es", # Noruego a Español
"ru": "Helsinki-NLP/opus-mt-ru-es", # Ruso a Español
"pl": "Helsinki-NLP/opus-mt-pl-es", # Polaco a Español
"cs": "Helsinki-NLP/opus-mt-cs-es", # Checo a Español
"tr": "Helsinki-NLP/opus-mt-tr-es", # Turco a Español
"zh": "Helsinki-NLP/opus-mt-zh-es", # Chino a Español
"ja": "Helsinki-NLP/opus-mt-ja-es", # Japonés a Español
"ar": "Helsinki-NLP/opus-mt-ar-es", # Árabe a Español
"ro": "Helsinki-NLP/opus-mt-ro-es", # Rumano a Español
"el": "Helsinki-NLP/opus-mt-el-es", # Griego a Español
"bg": "Helsinki-NLP/opus-mt-bg-es", # Búlgaro a Español
"uk": "Helsinki-NLP/opus-mt-uk-es", # Ucraniano a Español
"he": "Helsinki-NLP/opus-mt-he-es", # Hebreo a Español
"lt": "Helsinki-NLP/opus-mt-lt-es", # Lituano a Español
"et": "Helsinki-NLP/opus-mt-et-es", # Estonio a Español
"hr": "Helsinki-NLP/opus-mt-hr-es", # Croata a Español
"hu": "Helsinki-NLP/opus-mt-hu-es", # Húngaro a Español
"lv": "Helsinki-NLP/opus-mt-lv-es", # Letón a Español
"sl": "Helsinki-NLP/opus-mt-sl-es", # Esloveno a Español
"sk": "Helsinki-NLP/opus-mt-sk-es", # Eslovaco a Español
"sr": "Helsinki-NLP/opus-mt-sr-es", # Serbio a Español
"fa": "Helsinki-NLP/opus-mt-fa-es", # Persa a Español
}
asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
#vist_model = VitsModel.from_pretrained("facebook/mms-tts-spa")
#vist_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-spa")
model = SpeechT5ForTextToSpeech.from_pretrained(
"juangtzi/speecht5_finetuned_voxpopuli_es"
)
checkpoint = "microsoft/speecht5_tts"
processor = SpeechT5Processor.from_pretrained(checkpoint)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
speaker_embeddings2 = np.load('speaker_embeddings.npy')
speaker_embeddings2 = torch.tensor(speaker_embeddings2)
print(speaker_embeddings2)
lang_detector = pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection")
def language_detector(text):
resultado = lang_detector(text)
idioma_detectado = resultado[0]['label']
print(idioma_detectado)
return idioma_detectado
def translate(audio):
transcribe = asr_pipe(audio, max_new_tokens=256)
codigo_idioma = language_detector(transcribe['text'])
if codigo_idioma in translation_models:
translator = pipeline("translation", model=translation_models[codigo_idioma])
traduccion = translator(transcribe['text'])
else:
transcribe = transcribe['text']
print(f"No hay un modelo de traducción disponible para el idioma detectado {codigo_idioma}")
return transcribe
return traduccion
def synthesise(text):
if isinstance(text, list):
text = text[0]['translation_text']
else:
text = text
print(text)
inputs = processor(text, return_tensors="pt")
output = model.generate_speech(inputs["input_ids"], speaker_embeddings2, vocoder=vocoder)
return output
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
audio_data = synthesised_speech.cpu().numpy()
audio_data = np.squeeze(audio_data)
audio_data = audio_data / np.max(np.abs(audio_data))
sample_rate = 16000
return (sample_rate, audio_data)
title = "Cascaded STST"
description = """
Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in Spanish.
![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
"""
demo = gr.Blocks()
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
title=title,
description=description,
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
examples=[["./example.wav"]],
title=title,
description=description,
)
with demo:
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
demo.launch()