|
import torch |
|
import torchaudio |
|
import os |
|
from einops import rearrange |
|
import gc |
|
import spaces |
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import os |
|
from einops import rearrange |
|
from stable_audio_tools import get_pretrained_model |
|
from stable_audio_tools.inference.generation import generate_diffusion_cond |
|
from stable_audio_tools.data.utils import read_video, merge_video_audio, load_and_process_audio |
|
import stat |
|
import platform |
|
import logging |
|
from transformers import logging as transformers_logging |
|
|
|
transformers_logging.set_verbosity_error() |
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
|
model, model_config = get_pretrained_model('HKUSTAudio/AudioX') |
|
sample_rate = model_config["sample_rate"] |
|
sample_size = model_config["sample_size"] |
|
|
|
TEMP_DIR = "tmp/gradio" |
|
os.makedirs(TEMP_DIR, exist_ok=True) |
|
os.chmod(TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
|
VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "videos") |
|
os.makedirs(VIDEO_TEMP_DIR, exist_ok=True) |
|
os.chmod(VIDEO_TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
|
|
|
|
|
@spaces.GPU(duration=10) |
|
def generate_cond( |
|
prompt, |
|
negative_prompt=None, |
|
video_file=None, |
|
audio_prompt_file=None, |
|
audio_prompt_path=None, |
|
seconds_start=0, |
|
seconds_total=10, |
|
cfg_scale=7.0, |
|
steps=100, |
|
preview_every=0, |
|
seed=-1, |
|
sampler_type="dpmpp-3m-sde", |
|
sigma_min=0.03, |
|
sigma_max=500, |
|
cfg_rescale=0.0, |
|
use_init=False, |
|
init_audio=None, |
|
init_noise_level=0.1, |
|
mask_cropfrom=None, |
|
mask_pastefrom=None, |
|
mask_pasteto=None, |
|
mask_maskstart=None, |
|
mask_maskend=None, |
|
mask_softnessL=None, |
|
mask_softnessR=None, |
|
mask_marination=None, |
|
batch_size=1 |
|
): |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
print(f"Prompt: {prompt}") |
|
preview_images = [] |
|
if preview_every == 0: |
|
preview_every = None |
|
|
|
try: |
|
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() |
|
except Exception: |
|
has_mps = False |
|
if has_mps: |
|
device = torch.device("mps") |
|
elif torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
global model |
|
model = model.to(device) |
|
|
|
target_fps = model_config.get("video_fps", 5) |
|
model_type = model_config.get("model_type", "diffusion_cond") |
|
|
|
if video_file is not None: |
|
actual_video_path = video_file['name'] if isinstance(video_file, dict) else video_file.name |
|
else: |
|
actual_video_path = None |
|
|
|
if audio_prompt_file is not None: |
|
audio_path = audio_prompt_file.name |
|
elif audio_prompt_path: |
|
audio_path = audio_prompt_path.strip() |
|
else: |
|
audio_path = None |
|
|
|
Video_tensors = read_video(actual_video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) |
|
audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) |
|
|
|
audio_tensor = audio_tensor.to(device) |
|
seconds_input = sample_size / sample_rate |
|
|
|
if not prompt: |
|
prompt = "" |
|
|
|
conditioning = [{ |
|
"video_prompt": [Video_tensors.unsqueeze(0)], |
|
"text_prompt": prompt, |
|
"audio_prompt": audio_tensor.unsqueeze(0), |
|
"seconds_start": seconds_start, |
|
"seconds_total": seconds_input |
|
}] |
|
if negative_prompt: |
|
negative_conditioning = [{ |
|
"video_prompt": [Video_tensors.unsqueeze(0)], |
|
"text_prompt": negative_prompt, |
|
"audio_prompt": audio_tensor.unsqueeze(0), |
|
"seconds_start": seconds_start, |
|
"seconds_total": seconds_total |
|
}] * 1 |
|
else: |
|
negative_conditioning = None |
|
|
|
seed = int(seed) |
|
if not use_init: |
|
init_audio = None |
|
input_sample_size = sample_size |
|
|
|
def progress_callback(callback_info): |
|
nonlocal preview_images |
|
denoised = callback_info["denoised"] |
|
current_step = callback_info["i"] |
|
sigma = callback_info["sigma"] |
|
if (current_step - 1) % preview_every == 0: |
|
if model.pretransform is not None: |
|
denoised = model.pretransform.decode(denoised) |
|
denoised = rearrange(denoised, "b d n -> d (b n)") |
|
denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() |
|
audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) |
|
preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) |
|
|
|
if model_type == "diffusion_cond": |
|
audio = generate_diffusion_cond( |
|
model, |
|
conditioning=conditioning, |
|
negative_conditioning=negative_conditioning, |
|
steps=steps, |
|
cfg_scale=cfg_scale, |
|
batch_size=batch_size, |
|
sample_size=input_sample_size, |
|
sample_rate=sample_rate, |
|
seed=seed, |
|
device=device, |
|
sampler_type=sampler_type, |
|
sigma_min=sigma_min, |
|
sigma_max=sigma_max, |
|
init_audio=init_audio, |
|
init_noise_level=init_noise_level, |
|
mask_args=None, |
|
callback=progress_callback if preview_every is not None else None, |
|
scale_phi=cfg_rescale |
|
) |
|
|
|
audio = rearrange(audio, "b d n -> d (b n)") |
|
|
|
samples_10s = 10 * sample_rate |
|
audio = audio[:, :samples_10s] |
|
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() |
|
|
|
output_dir = "demo_result" |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_audio_path = f"{output_dir}/output.wav" |
|
torchaudio.save(output_audio_path, audio, sample_rate) |
|
|
|
if actual_video_path: |
|
output_video_path = f"{output_dir}/{os.path.basename(actual_video_path)}" |
|
target_width = 1280 |
|
target_height = 720 |
|
merge_video_audio( |
|
actual_video_path, |
|
output_audio_path, |
|
output_video_path, |
|
seconds_start, |
|
seconds_total |
|
) |
|
else: |
|
output_video_path = None |
|
|
|
del actual_video_path |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return output_video_path, output_audio_path |
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown( |
|
""" |
|
# 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation |
|
**[Paper](https://arxiv.org/abs/2503.10522) · [Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/HKUSTAudio/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)** |
|
""" |
|
) |
|
|
|
with gr.Tab("Generation"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter your prompt" |
|
) |
|
negative_prompt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Negative prompt", |
|
visible=False |
|
) |
|
video_file = gr.File(label="Upload Video File") |
|
audio_prompt_file = gr.File( |
|
label="Upload Audio Prompt File", |
|
visible=False |
|
) |
|
audio_prompt_path = gr.Textbox( |
|
label="Audio Prompt Path", |
|
placeholder="Enter audio file path", |
|
visible=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=6): |
|
with gr.Accordion("Video Params", open=False): |
|
seconds_start = gr.Slider( |
|
minimum=0, |
|
maximum=512, |
|
step=1, |
|
value=0, |
|
label="Video Seconds Start" |
|
) |
|
seconds_total = gr.Slider( |
|
minimum=0, |
|
maximum=10, |
|
step=1, |
|
value=10, |
|
label="Seconds Total", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
with gr.Accordion("Sampler Params", open=False): |
|
steps = gr.Slider( |
|
minimum=1, |
|
maximum=500, |
|
step=1, |
|
value=100, |
|
label="Steps" |
|
) |
|
preview_every = gr.Slider( |
|
minimum=0, |
|
maximum=100, |
|
step=1, |
|
value=0, |
|
label="Preview Every" |
|
) |
|
cfg_scale = gr.Slider( |
|
minimum=0.0, |
|
maximum=25.0, |
|
step=0.1, |
|
value=7.0, |
|
label="CFG Scale" |
|
) |
|
seed = gr.Textbox( |
|
label="Seed (set to -1 for random seed)", |
|
value="-1" |
|
) |
|
sampler_type = gr.Dropdown( |
|
choices=[ |
|
"dpmpp-2m-sde", |
|
"dpmpp-3m-sde", |
|
"k-heun", |
|
"k-lms", |
|
"k-dpmpp-2s-ancestral", |
|
"k-dpm-2", |
|
"k-dpm-fast" |
|
], |
|
label="Sampler Type", |
|
value="dpmpp-3m-sde" |
|
) |
|
sigma_min = gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
step=0.01, |
|
value=0.03, |
|
label="Sigma Min" |
|
) |
|
sigma_max = gr.Slider( |
|
minimum=0.0, |
|
maximum=1000.0, |
|
step=0.1, |
|
value=500, |
|
label="Sigma Max" |
|
) |
|
cfg_rescale = gr.Slider( |
|
minimum=0.0, |
|
maximum=1, |
|
step=0.01, |
|
value=0.0, |
|
label="CFG Rescale Amount" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
with gr.Accordion("Init Audio", open=False, visible=False): |
|
init_audio_checkbox = gr.Checkbox(label="Use Init Audio") |
|
init_audio_input = gr.Audio(label="Init Audio") |
|
init_noise_level = gr.Slider( |
|
minimum=0.1, |
|
maximum=100.0, |
|
step=0.01, |
|
value=0.1, |
|
label="Init Noise Level" |
|
) |
|
|
|
with gr.Row(): |
|
generate_button = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=6): |
|
video_output = gr.Video(label="Output Video", interactive=False) |
|
audio_output = gr.Audio(label="Output Audio", interactive=False) |
|
|
|
inputs = [ |
|
prompt, |
|
negative_prompt, |
|
video_file, |
|
audio_prompt_file, |
|
audio_prompt_path, |
|
seconds_start, |
|
seconds_total, |
|
cfg_scale, |
|
steps, |
|
preview_every, |
|
seed, |
|
sampler_type, |
|
sigma_min, |
|
sigma_max, |
|
cfg_rescale, |
|
init_audio_checkbox, |
|
init_audio_input, |
|
init_noise_level |
|
] |
|
|
|
generate_button.click( |
|
fn=generate_cond, |
|
inputs=inputs, |
|
outputs=[video_output, audio_output] |
|
) |
|
|
|
gr.Markdown("## Examples") |
|
with gr.Accordion("Click to show examples", open=False): |
|
with gr.Row(): |
|
gr.Markdown("**📝 Task: Text-to-Audio**") |
|
with gr.Column(scale=1.2): |
|
gr.Markdown("Prompt: *Typing on a keyboard*") |
|
ex1 = gr.Button("Load Example") |
|
with gr.Column(scale=1.2): |
|
gr.Markdown("Prompt: *Ocean waves crashing*") |
|
ex2 = gr.Button("Load Example") |
|
with gr.Column(scale=1.2): |
|
gr.Markdown("Prompt: *Footsteps in snow*") |
|
ex3 = gr.Button("Load Example") |
|
|
|
with gr.Row(): |
|
gr.Markdown("**🎶 Task: Text-to-Music**") |
|
with gr.Column(scale=1.2): |
|
gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") |
|
ex4 = gr.Button("Load Example") |
|
with gr.Column(scale=1.2): |
|
gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") |
|
ex5 = gr.Button("Load Example") |
|
with gr.Column(scale=1.2): |
|
gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") |
|
ex6 = gr.Button("Load Example") |
|
|
|
ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
|
ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
|
ex3.click(lambda: ["Footsteps in snow", None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
|
ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
|
ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
|
ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) |
|
|
|
interface.queue(5).launch(server_name="0.0.0.0", server_port=2159, share=True) |