cotxetj's picture
Update app.py
3a2a55c
import torch
from transformers import pipeline, VitsModel, VitsTokenizer
import numpy as np
import gradio as gr
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load Whisper-small
pipe = pipeline("automatic-speech-recognition",
model="openai/whisper-small",
device=device
)
# Load the model checkpoint and tokenizer
#model = VitsModel.from_pretrained("Matthijs/mms-tts-fra")
#tokenizer = VitsTokenizer.from_pretrained("Matthijs/mms-tts-fra")
model = VitsModel.from_pretrained("facebook/mms-tts-fra")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-fra")
# Define a function to translate an audio, in english here
def translate(audio):
outputs = pipe(audio, max_new_tokens=256,
generate_kwargs={"task": "transcribe", "language": "eng"})
return outputs["text"]
# Define function to generate the waveform output
def synthesise(text):
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
outputs = model(input_ids)
return outputs.audio[0]
# Define the pipeline
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
synthesised_speech = (
synthesised_speech.numpy() * 32767).astype(np.int16)
return 16000, synthesised_speech
def predict(transType, language, audio, audio_mic = None):
if not audio and audio_mic:
audio = audio_mic
if transType == "Text":
return translate(audio)
if transType == "Audio":
return speech_to_speech_translation(audio)
# Define the title etc
title = "Swedish STSOT (Speech To Speech Or Text)"
description="Use Whisper pretrained model to convert swedish audio to english (text or audio)"
supportLangs = ["Swedish", "French (in training)"]
transTypes = ["Text", "Audio"]
# examples = [
# ["Text", "Swedish", "ex1.mp3", None],
# ["Audio", "Swedish", "ex2.mp3", None]
# ]
examples = []
demo = gr.Interface(
fn=predict,
inputs=[
gr.Radio(label="Choose your output format", choices=transTypes),
gr.Radio(label="Choose a source language", choices=supportLangs, value="Swedish"),
gr.Audio(label="Import an audio", sources="upload", type="numpy"),
gr.Audio(label="Record an audio", sources="microphone", type="numpy"),
],
outputs=[
gr.Text(label="Translation"),
],
title=title,
description=description,
article="",
examples=examples,
).launch()
demo.launch()