AP1075 commited on
Commit
b4aec50
1 Parent(s): 8a3381b

Adding app.py for CPU inference

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import torch
4
+ import wavio
5
+ from tqdm import tqdm
6
+ from huggingface_hub import snapshot_download
7
+ from models import AudioDiffusion, DDPMScheduler
8
+ from audioldm.audio.stft import TacotronSTFT
9
+ from audioldm.variational_autoencoder import AutoencoderKL
10
+ from gradio import Markdown
11
+
12
+ class Tango:
13
+ def __init__(self, name="declare-lab/tango", device="cpu"):
14
+
15
+ path = snapshot_download(repo_id=name)
16
+
17
+ vae_config = json.load(open("{}/vae_config.json".format(path)))
18
+ stft_config = json.load(open("{}/stft_config.json".format(path)))
19
+ main_config = json.load(open("{}/main_config.json".format(path)))
20
+
21
+ self.vae = AutoencoderKL(**vae_config).to(device)
22
+ self.stft = TacotronSTFT(**stft_config).to(device)
23
+ self.model = AudioDiffusion(**main_config).to(device)
24
+
25
+ vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
26
+ stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
27
+ main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
28
+
29
+ self.vae.load_state_dict(vae_weights)
30
+ self.stft.load_state_dict(stft_weights)
31
+ self.model.load_state_dict(main_weights)
32
+
33
+ print ("Successfully loaded checkpoint from:", name)
34
+
35
+ self.vae.eval()
36
+ self.stft.eval()
37
+ self.model.eval()
38
+
39
+ self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
40
+
41
+ def chunks(self, lst, n):
42
+ """ Yield successive n-sized chunks from a list. """
43
+ for i in range(0, len(lst), n):
44
+ yield lst[i:i + n]
45
+
46
+ def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
47
+ """ Genrate audio for a single prompt string. """
48
+ with torch.no_grad():
49
+ latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
50
+ mel = self.vae.decode_first_stage(latents)
51
+ wave = self.vae.decode_to_waveform(mel)
52
+ return wave[0]
53
+
54
+ def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
55
+ """ Genrate audio for a list of prompt strings. """
56
+ outputs = []
57
+ for k in tqdm(range(0, len(prompts), batch_size)):
58
+ batch = prompts[k: k+batch_size]
59
+ with torch.no_grad():
60
+ latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
61
+ mel = self.vae.decode_first_stage(latents)
62
+ wave = self.vae.decode_to_waveform(mel)
63
+ outputs += [item for item in wave]
64
+ if samples == 1:
65
+ return outputs
66
+ else:
67
+ return list(self.chunks(outputs, samples))
68
+
69
+ # Initialize Tango model
70
+ tango = Tango()
71
+
72
+ def gradio_generate(prompt):
73
+
74
+ output_wave = tango.generate(prompt)
75
+
76
+ # Save the output_wave as a temporary WAV file
77
+ output_filename = "temp_output.wav"
78
+ wavio.write(output_filename, output_wave, rate=22050, sampwidth=2)
79
+
80
+ return output_filename
81
+
82
+ # Add the description text box
83
+ description_text = '''
84
+ TANGO is a latent diffusion model (LDM) for text-to-audio (TTA) generation. TANGO can generate realistic audios including human sounds, animal sounds, natural and artificial sounds and sound effects from textual prompts. We use the frozen instruction-tuned LLM Flan-T5 as the text encoder and train a UNet based diffusion model for audio generation. We perform comparably to current state-of-the-art models for TTA across both objective and subjective metrics, despite training the LDM on a 63 times smaller dataset. We release our model, training, inference code, and pre-trained checkpoints for the research community.
85
+ '''
86
+
87
+ # Define Gradio input and output components
88
+ input_text = gr.inputs.Textbox(lines=2, label="Prompt")
89
+ output_audio = gr.outputs.Audio(label="Generated Audio", type="filepath")
90
+
91
+ # Create Gradio interface
92
+ gr_interface = gr.Interface(
93
+ fn=gradio_generate,
94
+ inputs=input_text,
95
+ outputs=[output_audio],
96
+ title="Tango Audio Generator",
97
+ description="Generate audio using Tango model by providing a text prompt.",
98
+ allow_flagging=False,
99
+ examples=[
100
+ ["A Dog Barking"],
101
+ ["A loud thunderstorm"],
102
+ ],
103
+ )
104
+
105
+ # Launch Gradio app
106
+ gr_interface.launch()