Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,289 Bytes
5488167 071dfa9 5488167 3177554 5488167 4617cbd b030337 5488167 65fbe9a 071dfa9 b030337 5488167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import argparse
from ui.components import create_main_demo_ui
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler
import os
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--device_id", type=int, default=0)
parser.add_argument("--share", action='store_true', default=False)
parser.add_argument("--bf16", action='store_true', default=True)
parser.add_argument("--torch_compile", type=bool, default=False)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
persistent_storage_path = "/data"
def main(args):
model_demo = ACEStepPipeline(
checkpoint_dir=args.checkpoint_path,
dtype="bfloat16" if args.bf16 else "float32",
persistent_storage_path=persistent_storage_path,
torch_compile=args.torch_compile
)
data_sampler = DataSampler()
demo = create_main_demo_ui(
text2music_process_func=model_demo.__call__,
sample_data_func=data_sampler.sample,
)
demo.queue(default_concurrency_limit=8).launch(
)
if __name__ == "__main__":
main(args)
|