AidenTTS / app.py
arnavmehta7's picture
Update app.py
b0c547a verified
raw
history blame
5.8 kB
import gradio as gr
import torch
import librosa
from pathlib import Path
import tempfile, torchaudio
from transformers import pipeline
from uuid import uuid4
# Load the MARS5 model
mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)
asr_model = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny",
chunk_length_s=30,
device=torch.device("cuda:0"),
)
def transcribe_file(f: str) -> str:
predictions = asr_model(f, return_timestamps=True)["chunks"]
print(f">>>>>. predictions: {predictions}")
return " ".join([prediction["text"] for prediction in predictions])
# Function to process the text and audio input and generate the synthesized output
def synthesize(text, audio_file, transcript, kwargs_dict):
print(f">>>>>>> Kwargs dict: {kwargs_dict}")
# audio_file = Path(audio_file)
# temp_file = f"{uuid4()}.{audio_file.suffix}"
# # copying the audio_file
# with open(audio_file, 'rb') as src, open(temp_file, 'wb') as dst:
# dst.write(src.read())
# audio_file = temp_file
print(f">>>>> synthesizing! audio_file: {audio_file}")
if not transcript:
transcript = transcribe_file(audio_file)
# Load the reference audio
wav, sr = librosa.load(audio_file, sr=mars5.sr, mono=True)
wav = torch.from_numpy(wav)
# Define the configuration for the TTS model
cfg = config_class(**kwargs_dict)
# Generate the synthesized audio
ar_codes, wav_out = mars5.tts(text, wav, transcript.strip(), cfg=cfg)
# Save the synthesized audio to a temporary file
output_path = Path(tempfile.mktemp(suffix=".wav"))
torchaudio.save(output_path, wav_out.unsqueeze(0), mars5.sr)
return str(output_path)
defaults = {
'temperature': 0.8,
'top_k': -1,
'top_p': 0.2,
'typical_p': 1.0,
'freq_penalty': 2.6,
'presence_penalty': 0.4,
'rep_penalty_window': 100,
'max_prompt_phones': 360,
'deep_clone': True,
'nar_guidance_w': 3
}
with gr.Blocks() as demo:
gr.Markdown("## MARS5 TTS Demo\nEnter text and upload an audio file to clone the voice and generate synthesized speech using MARS5 TTS.")
text = gr.Textbox(label="Text to synthesize")
audio_file = gr.Audio(label="Audio file to clone from", type="filepath")
generate_btn = gr.Button("Generate Synthesized Audio")
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("additional inference settings\nWARNING: changing these incorrectly may degrade quality.")
prompt_text = gr.Textbox(label="Transcript of voice reference")
temperature = gr.Slider(minimum=0.01, maximum=3, step=0.01, label="temperature", value=defaults['temperature'])
top_k = gr.Slider(minimum=-1, maximum=2000, step=1, label="top_k", value=defaults['top_k'])
top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label="top_p", value=defaults['top_p'])
typical_p = gr.Slider(minimum=0.01, maximum=1, step=0.01, label="typical_p", value=defaults['typical_p'])
freq_penalty = gr.Slider(minimum=0, maximum=5, step=0.05, label="freq_penalty", value=defaults['freq_penalty'])
presence_penalty = gr.Slider(minimum=0, maximum=5, step=0.05, label="presence_penalty", value=defaults['presence_penalty'])
rep_penalty_window = gr.Slider(minimum=1, maximum=500, step=1, label="rep_penalty_window", value=defaults['rep_penalty_window'])
nar_guidance_w = gr.Slider(minimum=1, maximum=8, step=0.1, label="nar_guidance_w", value=defaults['nar_guidance_w'])
deep_clone = gr.Checkbox(value=defaults['deep_clone'], label='deep_clone')
output = gr.Audio(label="Synthesized Audio", type="filepath")
def on_click(
text,
audio_file,
prompt_text,
temperature,
top_k,
top_p,
typical_p,
freq_penalty,
presence_penalty,
rep_penalty_window,
nar_guidance_w,
deep_clone
):
print(f">>>> transcript: {prompt_text}; audio_file = {audio_file}")
of = synthesize(
text,
audio_file,
prompt_text,
{
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'typical_p': typical_p,
'freq_penalty': freq_penalty,
'presence_penalty': presence_penalty,
'rep_penalty_window': rep_penalty_window,
'nar_guidance_w': nar_guidance_w,
'deep_clone': deep_clone
}
)
print(f">>>> output file: {of}")
return of
generate_btn.click(
on_click,
inputs=[
text,
audio_file,
prompt_text,
temperature,
top_k,
top_p,
typical_p,
freq_penalty,
presence_penalty,
rep_penalty_window,
nar_guidance_w,
deep_clone
],
outputs=[output]
)
gr.Markdown("### Examples")
# Add examples
defaults = [0.8, -1, 0.2, 1.0, 2.6, 0.4, 100, 3, True]
examples = [
["Today is a wonderful day!", "female_speaker_1.flac", "People look, but no one ever finds it.", *defaults],
["You guys need to figure this out.", "male_speaker_1.flac", "Ask her to bring these things with her from the store.", *defaults]
]
gr.Examples(
examples=examples,
inputs=[text, audio_file, prompt_text, temperature, top_k, top_p, typical_p, freq_penalty, presence_penalty, rep_penalty_window, nar_guidance_w, deep_clone],
outputs=[output],
cache_examples=False,
fn=on_click
)
demo.launch(share=False)