feat: add cli option for model name (log)
Browse files
train.py
CHANGED
@@ -77,6 +77,13 @@ parser.add_argument(
|
|
77 |
default=["./dataset/font_img"],
|
78 |
help="Datasets paths, seperated by space (default: ['./dataset/font_img'])",
|
79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
args = parser.parse_args()
|
82 |
|
@@ -124,7 +131,7 @@ data_module = FontDataModule(
|
|
124 |
num_iters = data_module.get_train_num_iter(num_device) * num_epochs
|
125 |
num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs
|
126 |
|
127 |
-
model_name =
|
128 |
|
129 |
logger_unconditioned = TensorBoardLogger(
|
130 |
save_dir=os.getcwd(), name="tensorboard", version=model_name
|
|
|
77 |
default=["./dataset/font_img"],
|
78 |
help="Datasets paths, seperated by space (default: ['./dataset/font_img'])",
|
79 |
)
|
80 |
+
parser.add_argument(
|
81 |
+
"-n",
|
82 |
+
"--model-name",
|
83 |
+
type=str,
|
84 |
+
default=get_current_tag(),
|
85 |
+
help="Model name (default: current tag)",
|
86 |
+
)
|
87 |
|
88 |
args = parser.parse_args()
|
89 |
|
|
|
131 |
num_iters = data_module.get_train_num_iter(num_device) * num_epochs
|
132 |
num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs
|
133 |
|
134 |
+
model_name = args.model_name
|
135 |
|
136 |
logger_unconditioned = TensorBoardLogger(
|
137 |
save_dir=os.getcwd(), name="tensorboard", version=model_name
|