Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
from ui.components import create_main_demo_ui | |
from pipeline_ace_step import ACEStepPipeline | |
from data_sampler import DataSampler | |
import os | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--checkpoint_path", type=str, default=None) | |
parser.add_argument("--server_name", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--device_id", type=int, default=0) | |
parser.add_argument("--share", action='store_true', default=False) | |
parser.add_argument("--bf16", action='store_true', default=True) | |
parser.add_argument("--torch_compile", type=bool, default=False) | |
args = parser.parse_args() | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) | |
persistent_storage_path = "/data" | |
def main(args): | |
model_demo = ACEStepPipeline( | |
checkpoint_dir=args.checkpoint_path, | |
dtype="bfloat16" if args.bf16 else "float32", | |
persistent_storage_path=persistent_storage_path, | |
torch_compile=args.torch_compile | |
) | |
data_sampler = DataSampler() | |
demo = create_main_demo_ui( | |
text2music_process_func=model_demo.__call__, | |
sample_data_func=data_sampler.sample, | |
) | |
demo.queue(default_concurrency_limit=8).launch( | |
) | |
if __name__ == "__main__": | |
main(args) | |