Spaces:
Sleeping
Sleeping
import concurrent | |
import os | |
import tempfile | |
from typing import Optional, Tuple | |
import numpy as np | |
import spaces | |
from transformers import pipeline | |
import gradio as gr | |
import torch | |
import torchaudio | |
from resemble_enhance.enhancer.inference import denoise, enhance | |
from flore200_codes import flores_codes | |
from tts import BambaraTTS | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Translation pipeline | |
translation_model = "oza75/nllb-600M-mt-french-bambara" | |
translator = pipeline("translation", model=translation_model, max_length=512) | |
# Text-to-Speech pipeline | |
tts_model = "oza75/bambara-tts" | |
tts = BambaraTTS(tts_model) | |
# Function to translate text to Bambara | |
def translate_to_bambara(text, src_lang): | |
translation = translator(text, src_lang=src_lang, tgt_lang="bam_Latn") | |
return translation[0]['translation_text'] | |
# Function to convert text to speech | |
def text_to_speech(bambara_text, reference_audio: Optional[Tuple] = None): | |
if reference_audio is not None: | |
ref_sr, ref_audio = reference_audio | |
ref_audio = torch.from_numpy(ref_audio) | |
# Add a channel dimension if the audio is 1D | |
if ref_audio.ndim == 1: | |
ref_audio = ref_audio.unsqueeze(0) | |
# Save the reference audio to a temporary file if it's not None | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp: | |
torchaudio.save(tmp.name, ref_audio, ref_sr) | |
tmp_path = tmp.name | |
# Use the temporary file as the speaker reference | |
sr, audio = tts.text_to_speech(bambara_text, speaker_reference_wav_path=tmp_path) | |
# Clean up the temporary file | |
os.unlink(tmp_path) | |
else: | |
# If no reference audio provided, proceed with the default | |
sr, audio = tts.text_to_speech(bambara_text) | |
audio = audio.mean(dim=0) | |
return audio, sr | |
# Function to enhance speech | |
# @spaces.GPU | |
# def enhance_speech(audio_array, sampling_rate, solver, nfe, tau, denoise_before_enhancement): | |
# solver = solver.lower() | |
# nfe = int(nfe) | |
# lambd = 0.9 if denoise_before_enhancement else 0.1 | |
# | |
# @spaces.GPU(duration=360) | |
# def denoise_audio(): | |
# try: | |
# return denoise(audio_array, sampling_rate, device) | |
# except Exception as e: | |
# print("> Error while denoising : ", str(e)) | |
# return audio_array, sampling_rate | |
# | |
# @spaces.GPU(duration=360) | |
# def enhance_audio(): | |
# try: | |
# return enhance(audio_array, sampling_rate, device, nfe=nfe, solver=solver, lambd=lambd, tau=tau) | |
# except Exception as e: | |
# print("> Error while enhancement : ", str(e)) | |
# return audio_array, sampling_rate | |
# | |
# with concurrent.futures.ThreadPoolExecutor() as executor: | |
# future_denoise = executor.submit(denoise_audio) | |
# future_enhance = executor.submit(enhance_audio) | |
# | |
# denoised_audio, new_sr1 = future_denoise.result() | |
# enhanced_audio, new_sr2 = future_enhance.result() | |
# | |
# # Convert to numpy and return | |
# return (new_sr1, denoised_audio.cpu().numpy()), (new_sr2, enhanced_audio.cpu().numpy()) | |
def enhance_speech(audio_array, sampling_rate, solver, nfe, tau, denoise_before_enhancement): | |
solver = solver.lower() | |
nfe = int(nfe) | |
lambd = 0.9 if denoise_before_enhancement else 0.1 | |
denoised_audio, new_sr1 = denoise(audio_array, sampling_rate, device) | |
enhanced_audio, new_sr2 = enhance(audio_array, sampling_rate, device, nfe=nfe, solver=solver, lambd=lambd, tau=tau) | |
# Convert to numpy and return | |
return (new_sr1, denoised_audio.cpu().numpy()), (new_sr2, enhanced_audio.cpu().numpy()) | |
def convert_to_int16(audio_array): | |
if audio_array.dtype == torch.float32: | |
# Assuming audio_array values are in the range [-1.0, 1.0] | |
# Scale to int16 range and convert the datatype | |
audio_array = (audio_array * 32767).to(torch.int16) | |
return audio_array | |
# Define the Gradio interface | |
def _fn( | |
src_lang, | |
text, | |
reference_audio=None, | |
solver="Midpoint", | |
nfe=64, | |
prior_temp=0.5, | |
denoise_before_enhancement=False | |
): | |
source_lang = flores_codes[src_lang] | |
# Step 1: Translate the text to Bambara | |
bambara_text = translate_to_bambara(text, source_lang) | |
# Step 2: Convert the translated text to speech with reference audio | |
if reference_audio is not None: | |
audio_array, sampling_rate = text_to_speech(bambara_text, reference_audio) | |
else: | |
audio_array, sampling_rate = text_to_speech(bambara_text) | |
# Step 3: Enhance the audio | |
denoised_audio, enhanced_audio = enhance_speech( | |
audio_array, | |
sampling_rate, | |
solver, | |
nfe, | |
prior_temp, | |
denoise_before_enhancement | |
) | |
print("Audio Array Shape:", audio_array.shape) | |
print("Sample Rate:", sampling_rate) | |
print("Audio Array Dtype:", audio_array.dtype) | |
print("Max Value in Audio Array:", torch.max(audio_array)) | |
print("Min Value in Audio Array:", torch.min(audio_array)) | |
print("Sampling rate type: ", type(sampling_rate)) | |
print("Denoised sampling rate type: ", type(denoised_audio[0])) | |
print("Enhanced sampling rate type: ", type(enhanced_audio[0])) | |
# Return all outputs | |
return ( | |
bambara_text, | |
(sampling_rate, convert_to_int16(audio_array).numpy()), | |
(denoised_audio[0], convert_to_int16(denoised_audio[1])), | |
(enhanced_audio[0], convert_to_int16(enhanced_audio[1])) | |
) | |
def main(): | |
lang_codes = list(flores_codes.keys()) | |
# Build Gradio app | |
app = gr.Interface( | |
fn=_fn, | |
inputs=[ | |
gr.Dropdown(label="Source Language", choices=lang_codes, value='French'), | |
gr.Textbox(label="Text to Translate", lines=3), | |
gr.Audio(label="Clone your voice (optional)", type="numpy", format="wav"), | |
gr.Dropdown( | |
choices=["Midpoint", "RK4", "Euler"], value="Midpoint", | |
label="ODE Solver (Midpoint is recommended)" | |
), | |
gr.Slider(minimum=1, maximum=128, value=64, step=1, label="Number of Function Evaluations"), | |
gr.Slider(minimum=0.1, maximum=1, value=0.5, step=0.01, label="Prior Temperature"), | |
gr.Checkbox(value=False, label="Denoise Before Enhancement") | |
], | |
outputs=[ | |
gr.Textbox(label="Translated Text"), | |
gr.Audio(label="Original TTS Audio"), | |
gr.Audio(label="Denoised Audio"), | |
gr.Audio(label="Enhanced Audio") | |
], | |
title="Bambara Translation and Text to Speech with Audio Enhancement", | |
description="Translate text to Bambara and convert it to speech with options to enhance audio quality." | |
) | |
app.launch(share=False) | |
if __name__ == "__main__": | |
main() | |