AP1075 commited on
Commit
4802060
1 Parent(s): de5b49d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -106
app.py DELETED
@@ -1,106 +0,0 @@
1
- import gradio as gr
2
- import json
3
- import torch
4
- import wavio
5
- from tqdm import tqdm
6
- from huggingface_hub import snapshot_download
7
- from models import AudioDiffusion, DDPMScheduler
8
- from audioldm.audio.stft import TacotronSTFT
9
- from audioldm.variational_autoencoder import AutoencoderKL
10
- from gradio import Markdown
11
-
12
- class Tango:
13
- def __init__(self, name="declare-lab/tango", device="cpu"):
14
-
15
- path = snapshot_download(repo_id=name)
16
-
17
- vae_config = json.load(open("{}/vae_config.json".format(path)))
18
- stft_config = json.load(open("{}/stft_config.json".format(path)))
19
- main_config = json.load(open("{}/main_config.json".format(path)))
20
-
21
- self.vae = AutoencoderKL(**vae_config).to(device)
22
- self.stft = TacotronSTFT(**stft_config).to(device)
23
- self.model = AudioDiffusion(**main_config).to(device)
24
-
25
- vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
26
- stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
27
- main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
28
-
29
- self.vae.load_state_dict(vae_weights)
30
- self.stft.load_state_dict(stft_weights)
31
- self.model.load_state_dict(main_weights)
32
-
33
- print ("Successfully loaded checkpoint from:", name)
34
-
35
- self.vae.eval()
36
- self.stft.eval()
37
- self.model.eval()
38
-
39
- self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
40
-
41
- def chunks(self, lst, n):
42
- """ Yield successive n-sized chunks from a list. """
43
- for i in range(0, len(lst), n):
44
- yield lst[i:i + n]
45
-
46
- def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
47
- """ Genrate audio for a single prompt string. """
48
- with torch.no_grad():
49
- latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
50
- mel = self.vae.decode_first_stage(latents)
51
- wave = self.vae.decode_to_waveform(mel)
52
- return wave[0]
53
-
54
- def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
55
- """ Genrate audio for a list of prompt strings. """
56
- outputs = []
57
- for k in tqdm(range(0, len(prompts), batch_size)):
58
- batch = prompts[k: k+batch_size]
59
- with torch.no_grad():
60
- latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
61
- mel = self.vae.decode_first_stage(latents)
62
- wave = self.vae.decode_to_waveform(mel)
63
- outputs += [item for item in wave]
64
- if samples == 1:
65
- return outputs
66
- else:
67
- return list(self.chunks(outputs, samples))
68
-
69
- # Initialize Tango model
70
- tango = Tango()
71
-
72
- def gradio_generate(prompt):
73
-
74
- output_wave = tango.generate(prompt)
75
-
76
- # Save the output_wave as a temporary WAV file
77
- output_filename = "temp_output.wav"
78
- wavio.write(output_filename, output_wave, rate=22050, sampwidth=2)
79
-
80
- return output_filename
81
-
82
- # Add the description text box
83
- description_text = '''
84
- TANGO is a latent diffusion model (LDM) for text-to-audio (TTA) generation. TANGO can generate realistic audios including human sounds, animal sounds, natural and artificial sounds and sound effects from textual prompts. We use the frozen instruction-tuned LLM Flan-T5 as the text encoder and train a UNet based diffusion model for audio generation. We perform comparably to current state-of-the-art models for TTA across both objective and subjective metrics, despite training the LDM on a 63 times smaller dataset. We release our model, training, inference code, and pre-trained checkpoints for the research community.
85
- '''
86
-
87
- # Define Gradio input and output components
88
- input_text = gr.inputs.Textbox(lines=2, label="Prompt")
89
- output_audio = gr.outputs.Audio(label="Generated Audio", type="filepath")
90
-
91
- # Create Gradio interface
92
- gr_interface = gr.Interface(
93
- fn=gradio_generate,
94
- inputs=input_text,
95
- outputs=[output_audio],
96
- title="Tango Audio Generator",
97
- description="Generate audio using Tango model by providing a text prompt.",
98
- allow_flagging=False,
99
- examples=[
100
- ["A Dog Barking"],
101
- ["A loud thunderstorm"],
102
- ],
103
- )
104
-
105
- # Launch Gradio app
106
- gr_interface.launch()