Sayoyo commited on
Commit
b030337
·
1 Parent(s): 3b7cae0

[feat] update args

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -7,10 +7,12 @@ import os
7
 
8
  parser = argparse.ArgumentParser()
9
  parser.add_argument("--checkpoint_path", type=str, default=None)
 
10
  parser.add_argument("--port", type=int, default=7860)
11
  parser.add_argument("--device_id", type=int, default=0)
12
  parser.add_argument("--share", action='store_true', default=False)
13
  parser.add_argument("--bf16", action='store_true', default=True)
 
14
 
15
  args = parser.parse_args()
16
  os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
@@ -25,7 +27,7 @@ def main(args):
25
  checkpoint_dir=args.checkpoint_path,
26
  dtype="bfloat16" if args.bf16 else "float32",
27
  persistent_storage_path=persistent_storage_path,
28
- torch_compile=True
29
  )
30
  data_sampler = DataSampler()
31
 
@@ -33,7 +35,11 @@ def main(args):
33
  text2music_process_func=model_demo.__call__,
34
  sample_data_func=data_sampler.sample,
35
  )
36
- demo.launch()
 
 
 
 
37
 
38
 
39
  if __name__ == "__main__":
 
7
 
8
  parser = argparse.ArgumentParser()
9
  parser.add_argument("--checkpoint_path", type=str, default=None)
10
+ parser.add_argument("--server_name", type=str, default="0.0.0.0")
11
  parser.add_argument("--port", type=int, default=7860)
12
  parser.add_argument("--device_id", type=int, default=0)
13
  parser.add_argument("--share", action='store_true', default=False)
14
  parser.add_argument("--bf16", action='store_true', default=True)
15
+ parser.add_argument("--torch_compile", type=bool, default=True)
16
 
17
  args = parser.parse_args()
18
  os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
 
27
  checkpoint_dir=args.checkpoint_path,
28
  dtype="bfloat16" if args.bf16 else "float32",
29
  persistent_storage_path=persistent_storage_path,
30
+ torch_compile=args.torch_compile
31
  )
32
  data_sampler = DataSampler()
33
 
 
35
  text2music_process_func=model_demo.__call__,
36
  sample_data_func=data_sampler.sample,
37
  )
38
+ demo.launch(
39
+ server_name=args.server_name,
40
+ server_port=args.port,
41
+ share=args.share
42
+ )
43
 
44
 
45
  if __name__ == "__main__":