|
import argparse |
|
from ui.components import create_main_demo_ui |
|
from pipeline_ace_step import ACEStepPipeline |
|
from data_sampler import DataSampler |
|
import os |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--checkpoint_path", type=str, default=None) |
|
parser.add_argument("--server_name", type=str, default="0.0.0.0") |
|
parser.add_argument("--port", type=int, default=7860) |
|
parser.add_argument("--device_id", type=int, default=0) |
|
parser.add_argument("--share", action='store_true', default=False) |
|
parser.add_argument("--bf16", action='store_true', default=True) |
|
parser.add_argument("--torch_compile", type=bool, default=False) |
|
|
|
args = parser.parse_args() |
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) |
|
|
|
|
|
persistent_storage_path = "/data" |
|
|
|
|
|
def main(args): |
|
|
|
model_demo = ACEStepPipeline( |
|
checkpoint_dir=args.checkpoint_path, |
|
dtype="bfloat16" if args.bf16 else "float32", |
|
persistent_storage_path=persistent_storage_path, |
|
torch_compile=args.torch_compile |
|
) |
|
data_sampler = DataSampler() |
|
|
|
demo = create_main_demo_ui( |
|
text2music_process_func=model_demo.__call__, |
|
sample_data_func=data_sampler.sample, |
|
load_data_func=data_sampler.load_json, |
|
) |
|
demo.queue(default_concurrency_limit=8).launch( |
|
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main(args) |
|
|