vladelesin's picture
Update app.py
8ab88f8
raw
history blame
2.08 kB
import gradio as gr
import numpy as np
import torch
from transformers import pipeline
from transformers import VitsModel, VitsTokenizer, FSMTForConditionalGeneration, FSMTTokenizer, Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, T5ForConditionalGeneration, T5Tokenizer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Transform audio to en text
asr_pipe = pipeline("automatic-speech-recognition", model="asapp/sew-d-tiny-100k-ft-ls100h", device=device)
# Translate en to rus text
translation_en_to_rus = pipeline("translation", model="Helsinki-NLP/opus-mt-en-ru")
# Create speech from rus text
model = VitsModel.from_pretrained("facebook/mms-tts-rus")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-rus")
#model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
#tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
def translate(audio):
en_text = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "translate"})
translated_text = translation_en_to_rus(en_text["text"])
return translated_text[0]['translation_text']
def synthesise(text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
speech = model(**inputs).waveform
return speech.cpu()
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
return 16000, synthesised_speech[0]
demo = gr.Blocks()
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy")
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(source="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
examples=[["./example.wav"]]
)
with demo:
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
demo.launch()