STST / app.py
pragsGit's picture
Upload app.py
86fe041 verified
raw history blame
No virus
1.61 kB
import torch
from transformers import pipeline
from transformers import VitsModel, VitsTokenizer
import numpy as np
import gradio as gr
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-base",
device=device
)
def translate(audio):
outputs = pipe(
audio,
max_new_tokens=256,
generate_kwargs={"task": "transcribe", "language": "es"}
)
model = VitsModel.from_pretrained("facebook/mms-tts-spa")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-spa")
def synthesise(text):
inputs=tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
outputs = model(input_ids)
return outputs["waveform"]
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
synthesised_speech = (synthesised_speech.numpy() * max_range).astype(np.int16)
return 16000, synthesised_speech
demo = gr.Blocks()
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
)
with demo:
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
demo.launch()