ConsistencyTTA / run_gradio.py
Bai-YT's picture
Update run_gradio.py
305d037 verified
raw
history blame
2.96 kB
import torch
import gradio as gr
import soundfile as sf
import numpy as np
import random, os
from consistencytta import ConsistencyTTA
def seed_all(seed):
""" Seed all random number generators. """
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.random.manual_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
device = torch.device(
"cuda:0" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else "cpu"
)
sr = 16000
# Build ConsistencyTTA model
consistencytta = ConsistencyTTA().to(device)
consistencytta.eval()
consistencytta.requires_grad_(False)
def generate(prompt: str, seed: str = '', cfg_weight: float = 4.):
""" Generate audio from a given prompt.
Args:
prompt (str): Text prompt to generate audio from.
seed (str, optional): Random seed. Defaults to '', which means no seed.
"""
if seed != '':
try:
seed_all(int(seed))
except:
pass
with torch.no_grad():
with torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available()
):
wav = consistencytta(
[prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr
)
sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16')
return "output.wav"
# Generate test audio
print("Generating test audio...")
generate("A dog barks as a train passes by.", seed=1)
print("Test audio generated successfully! Starting Gradio interface...")
# Launch Gradio interface
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(
label="Text", value="Several people cheer and scream and speak as water flows hard."
),
gr.Textbox(label="Random Seed (Optional)", value=''),
gr.Slider(
minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength"
)],
outputs="audio",
title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \
"Generation with Consistency Distillation",
description="This is the official demo page for <a href='https://consistency-tta.github." \
"io' target=&ldquo;blank&rdquo;>ConsistencyTTA</a>, a model that accelerates " \
"diffusion-based text-to-audio generation hundreds of times with consistency " \
"models. <br> Here, the audio is generated within a single non-autoregressive " \
"forward pass from the CLAP-finetuned ConsistencyTTA checkpoint. <br> Since " \
"the training dataset does not include speech, the model is not expected to " \
"generate coherent speech. <br> Have fun!"
)
iface.launch()