gyrojeff commited on
Commit
ac3ee6a
1 Parent(s): 65e514f

feat: add cli option for model name (log)

Browse files
Files changed (1) hide show
  1. train.py +8 -1
train.py CHANGED
@@ -77,6 +77,13 @@ parser.add_argument(
77
  default=["./dataset/font_img"],
78
  help="Datasets paths, seperated by space (default: ['./dataset/font_img'])",
79
  )
 
 
 
 
 
 
 
80
 
81
  args = parser.parse_args()
82
 
@@ -124,7 +131,7 @@ data_module = FontDataModule(
124
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
125
  num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs
126
 
127
- model_name = f"{get_current_tag()}"
128
 
129
  logger_unconditioned = TensorBoardLogger(
130
  save_dir=os.getcwd(), name="tensorboard", version=model_name
 
77
  default=["./dataset/font_img"],
78
  help="Datasets paths, seperated by space (default: ['./dataset/font_img'])",
79
  )
80
+ parser.add_argument(
81
+ "-n",
82
+ "--model-name",
83
+ type=str,
84
+ default=get_current_tag(),
85
+ help="Model name (default: current tag)",
86
+ )
87
 
88
  args = parser.parse_args()
89
 
 
131
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
132
  num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs
133
 
134
+ model_name = args.model_name
135
 
136
  logger_unconditioned = TensorBoardLogger(
137
  save_dir=os.getcwd(), name="tensorboard", version=model_name