gyrojeff commited on
Commit
705feb9
·
1 Parent(s): d1e10d9

feat: add cli config for size

Browse files
Files changed (1) hide show
  1. train.py +11 -2
train.py CHANGED
@@ -81,7 +81,7 @@ parser.add_argument(
81
  "-n",
82
  "--model-name",
83
  type=str,
84
- default=get_current_tag(),
85
  help="Model name (default: current tag)",
86
  )
87
  parser.add_argument(
@@ -90,6 +90,13 @@ parser.add_argument(
90
  action="store_true",
91
  help="Font classification only (default: False)",
92
  )
 
 
 
 
 
 
 
93
 
94
  args = parser.parse_args()
95
 
@@ -99,6 +106,8 @@ single_batch_size = args.single_batch_size
99
  total_num_workers = os.cpu_count()
100
  single_device_num_workers = total_num_workers // len(devices)
101
 
 
 
102
  if os.name == "nt":
103
  single_device_num_workers = 0
104
 
@@ -137,7 +146,7 @@ data_module = FontDataModule(
137
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
138
  num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs
139
 
140
- model_name = args.model_name
141
 
142
  logger_unconditioned = TensorBoardLogger(
143
  save_dir=os.getcwd(), name="tensorboard", version=model_name
 
81
  "-n",
82
  "--model-name",
83
  type=str,
84
+ default=None,
85
  help="Model name (default: current tag)",
86
  )
87
  parser.add_argument(
 
90
  action="store_true",
91
  help="Font classification only (default: False)",
92
  )
93
+ parser.add_argument(
94
+ "-z",
95
+ "--size",
96
+ type=int,
97
+ default=512,
98
+ help="Model feature image input size (default: 512)",
99
+ )
100
 
101
  args = parser.parse_args()
102
 
 
106
  total_num_workers = os.cpu_count()
107
  single_device_num_workers = total_num_workers // len(devices)
108
 
109
+ config.INPUT_SIZE = args.size
110
+
111
  if os.name == "nt":
112
  single_device_num_workers = 0
113
 
 
146
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
147
  num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs
148
 
149
+ model_name = get_current_tag() if args.model_name is None else args.model_name
150
 
151
  logger_unconditioned = TensorBoardLogger(
152
  save_dir=os.getcwd(), name="tensorboard", version=model_name