import torch import torchaudio import os from einops import rearrange import gc import spaces import gradio as gr import torch import torchaudio import os from einops import rearrange from stable_audio_tools import get_pretrained_model from stable_audio_tools.inference.generation import generate_diffusion_cond from stable_audio_tools.data.utils import read_video, merge_video_audio, load_and_process_audio import stat import platform import logging from transformers import logging as transformers_logging transformers_logging.set_verbosity_error() logging.getLogger("transformers").setLevel(logging.ERROR) model, model_config = get_pretrained_model('HKUSTAudio/AudioX') sample_rate = model_config["sample_rate"] sample_size = model_config["sample_size"] TEMP_DIR = "tmp/gradio" os.makedirs(TEMP_DIR, exist_ok=True) os.chmod(TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "videos") os.makedirs(VIDEO_TEMP_DIR, exist_ok=True) os.chmod(VIDEO_TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) @spaces.GPU(duration=10) def generate_cond( prompt, negative_prompt=None, video_file=None, audio_prompt_file=None, audio_prompt_path=None, seconds_start=0, seconds_total=10, cfg_scale=7.0, steps=100, preview_every=0, seed=-1, sampler_type="dpmpp-3m-sde", sigma_min=0.03, sigma_max=500, cfg_rescale=0.0, use_init=False, init_audio=None, init_noise_level=0.1, mask_cropfrom=None, mask_pastefrom=None, mask_pasteto=None, mask_maskstart=None, mask_maskend=None, mask_softnessL=None, mask_softnessR=None, mask_marination=None, batch_size=1 ): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print(f"Prompt: {prompt}") preview_images = [] if preview_every == 0: preview_every = None try: has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() except Exception: has_mps = False if has_mps: device = torch.device("mps") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") global model model = model.to(device) target_fps = model_config.get("video_fps", 5) model_type = model_config.get("model_type", "diffusion_cond") if video_file is not None: actual_video_path = video_file['name'] if isinstance(video_file, dict) else video_file.name else: actual_video_path = None if audio_prompt_file is not None: audio_path = audio_prompt_file.name elif audio_prompt_path: audio_path = audio_prompt_path.strip() else: audio_path = None Video_tensors = read_video(actual_video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) audio_tensor = audio_tensor.to(device) seconds_input = sample_size / sample_rate if not prompt: prompt = "" conditioning = [{ "video_prompt": [Video_tensors.unsqueeze(0)], "text_prompt": prompt, "audio_prompt": audio_tensor.unsqueeze(0), "seconds_start": seconds_start, "seconds_total": seconds_input }] if negative_prompt: negative_conditioning = [{ "video_prompt": [Video_tensors.unsqueeze(0)], "text_prompt": negative_prompt, "audio_prompt": audio_tensor.unsqueeze(0), "seconds_start": seconds_start, "seconds_total": seconds_total }] * 1 else: negative_conditioning = None seed = int(seed) if not use_init: init_audio = None input_sample_size = sample_size def progress_callback(callback_info): nonlocal preview_images denoised = callback_info["denoised"] current_step = callback_info["i"] sigma = callback_info["sigma"] if (current_step - 1) % preview_every == 0: if model.pretransform is not None: denoised = model.pretransform.decode(denoised) denoised = rearrange(denoised, "b d n -> d (b n)") denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) if model_type == "diffusion_cond": audio = generate_diffusion_cond( model, conditioning=conditioning, negative_conditioning=negative_conditioning, steps=steps, cfg_scale=cfg_scale, batch_size=batch_size, sample_size=input_sample_size, sample_rate=sample_rate, seed=seed, device=device, sampler_type=sampler_type, sigma_min=sigma_min, sigma_max=sigma_max, init_audio=init_audio, init_noise_level=init_noise_level, mask_args=None, callback=progress_callback if preview_every is not None else None, scale_phi=cfg_rescale ) audio = rearrange(audio, "b d n -> d (b n)") samples_10s = 10 * sample_rate audio = audio[:, :samples_10s] audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() output_dir = "demo_result" os.makedirs(output_dir, exist_ok=True) output_audio_path = f"{output_dir}/output.wav" torchaudio.save(output_audio_path, audio, sample_rate) if actual_video_path: output_video_path = f"{output_dir}/{os.path.basename(actual_video_path)}" target_width = 1280 target_height = 720 merge_video_audio( actual_video_path, output_audio_path, output_video_path, seconds_start, seconds_total ) else: output_video_path = None del actual_video_path torch.cuda.empty_cache() gc.collect() return output_video_path, output_audio_path with gr.Blocks() as interface: gr.Markdown( """ # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation **[Paper](https://arxiv.org/abs/2503.10522) · [Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/HKUSTAudio/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)** """ ) with gr.Tab("Generation"): with gr.Row(): with gr.Column(): prompt = gr.Textbox( show_label=False, placeholder="Enter your prompt" ) negative_prompt = gr.Textbox( show_label=False, placeholder="Negative prompt", visible=False ) video_file = gr.File(label="Upload Video File") audio_prompt_file = gr.File( label="Upload Audio Prompt File", visible=False ) audio_prompt_path = gr.Textbox( label="Audio Prompt Path", placeholder="Enter audio file path", visible=False ) with gr.Row(): with gr.Column(scale=6): with gr.Accordion("Video Params", open=False): seconds_start = gr.Slider( minimum=0, maximum=512, step=1, value=0, label="Video Seconds Start" ) seconds_total = gr.Slider( minimum=0, maximum=10, step=1, value=10, label="Seconds Total", interactive=False ) with gr.Row(): with gr.Column(scale=4): with gr.Accordion("Sampler Params", open=False): steps = gr.Slider( minimum=1, maximum=500, step=1, value=100, label="Steps" ) preview_every = gr.Slider( minimum=0, maximum=100, step=1, value=0, label="Preview Every" ) cfg_scale = gr.Slider( minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG Scale" ) seed = gr.Textbox( label="Seed (set to -1 for random seed)", value="-1" ) sampler_type = gr.Dropdown( choices=[ "dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast" ], label="Sampler Type", value="dpmpp-3m-sde" ) sigma_min = gr.Slider( minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma Min" ) sigma_max = gr.Slider( minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma Max" ) cfg_rescale = gr.Slider( minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG Rescale Amount" ) with gr.Row(): with gr.Column(scale=4): with gr.Accordion("Init Audio", open=False, visible=False): init_audio_checkbox = gr.Checkbox(label="Use Init Audio") init_audio_input = gr.Audio(label="Init Audio") init_noise_level = gr.Slider( minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init Noise Level" ) with gr.Row(): generate_button = gr.Button("Generate", variant="primary") with gr.Row(): with gr.Column(scale=6): video_output = gr.Video(label="Output Video", interactive=False) audio_output = gr.Audio(label="Output Audio", interactive=False) inputs = [ prompt, negative_prompt, video_file, audio_prompt_file, audio_prompt_path, seconds_start, seconds_total, cfg_scale, steps, preview_every, seed, sampler_type, sigma_min, sigma_max, cfg_rescale, init_audio_checkbox, init_audio_input, init_noise_level ] generate_button.click( fn=generate_cond, inputs=inputs, outputs=[video_output, audio_output] ) gr.Markdown("## Examples") with gr.Accordion("Click to show examples", open=False): with gr.Row(): gr.Markdown("**📝 Task: Text-to-Audio**") with gr.Column(scale=1.2): gr.Markdown("Prompt: *Typing on a keyboard*") ex1 = gr.Button("Load Example") with gr.Column(scale=1.2): gr.Markdown("Prompt: *Ocean waves crashing*") ex2 = gr.Button("Load Example") with gr.Column(scale=1.2): gr.Markdown("Prompt: *Footsteps in snow*") ex3 = gr.Button("Load Example") with gr.Row(): gr.Markdown("**🎶 Task: Text-to-Music**") with gr.Column(scale=1.2): gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") ex4 = gr.Button("Load Example") with gr.Column(scale=1.2): gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") ex5 = gr.Button("Load Example") with gr.Column(scale=1.2): gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") ex6 = gr.Button("Load Example") ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex3.click(lambda: ["Footsteps in snow", None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) interface.queue(5).launch(server_name="0.0.0.0", server_port=2159, share=True)