|  | import torch | 
					
						
						|  | import torchaudio | 
					
						
						|  | import os | 
					
						
						|  | from einops import rearrange | 
					
						
						|  | import gc | 
					
						
						|  | import spaces | 
					
						
						|  | import gradio as gr | 
					
						
						|  | import torch | 
					
						
						|  | import torchaudio | 
					
						
						|  | import os | 
					
						
						|  | from einops import rearrange | 
					
						
						|  | from stable_audio_tools import get_pretrained_model | 
					
						
						|  | from stable_audio_tools.inference.generation import generate_diffusion_cond | 
					
						
						|  | from stable_audio_tools.data.utils import read_video, merge_video_audio, load_and_process_audio | 
					
						
						|  | import stat | 
					
						
						|  | import platform | 
					
						
						|  | import logging | 
					
						
						|  | from transformers import logging as transformers_logging | 
					
						
						|  |  | 
					
						
						|  | transformers_logging.set_verbosity_error() | 
					
						
						|  | logging.getLogger("transformers").setLevel(logging.ERROR) | 
					
						
						|  |  | 
					
						
						|  | model, model_config = get_pretrained_model('HKUSTAudio/AudioX') | 
					
						
						|  | sample_rate = model_config["sample_rate"] | 
					
						
						|  | sample_size = model_config["sample_size"] | 
					
						
						|  |  | 
					
						
						|  | TEMP_DIR = "tmp/gradio" | 
					
						
						|  | os.makedirs(TEMP_DIR, exist_ok=True) | 
					
						
						|  | os.chmod(TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) | 
					
						
						|  |  | 
					
						
						|  | VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "videos") | 
					
						
						|  | os.makedirs(VIDEO_TEMP_DIR, exist_ok=True) | 
					
						
						|  | os.chmod(VIDEO_TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @spaces.GPU(duration=10) | 
					
						
						|  | def generate_cond( | 
					
						
						|  | prompt, | 
					
						
						|  | negative_prompt=None, | 
					
						
						|  | video_file=None, | 
					
						
						|  | audio_prompt_file=None, | 
					
						
						|  | audio_prompt_path=None, | 
					
						
						|  | seconds_start=0, | 
					
						
						|  | seconds_total=10, | 
					
						
						|  | cfg_scale=7.0, | 
					
						
						|  | steps=100, | 
					
						
						|  | preview_every=0, | 
					
						
						|  | seed=-1, | 
					
						
						|  | sampler_type="dpmpp-3m-sde", | 
					
						
						|  | sigma_min=0.03, | 
					
						
						|  | sigma_max=500, | 
					
						
						|  | cfg_rescale=0.0, | 
					
						
						|  | use_init=False, | 
					
						
						|  | init_audio=None, | 
					
						
						|  | init_noise_level=0.1, | 
					
						
						|  | mask_cropfrom=None, | 
					
						
						|  | mask_pastefrom=None, | 
					
						
						|  | mask_pasteto=None, | 
					
						
						|  | mask_maskstart=None, | 
					
						
						|  | mask_maskend=None, | 
					
						
						|  | mask_softnessL=None, | 
					
						
						|  | mask_softnessR=None, | 
					
						
						|  | mask_marination=None, | 
					
						
						|  | batch_size=1 | 
					
						
						|  | ): | 
					
						
						|  | if torch.cuda.is_available(): | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  | gc.collect() | 
					
						
						|  | print(f"Prompt: {prompt}") | 
					
						
						|  | preview_images = [] | 
					
						
						|  | if preview_every == 0: | 
					
						
						|  | preview_every = None | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() | 
					
						
						|  | except Exception: | 
					
						
						|  | has_mps = False | 
					
						
						|  | if has_mps: | 
					
						
						|  | device = torch.device("mps") | 
					
						
						|  | elif torch.cuda.is_available(): | 
					
						
						|  | device = torch.device("cuda") | 
					
						
						|  | else: | 
					
						
						|  | device = torch.device("cpu") | 
					
						
						|  |  | 
					
						
						|  | global model | 
					
						
						|  | model = model.to(device) | 
					
						
						|  |  | 
					
						
						|  | target_fps = model_config.get("video_fps", 5) | 
					
						
						|  | model_type = model_config.get("model_type", "diffusion_cond") | 
					
						
						|  |  | 
					
						
						|  | if video_file is not None: | 
					
						
						|  | actual_video_path = video_file['name'] if isinstance(video_file, dict) else video_file.name | 
					
						
						|  | else: | 
					
						
						|  | actual_video_path = None | 
					
						
						|  |  | 
					
						
						|  | if audio_prompt_file is not None: | 
					
						
						|  | audio_path = audio_prompt_file.name | 
					
						
						|  | elif audio_prompt_path: | 
					
						
						|  | audio_path = audio_prompt_path.strip() | 
					
						
						|  | else: | 
					
						
						|  | audio_path = None | 
					
						
						|  |  | 
					
						
						|  | Video_tensors = read_video(actual_video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) | 
					
						
						|  | audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) | 
					
						
						|  |  | 
					
						
						|  | audio_tensor = audio_tensor.to(device) | 
					
						
						|  | seconds_input = sample_size / sample_rate | 
					
						
						|  |  | 
					
						
						|  | if not prompt: | 
					
						
						|  | prompt = "" | 
					
						
						|  |  | 
					
						
						|  | conditioning = [{ | 
					
						
						|  | "video_prompt": [Video_tensors.unsqueeze(0)], | 
					
						
						|  | "text_prompt": prompt, | 
					
						
						|  | "audio_prompt": audio_tensor.unsqueeze(0), | 
					
						
						|  | "seconds_start": seconds_start, | 
					
						
						|  | "seconds_total": seconds_input | 
					
						
						|  | }] | 
					
						
						|  | if negative_prompt: | 
					
						
						|  | negative_conditioning = [{ | 
					
						
						|  | "video_prompt": [Video_tensors.unsqueeze(0)], | 
					
						
						|  | "text_prompt": negative_prompt, | 
					
						
						|  | "audio_prompt": audio_tensor.unsqueeze(0), | 
					
						
						|  | "seconds_start": seconds_start, | 
					
						
						|  | "seconds_total": seconds_total | 
					
						
						|  | }] * 1 | 
					
						
						|  | else: | 
					
						
						|  | negative_conditioning = None | 
					
						
						|  |  | 
					
						
						|  | seed = int(seed) | 
					
						
						|  | if not use_init: | 
					
						
						|  | init_audio = None | 
					
						
						|  | input_sample_size = sample_size | 
					
						
						|  |  | 
					
						
						|  | def progress_callback(callback_info): | 
					
						
						|  | nonlocal preview_images | 
					
						
						|  | denoised = callback_info["denoised"] | 
					
						
						|  | current_step = callback_info["i"] | 
					
						
						|  | sigma = callback_info["sigma"] | 
					
						
						|  | if (current_step - 1) % preview_every == 0: | 
					
						
						|  | if model.pretransform is not None: | 
					
						
						|  | denoised = model.pretransform.decode(denoised) | 
					
						
						|  | denoised = rearrange(denoised, "b d n -> d (b n)") | 
					
						
						|  | denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() | 
					
						
						|  | audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) | 
					
						
						|  | preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) | 
					
						
						|  |  | 
					
						
						|  | if model_type == "diffusion_cond": | 
					
						
						|  | audio = generate_diffusion_cond( | 
					
						
						|  | model, | 
					
						
						|  | conditioning=conditioning, | 
					
						
						|  | negative_conditioning=negative_conditioning, | 
					
						
						|  | steps=steps, | 
					
						
						|  | cfg_scale=cfg_scale, | 
					
						
						|  | batch_size=batch_size, | 
					
						
						|  | sample_size=input_sample_size, | 
					
						
						|  | sample_rate=sample_rate, | 
					
						
						|  | seed=seed, | 
					
						
						|  | device=device, | 
					
						
						|  | sampler_type=sampler_type, | 
					
						
						|  | sigma_min=sigma_min, | 
					
						
						|  | sigma_max=sigma_max, | 
					
						
						|  | init_audio=init_audio, | 
					
						
						|  | init_noise_level=init_noise_level, | 
					
						
						|  | mask_args=None, | 
					
						
						|  | callback=progress_callback if preview_every is not None else None, | 
					
						
						|  | scale_phi=cfg_rescale | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | audio = rearrange(audio, "b d n -> d (b n)") | 
					
						
						|  |  | 
					
						
						|  | samples_10s = 10 * sample_rate | 
					
						
						|  | audio = audio[:, :samples_10s] | 
					
						
						|  | audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() | 
					
						
						|  |  | 
					
						
						|  | output_dir = "demo_result" | 
					
						
						|  | os.makedirs(output_dir, exist_ok=True) | 
					
						
						|  | output_audio_path = f"{output_dir}/output.wav" | 
					
						
						|  | torchaudio.save(output_audio_path, audio, sample_rate) | 
					
						
						|  |  | 
					
						
						|  | if actual_video_path: | 
					
						
						|  | output_video_path = f"{output_dir}/{os.path.basename(actual_video_path)}" | 
					
						
						|  | target_width = 1280 | 
					
						
						|  | target_height = 720 | 
					
						
						|  | merge_video_audio( | 
					
						
						|  | actual_video_path, | 
					
						
						|  | output_audio_path, | 
					
						
						|  | output_video_path, | 
					
						
						|  | seconds_start, | 
					
						
						|  | seconds_total | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | output_video_path = None | 
					
						
						|  |  | 
					
						
						|  | del actual_video_path | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  | gc.collect() | 
					
						
						|  |  | 
					
						
						|  | return output_video_path, output_audio_path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with gr.Blocks() as interface: | 
					
						
						|  | gr.Markdown( | 
					
						
						|  | """ | 
					
						
						|  | # 🎧AudioX: Diffusion Transformer for Anything-to-Audio Generation | 
					
						
						|  | **[Paper](https://arxiv.org/abs/2503.10522) · [Project Page](https://zeyuet.github.io/AudioX/) · [Huggingface](https://huggingface.co/HKUSTAudio/AudioX) · [GitHub](https://github.com/ZeyueT/AudioX)** | 
					
						
						|  | """ | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with gr.Tab("Generation"): | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(): | 
					
						
						|  | prompt = gr.Textbox( | 
					
						
						|  | show_label=False, | 
					
						
						|  | placeholder="Enter your prompt" | 
					
						
						|  | ) | 
					
						
						|  | negative_prompt = gr.Textbox( | 
					
						
						|  | show_label=False, | 
					
						
						|  | placeholder="Negative prompt", | 
					
						
						|  | visible=False | 
					
						
						|  | ) | 
					
						
						|  | video_file = gr.File(label="Upload Video File") | 
					
						
						|  | audio_prompt_file = gr.File( | 
					
						
						|  | label="Upload Audio Prompt File", | 
					
						
						|  | visible=False | 
					
						
						|  | ) | 
					
						
						|  | audio_prompt_path = gr.Textbox( | 
					
						
						|  | label="Audio Prompt Path", | 
					
						
						|  | placeholder="Enter audio file path", | 
					
						
						|  | visible=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(scale=6): | 
					
						
						|  | with gr.Accordion("Video Params", open=False): | 
					
						
						|  | seconds_start = gr.Slider( | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=512, | 
					
						
						|  | step=1, | 
					
						
						|  | value=0, | 
					
						
						|  | label="Video Seconds Start" | 
					
						
						|  | ) | 
					
						
						|  | seconds_total = gr.Slider( | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=10, | 
					
						
						|  | step=1, | 
					
						
						|  | value=10, | 
					
						
						|  | label="Seconds Total", | 
					
						
						|  | interactive=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(scale=4): | 
					
						
						|  | with gr.Accordion("Sampler Params", open=False): | 
					
						
						|  | steps = gr.Slider( | 
					
						
						|  | minimum=1, | 
					
						
						|  | maximum=500, | 
					
						
						|  | step=1, | 
					
						
						|  | value=100, | 
					
						
						|  | label="Steps" | 
					
						
						|  | ) | 
					
						
						|  | preview_every = gr.Slider( | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=100, | 
					
						
						|  | step=1, | 
					
						
						|  | value=0, | 
					
						
						|  | label="Preview Every" | 
					
						
						|  | ) | 
					
						
						|  | cfg_scale = gr.Slider( | 
					
						
						|  | minimum=0.0, | 
					
						
						|  | maximum=25.0, | 
					
						
						|  | step=0.1, | 
					
						
						|  | value=7.0, | 
					
						
						|  | label="CFG Scale" | 
					
						
						|  | ) | 
					
						
						|  | seed = gr.Textbox( | 
					
						
						|  | label="Seed (set to -1 for random seed)", | 
					
						
						|  | value="-1" | 
					
						
						|  | ) | 
					
						
						|  | sampler_type = gr.Dropdown( | 
					
						
						|  | choices=[ | 
					
						
						|  | "dpmpp-2m-sde", | 
					
						
						|  | "dpmpp-3m-sde", | 
					
						
						|  | "k-heun", | 
					
						
						|  | "k-lms", | 
					
						
						|  | "k-dpmpp-2s-ancestral", | 
					
						
						|  | "k-dpm-2", | 
					
						
						|  | "k-dpm-fast" | 
					
						
						|  | ], | 
					
						
						|  | label="Sampler Type", | 
					
						
						|  | value="dpmpp-3m-sde" | 
					
						
						|  | ) | 
					
						
						|  | sigma_min = gr.Slider( | 
					
						
						|  | minimum=0.0, | 
					
						
						|  | maximum=2.0, | 
					
						
						|  | step=0.01, | 
					
						
						|  | value=0.03, | 
					
						
						|  | label="Sigma Min" | 
					
						
						|  | ) | 
					
						
						|  | sigma_max = gr.Slider( | 
					
						
						|  | minimum=0.0, | 
					
						
						|  | maximum=1000.0, | 
					
						
						|  | step=0.1, | 
					
						
						|  | value=500, | 
					
						
						|  | label="Sigma Max" | 
					
						
						|  | ) | 
					
						
						|  | cfg_rescale = gr.Slider( | 
					
						
						|  | minimum=0.0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | step=0.01, | 
					
						
						|  | value=0.0, | 
					
						
						|  | label="CFG Rescale Amount" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(scale=4): | 
					
						
						|  | with gr.Accordion("Init Audio", open=False, visible=False): | 
					
						
						|  | init_audio_checkbox = gr.Checkbox(label="Use Init Audio") | 
					
						
						|  | init_audio_input = gr.Audio(label="Init Audio") | 
					
						
						|  | init_noise_level = gr.Slider( | 
					
						
						|  | minimum=0.1, | 
					
						
						|  | maximum=100.0, | 
					
						
						|  | step=0.01, | 
					
						
						|  | value=0.1, | 
					
						
						|  | label="Init Noise Level" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | generate_button = gr.Button("Generate", variant="primary") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(scale=6): | 
					
						
						|  | video_output = gr.Video(label="Output Video", interactive=False) | 
					
						
						|  | audio_output = gr.Audio(label="Output Audio", interactive=False) | 
					
						
						|  |  | 
					
						
						|  | inputs = [ | 
					
						
						|  | prompt, | 
					
						
						|  | negative_prompt, | 
					
						
						|  | video_file, | 
					
						
						|  | audio_prompt_file, | 
					
						
						|  | audio_prompt_path, | 
					
						
						|  | seconds_start, | 
					
						
						|  | seconds_total, | 
					
						
						|  | cfg_scale, | 
					
						
						|  | steps, | 
					
						
						|  | preview_every, | 
					
						
						|  | seed, | 
					
						
						|  | sampler_type, | 
					
						
						|  | sigma_min, | 
					
						
						|  | sigma_max, | 
					
						
						|  | cfg_rescale, | 
					
						
						|  | init_audio_checkbox, | 
					
						
						|  | init_audio_input, | 
					
						
						|  | init_noise_level | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | generate_button.click( | 
					
						
						|  | fn=generate_cond, | 
					
						
						|  | inputs=inputs, | 
					
						
						|  | outputs=[video_output, audio_output] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("## Examples") | 
					
						
						|  | with gr.Accordion("Click to show examples", open=False): | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gr.Markdown("**📝 Task: Text-to-Audio**") | 
					
						
						|  | with gr.Column(scale=1.2): | 
					
						
						|  | gr.Markdown("Prompt: *Typing on a keyboard*") | 
					
						
						|  | ex1 = gr.Button("Load Example") | 
					
						
						|  | with gr.Column(scale=1.2): | 
					
						
						|  | gr.Markdown("Prompt: *Ocean waves crashing*") | 
					
						
						|  | ex2 = gr.Button("Load Example") | 
					
						
						|  | with gr.Column(scale=1.2): | 
					
						
						|  | gr.Markdown("Prompt: *Footsteps in snow*") | 
					
						
						|  | ex3 = gr.Button("Load Example") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gr.Markdown("**🎶 Task: Text-to-Music**") | 
					
						
						|  | with gr.Column(scale=1.2): | 
					
						
						|  | gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") | 
					
						
						|  | ex4 = gr.Button("Load Example") | 
					
						
						|  | with gr.Column(scale=1.2): | 
					
						
						|  | gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") | 
					
						
						|  | ex5 = gr.Button("Load Example") | 
					
						
						|  | with gr.Column(scale=1.2): | 
					
						
						|  | gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") | 
					
						
						|  | ex6 = gr.Button("Load Example") | 
					
						
						|  |  | 
					
						
						|  | ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | 
					
						
						|  | ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | 
					
						
						|  | ex3.click(lambda: ["Footsteps in snow", None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | 
					
						
						|  | ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | 
					
						
						|  | ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | 
					
						
						|  | ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | 
					
						
						|  |  | 
					
						
						|  | interface.queue(5).launch(server_name="0.0.0.0", server_port=2159, share=True) |