Spaces:
Runtime error
Runtime error
[feat] update args
Browse files
app.py
CHANGED
|
@@ -7,10 +7,12 @@ import os
|
|
| 7 |
|
| 8 |
parser = argparse.ArgumentParser()
|
| 9 |
parser.add_argument("--checkpoint_path", type=str, default=None)
|
|
|
|
| 10 |
parser.add_argument("--port", type=int, default=7860)
|
| 11 |
parser.add_argument("--device_id", type=int, default=0)
|
| 12 |
parser.add_argument("--share", action='store_true', default=False)
|
| 13 |
parser.add_argument("--bf16", action='store_true', default=True)
|
|
|
|
| 14 |
|
| 15 |
args = parser.parse_args()
|
| 16 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
|
@@ -25,7 +27,7 @@ def main(args):
|
|
| 25 |
checkpoint_dir=args.checkpoint_path,
|
| 26 |
dtype="bfloat16" if args.bf16 else "float32",
|
| 27 |
persistent_storage_path=persistent_storage_path,
|
| 28 |
-
torch_compile=
|
| 29 |
)
|
| 30 |
data_sampler = DataSampler()
|
| 31 |
|
|
@@ -33,7 +35,11 @@ def main(args):
|
|
| 33 |
text2music_process_func=model_demo.__call__,
|
| 34 |
sample_data_func=data_sampler.sample,
|
| 35 |
)
|
| 36 |
-
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
if __name__ == "__main__":
|
|
|
|
| 7 |
|
| 8 |
parser = argparse.ArgumentParser()
|
| 9 |
parser.add_argument("--checkpoint_path", type=str, default=None)
|
| 10 |
+
parser.add_argument("--server_name", type=str, default="0.0.0.0")
|
| 11 |
parser.add_argument("--port", type=int, default=7860)
|
| 12 |
parser.add_argument("--device_id", type=int, default=0)
|
| 13 |
parser.add_argument("--share", action='store_true', default=False)
|
| 14 |
parser.add_argument("--bf16", action='store_true', default=True)
|
| 15 |
+
parser.add_argument("--torch_compile", type=bool, default=True)
|
| 16 |
|
| 17 |
args = parser.parse_args()
|
| 18 |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
|
|
|
| 27 |
checkpoint_dir=args.checkpoint_path,
|
| 28 |
dtype="bfloat16" if args.bf16 else "float32",
|
| 29 |
persistent_storage_path=persistent_storage_path,
|
| 30 |
+
torch_compile=args.torch_compile
|
| 31 |
)
|
| 32 |
data_sampler = DataSampler()
|
| 33 |
|
|
|
|
| 35 |
text2music_process_func=model_demo.__call__,
|
| 36 |
sample_data_func=data_sampler.sample,
|
| 37 |
)
|
| 38 |
+
demo.launch(
|
| 39 |
+
server_name=args.server_name,
|
| 40 |
+
server_port=args.port,
|
| 41 |
+
share=args.share
|
| 42 |
+
)
|
| 43 |
|
| 44 |
|
| 45 |
if __name__ == "__main__":
|