Spaces:
Running
on
Zero
Running
on
Zero
Adding app.py for CPU inference
Browse files
app.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import wavio
|
5 |
+
from tqdm import tqdm
|
6 |
+
from huggingface_hub import snapshot_download
|
7 |
+
from models import AudioDiffusion, DDPMScheduler
|
8 |
+
from audioldm.audio.stft import TacotronSTFT
|
9 |
+
from audioldm.variational_autoencoder import AutoencoderKL
|
10 |
+
from gradio import Markdown
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, name="declare-lab/tango", device="cpu"):
|
14 |
+
|
15 |
+
path = snapshot_download(repo_id=name)
|
16 |
+
|
17 |
+
vae_config = json.load(open("{}/vae_config.json".format(path)))
|
18 |
+
stft_config = json.load(open("{}/stft_config.json".format(path)))
|
19 |
+
main_config = json.load(open("{}/main_config.json".format(path)))
|
20 |
+
|
21 |
+
self.vae = AutoencoderKL(**vae_config).to(device)
|
22 |
+
self.stft = TacotronSTFT(**stft_config).to(device)
|
23 |
+
self.model = AudioDiffusion(**main_config).to(device)
|
24 |
+
|
25 |
+
vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
|
26 |
+
stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
|
27 |
+
main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
|
28 |
+
|
29 |
+
self.vae.load_state_dict(vae_weights)
|
30 |
+
self.stft.load_state_dict(stft_weights)
|
31 |
+
self.model.load_state_dict(main_weights)
|
32 |
+
|
33 |
+
print ("Successfully loaded checkpoint from:", name)
|
34 |
+
|
35 |
+
self.vae.eval()
|
36 |
+
self.stft.eval()
|
37 |
+
self.model.eval()
|
38 |
+
|
39 |
+
self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
|
40 |
+
|
41 |
+
def chunks(self, lst, n):
|
42 |
+
""" Yield successive n-sized chunks from a list. """
|
43 |
+
for i in range(0, len(lst), n):
|
44 |
+
yield lst[i:i + n]
|
45 |
+
|
46 |
+
def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
|
47 |
+
""" Genrate audio for a single prompt string. """
|
48 |
+
with torch.no_grad():
|
49 |
+
latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
|
50 |
+
mel = self.vae.decode_first_stage(latents)
|
51 |
+
wave = self.vae.decode_to_waveform(mel)
|
52 |
+
return wave[0]
|
53 |
+
|
54 |
+
def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
|
55 |
+
""" Genrate audio for a list of prompt strings. """
|
56 |
+
outputs = []
|
57 |
+
for k in tqdm(range(0, len(prompts), batch_size)):
|
58 |
+
batch = prompts[k: k+batch_size]
|
59 |
+
with torch.no_grad():
|
60 |
+
latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
|
61 |
+
mel = self.vae.decode_first_stage(latents)
|
62 |
+
wave = self.vae.decode_to_waveform(mel)
|
63 |
+
outputs += [item for item in wave]
|
64 |
+
if samples == 1:
|
65 |
+
return outputs
|
66 |
+
else:
|
67 |
+
return list(self.chunks(outputs, samples))
|
68 |
+
|
69 |
+
# Initialize Tango model
|
70 |
+
tango = Tango()
|
71 |
+
|
72 |
+
def gradio_generate(prompt):
|
73 |
+
|
74 |
+
output_wave = tango.generate(prompt)
|
75 |
+
|
76 |
+
# Save the output_wave as a temporary WAV file
|
77 |
+
output_filename = "temp_output.wav"
|
78 |
+
wavio.write(output_filename, output_wave, rate=22050, sampwidth=2)
|
79 |
+
|
80 |
+
return output_filename
|
81 |
+
|
82 |
+
# Add the description text box
|
83 |
+
description_text = '''
|
84 |
+
TANGO is a latent diffusion model (LDM) for text-to-audio (TTA) generation. TANGO can generate realistic audios including human sounds, animal sounds, natural and artificial sounds and sound effects from textual prompts. We use the frozen instruction-tuned LLM Flan-T5 as the text encoder and train a UNet based diffusion model for audio generation. We perform comparably to current state-of-the-art models for TTA across both objective and subjective metrics, despite training the LDM on a 63 times smaller dataset. We release our model, training, inference code, and pre-trained checkpoints for the research community.
|
85 |
+
'''
|
86 |
+
|
87 |
+
# Define Gradio input and output components
|
88 |
+
input_text = gr.inputs.Textbox(lines=2, label="Prompt")
|
89 |
+
output_audio = gr.outputs.Audio(label="Generated Audio", type="filepath")
|
90 |
+
|
91 |
+
# Create Gradio interface
|
92 |
+
gr_interface = gr.Interface(
|
93 |
+
fn=gradio_generate,
|
94 |
+
inputs=input_text,
|
95 |
+
outputs=[output_audio],
|
96 |
+
title="Tango Audio Generator",
|
97 |
+
description="Generate audio using Tango model by providing a text prompt.",
|
98 |
+
allow_flagging=False,
|
99 |
+
examples=[
|
100 |
+
["A Dog Barking"],
|
101 |
+
["A loud thunderstorm"],
|
102 |
+
],
|
103 |
+
)
|
104 |
+
|
105 |
+
# Launch Gradio app
|
106 |
+
gr_interface.launch()
|