bambara-mt / app.py
Aboubacar OUATTARA - kaira
initial commit
d72812e
raw history blame
No virus
3.84 kB
import concurrent
from transformers import pipeline
import gradio as gr
import torch
import torchaudio
from resemble_enhance.enhancer.inference import denoise, enhance
from flore200_codes import flores_codes
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Translation pipeline
translation_model = "oza75/nllb-600M-mt-french-bambara"
translator = pipeline("translation", model=translation_model, max_length=512)
# Text-to-Speech pipeline
tts_model = "oza75/bambara-tts-male-001"
tts = pipeline("text-to-speech", model=tts_model, device=device)
# Function to translate text to Bambara
def translate_to_bambara(text, src_lang):
translation = translator(text, src_lang=src_lang, tgt_lang="bam_Latn")
return translation[0]['translation_text']
# Function to convert text to speech
def text_to_speech(bambara_text):
speech = tts(bambara_text)
audio, sr = speech['audio'], speech['sampling_rate']
audio = torch.from_numpy(audio).mean(dim=0)
return audio, sr
# Function to enhance speech
def enhance_speech(audio_array, sampling_rate, solver, nfe, tau, denoise_before_enhancement):
solver = solver.lower()
nfe = int(nfe)
lambd = 0.9 if denoise_before_enhancement else 0.1
def denoise_audio():
return denoise(audio_array, sampling_rate, device)
def enhance_audio():
return enhance(audio_array, sampling_rate, device, nfe=nfe, solver=solver, lambd=lambd, tau=tau)
with concurrent.futures.ThreadPoolExecutor() as executor:
future_denoise = executor.submit(denoise_audio)
future_enhance = executor.submit(enhance_audio)
denoised_audio, new_sr1 = future_denoise.result()
enhanced_audio, new_sr2 = future_enhance.result()
# Convert to numpy and return
return (new_sr1, denoised_audio.cpu().numpy()), (new_sr2, enhanced_audio.cpu().numpy())
# Define the Gradio interface
def _fn(src_lang, text, solver="Midpoint", nfe=64, prior_temp=0.5, denoise_before_enhancement=False):
source_lang = flores_codes[src_lang]
# Step 1: Translate the text to Bambara
bambara_text = translate_to_bambara(text, source_lang)
# Step 2: Convert the translated text to speech
audio_array, sampling_rate = text_to_speech(bambara_text)
# Step 3: Enhance the audio
denoised_audio, enhanced_audio = enhance_speech(
audio_array,
sampling_rate,
solver,
nfe,
prior_temp,
denoise_before_enhancement
)
# Return all outputs
return bambara_text, (sampling_rate, audio_array.cpu().numpy()), denoised_audio, enhanced_audio
def main():
lang_codes = list(flores_codes.keys())
# Build Gradio app
app = gr.Interface(
fn=_fn,
inputs=[
gr.Dropdown(label="Source Language", choices=lang_codes, value='French'),
gr.Textbox(label="Text to Translate"),
gr.Dropdown(
choices=["Midpoint", "RK4", "Euler"], value="Midpoint",
label="ODE Solver (Midpoint is recommended)"
),
gr.Slider(minimum=1, maximum=128, value=64, step=1, label="Number of Function Evaluations"),
gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Prior Temperature"),
gr.Checkbox(value=False, label="Denoise Before Enhancement")
],
outputs=[
gr.Textbox(label="Translated Text"),
gr.Audio(label="Original TTS Audio"),
gr.Audio(label="Denoised Audio"),
gr.Audio(label="Enhanced Audio")
],
title="Bambara Translation and Text to Speech with Audio Enhancement",
description="Translate text to Bambara and convert it to speech with options to enhance audio quality."
)
app.launch()
if __name__ == "__main__":
main()