Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,625 Bytes
a89c362 b409041 a89c362 629d1bf c43b5ca b409041 c43b5ca 629d1bf a89c362 7d00b29 629d1bf 49bb4d0 9a7456a 629d1bf f4c1365 a89c362 629d1bf a89c362 361d70a 101c1cd 629d1bf c43b5ca 629d1bf a89c362 629d1bf d1abaff 629d1bf d1abaff 629d1bf f342085 6116af2 f342085 629d1bf 101c1cd 629d1bf d1abaff 629d1bf 101c1cd 629d1bf 101c1cd 629d1bf 101c1cd 629d1bf 101c1cd 629d1bf 101c1cd 629d1bf d1abaff 629d1bf 101c1cd 629d1bf 101c1cd 629d1bf 101c1cd 629d1bf d1abaff 629d1bf 2c5a278 ab6542d 629d1bf a89c362 d1abaff 629d1bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import spaces
import os
import json
import numpy as np
import torch
import soundfile as sf
import gradio as gr
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class InferRunner:
def __init__(self, device):
vae_config = json.load(open("ckpts/ldm/vae_config.json"))
self.vae = AutoencoderKL(**vae_config)
vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location="cpu")
self.vae.load_state_dict(vae_weights)
self.vae = self.vae.to(device)
train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
self.pico_model = PicoDiffusion(
scheduler_name=train_args.scheduler_name,
unet_model_config_path=train_args.unet_model_config,
snr_gamma=train_args.snr_gamma,
freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
diffusion_pt="ckpts/pico_model/diffusion.pt",
).eval().to(device)
self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"
runner = InferRunner(device)
event_list = get_event()
@spaces.GPU(duration=240)
def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
with torch.no_grad():
latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
mel = runner.vae.decode_first_stage(latents)
wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
outpath = f"output.wav"
sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
return outpath
def preprocess(caption):
output = preprocess_gemini(caption)
return output, output
def update_textbox(event_name, current_text):
event = event_name + ' two times.'
if current_text:
return current_text.strip('.') + ' then ' + event
else:
return event
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("## PicoAudio")
with gr.Row():
gr.Markdown("""
[![arXiv](https://img.shields.io/badge/arXiv-2407.02869v2-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.02869v2)
[![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://zeyuxie29.github.io/PicoAudio.github.io/)
[![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/zeyuxie29/PicoAudio)
[![Hugging Face data](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Dataset-blue)](https://huggingface.co/datasets/ZeyuXie/PicoAudio/tree/main)
""")
with gr.Row():
description_text = f"18 events supported:"
gr.Markdown(description_text)
btn_event = []
with gr.Row():
for i in range(6):
event_name = f"{event_list[i]}"
btn_event.append(gr.Button(event_name))
with gr.Row():
for i in range(6, 12):
event_name = f"{event_list[i]}"
btn_event.append(gr.Button(event_name))
with gr.Row():
for i in range(12, 18):
event_name = f"{event_list[i]}"
btn_event.append(gr.Button(event_name))
with gr.Row():
gr.Markdown("## Step1-Preprocess")
with gr.Row():
preprocess_description_text = f"Transfer free-text into timestamp caption via LLM. "+\
"This demo uses Gemini as the preprocessor. If any errors occur, please try a few more times. "+\
"We also provide the GPT version consistent with the paper in the file 'Files/llm_reprocessing.py'. You can use your own api_key to modify and run 'Files/inference.py' for local inference."
gr.Markdown(preprocess_description_text)
with gr.Row():
with gr.Column():
freetext_prompt = gr.Textbox(label="Free-text Prompt: Input your free-text caption here. (e.g. a dog barks three times.)",
value="a dog barks three times.",)
with gr.Row():
preprocess_run_button = gr.Button()
preprocess_run_clear = gr.ClearButton([freetext_prompt])
prompt = None
with gr.Column():
freetext_prompt_out = gr.Textbox(label="Timestamp Caption: Preprocess output")
with gr.Row():
with gr.Column():
gr.Examples(
examples = [["spraying two times then gunshot three times."],
["a dog barks three times."],
["cow mooing two times."],],
inputs = [freetext_prompt],
outputs = [prompt]
)
with gr.Column():
pass
with gr.Row():
gr.Markdown("## Step2-Generate")
with gr.Row():
generate_description_text = f"Generate audio based on timestamp caption."
gr.Markdown(generate_description_text)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Timestamp Caption: Specify your timestamp caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
with gr.Row():
generate_run_button = gr.Button()
generate_run_clear = gr.ClearButton([prompt])
with gr.Accordion("Advanced options", open=False):
num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
with gr.Column():
outaudio = gr.Audio()
for i in range(18):
event_name = f"{event_list[i]}"
btn_event[i].click(fn=update_textbox, inputs=[gr.State(event_name), freetext_prompt], outputs=freetext_prompt)
preprocess_run_button.click(fn=preprocess, inputs=[freetext_prompt], outputs=[prompt, freetext_prompt_out])
generate_run_button.click(fn=infer, inputs=[prompt, num_steps, guidance_scale], outputs=[outaudio])
with gr.Row():
with gr.Column():
gr.Examples(
examples = [["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
["dog_barking at 0.562-2.562_4.25-6.25."],
["cow_mooing at 0.958-3.582_5.272-7.896."],
["tapping_clicking_clanking at 0.579-4.019_5.882-9.322"],
["duck_quacking at 1.51-2.51_4.904-5.904"],
],
inputs = [prompt, num_steps, guidance_scale],
outputs = [outaudio]
)
with gr.Column():
pass
demo.launch()
# description_text = f"18 events: {', '.join(event_list)}"
# prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
# value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
# outaudio = gr.Audio()
# num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
# guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
# gr_interface = gr.Interface(
# fn=infer,
# inputs=[prompt, num_steps, guidance_scale],
# outputs=[outaudio],
# title="PicoAudio",
# description=description_text,
# allow_flagging=False,
# examples=[
# ["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
# ["dog_barking at 0.562-2.562_4.25-6.25."],
# ["cow_mooing at 0.958-3.582_5.272-7.896."],
# ],
# cache_examples="lazy", # Turn on to cache.
# )
# gr_interface.queue(10).launch() |