import argparse from ui.simple_components import create_simple_ui from pipeline_ace_step import ACEStepPipeline from data_sampler import DataSampler import os import gradio as gr parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_path", type=str, default=None) parser.add_argument("--server_name", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--device_id", type=int, default=0) parser.add_argument("--share", action='store_true', default=False) parser.add_argument("--bf16", action='store_true', default=True) parser.add_argument("--torch_compile", type=bool, default=False) args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) persistent_storage_path = "./data" def main(args): model_demo = ACEStepPipeline( checkpoint_dir=args.checkpoint_path, dtype="bfloat16" if args.bf16 else "float32", persistent_storage_path=persistent_storage_path, torch_compile=args.torch_compile ) data_sampler = DataSampler() # Create API function for external calls def generate_music_api( duration: float = 20.0, tags: str = "edm, synth, bass, kick drum, 128 bpm, euphoric, pulsating, energetic, instrumental", lyrics: str = "[instrumental]", infer_steps: int = 60, guidance_scale: float = 15.0, ): """ API function to generate music Args: duration: Duration in seconds (default 20) tags: Music tags/style description lyrics: Lyrics or [instrumental] for no vocals infer_steps: Inference steps (default 60) guidance_scale: Guidance scale (default 15.0) Returns: audio_path: Path to generated audio file """ result = model_demo( audio_duration=duration, prompt=tags, lyrics=lyrics, infer_step=infer_steps, guidance_scale=guidance_scale, scheduler_type="euler", cfg_type="apg", omega_scale=10.0, manual_seeds=None, guidance_interval=0.5, guidance_interval_decay=0.0, min_guidance_scale=3.0, use_erg_tag=True, use_erg_lyric=False, use_erg_diffusion=True, oss_steps=None, guidance_scale_text=0.0, guidance_scale_lyric=0.0, audio2audio_enable=False, ref_audio_strength=0.5, ref_audio_input=None, lora_name_or_path="none" ) # Return the audio file path if result and len(result) > 0: return result[0] # Return first audio output (now always 24kHz WAV) return None # Use simplified UI demo = create_simple_ui( text2music_process_func=model_demo.__call__ ) # Add API endpoint to the demo demo.api_open = True demo.queue(default_concurrency_limit=8).launch( server_name=args.server_name, server_port=args.port, share=args.share ) if __name__ == "__main__": main(args)