mars5_space / app.py
arnavmehta7's picture
Update app.py
9f8a599 verified
import gradio as gr
import torch
import librosa
from pathlib import Path
import tempfile, torchaudio
from transformers import pipeline
# Load the MARS5 model
mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)
asr_model = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny",
chunk_length_s=30,
device=torch.device("cuda:0"),
)
def transcribe_file(f: str) -> str:
predictions = asr_model(f, return_timestamps=True)["chunks"]
print(f">>>>>. predictions: {predictions}")
return " ".join([prediction["text"] for prediction in predictions])
# Function to process the text and audio input and generate the synthesized output
def synthesize(text, audio_file, transcript, kwargs_dict):
print(f">>>>>>> Kwargs dict: {kwargs_dict}")
if not transcript:
transcript = transcribe_file(audio_file)
# Load the reference audio
wav, sr = librosa.load(audio_file, sr=mars5.sr, mono=True)
wav = torch.from_numpy(wav)
# Define the configuration for the TTS model
cfg = config_class(**kwargs_dict)
# Generate the synthesized audio
ar_codes, wav_out = mars5.tts(text, wav, transcript.strip(), cfg=cfg)
# Save the synthesized audio to a temporary file
output_path = Path(tempfile.mktemp(suffix=".wav"))
torchaudio.save(output_path, wav_out.unsqueeze(0), mars5.sr)
return str(output_path)
defaults = {
'temperature': 0.8,
'top_k': -1,
'top_p': 0.2,
'typical_p': 1.0,
'freq_penalty': 2.6,
'presence_penalty': 0.4,
'rep_penalty_window': 100,
'max_prompt_phones': 360,
'deep_clone': True,
'nar_guidance_w': 3
}
with gr.Blocks() as demo:
link = "https://github.com/Camb-ai/MARS5-TTS"
gr.Markdown("## MARS5 TTS Demo\nEnter text and upload an audio file to clone the voice and generate synthesized speech using **[MARS5-TTS]({link})**")
text = gr.Textbox(label="Text to synthesize")
audio_file = gr.Audio(label="Audio file to clone from", type="filepath")
generate_btn = gr.Button("Generate Synthesized Audio")
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("additional inference settings\nWARNING: changing these incorrectly may degrade quality.")
prompt_text = gr.Textbox(label="Transcript of voice reference")
temperature = gr.Slider(minimum=0.01, maximum=3, step=0.01, label="temperature", value=defaults['temperature'])
top_k = gr.Slider(minimum=-1, maximum=2000, step=1, label="top_k", value=defaults['top_k'])
top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label="top_p", value=defaults['top_p'])
typical_p = gr.Slider(minimum=0.01, maximum=1, step=0.01, label="typical_p", value=defaults['typical_p'])
freq_penalty = gr.Slider(minimum=0, maximum=5, step=0.05, label="freq_penalty", value=defaults['freq_penalty'])
presence_penalty = gr.Slider(minimum=0, maximum=5, step=0.05, label="presence_penalty", value=defaults['presence_penalty'])
rep_penalty_window = gr.Slider(minimum=1, maximum=500, step=1, label="rep_penalty_window", value=defaults['rep_penalty_window'])
nar_guidance_w = gr.Slider(minimum=1, maximum=8, step=0.1, label="nar_guidance_w", value=defaults['nar_guidance_w'])
deep_clone = gr.Checkbox(value=defaults['deep_clone'], label='deep_clone')
output = gr.Audio(label="Synthesized Audio", type="filepath")
def on_click(
text,
audio_file,
prompt_text,
temperature,
top_k,
top_p,
typical_p,
freq_penalty,
presence_penalty,
rep_penalty_window,
nar_guidance_w,
deep_clone
):
print(f">>>> transcript: {prompt_text}; audio_file = {audio_file}")
of = synthesize(
text,
audio_file,
prompt_text,
{
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'typical_p': typical_p,
'freq_penalty': freq_penalty,
'presence_penalty': presence_penalty,
'rep_penalty_window': rep_penalty_window,
'nar_guidance_w': nar_guidance_w,
'deep_clone': deep_clone
}
)
print(f">>>> output file: {of}")
return of
generate_btn.click(
on_click,
inputs=[
text,
audio_file,
prompt_text,
temperature,
top_k,
top_p,
typical_p,
freq_penalty,
presence_penalty,
rep_penalty_window,
nar_guidance_w,
deep_clone
],
outputs=[output]
)
# Add examples
defaults = [0.8, -1, 0.2, 1.0, 2.6, 0.4, 100, 3, True]
examples = [
["Can you please go there and figure it out?", "female_speaker_1.flac", "People look, but no one ever finds it.", *defaults],
["Hey, do you need my help?", "male_speaker_1.flac", "Ask her to bring these things with her from the store.", *defaults]
]
gr.Examples(
examples=examples,
inputs=[text, audio_file, prompt_text, temperature, top_k, top_p, typical_p, freq_penalty, presence_penalty, rep_penalty_window, nar_guidance_w, deep_clone],
outputs=[output],
cache_examples=False,
fn=on_click
)
demo.launch(share=False)