Vincentqyw
fix: roma
8b973ee
raw
history blame
3.87 kB
import argparse
def str2bool(v):
return v.lower() in ("true", "1")
arg_lists = []
parser = argparse.ArgumentParser()
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
# -----------------------------------------------------------------------------
# Network
net_arg = add_argument_group("Network")
net_arg.add_argument(
"--model_name", type=str, default="SGM", help="" "model for training"
)
net_arg.add_argument(
"--config_path",
type=str,
default="configs/sgm.yaml",
help="" "config path for model",
)
# -----------------------------------------------------------------------------
# Data
data_arg = add_argument_group("Data")
data_arg.add_argument(
"--rawdata_path", type=str, default="rawdata", help="" "path for rawdata"
)
data_arg.add_argument(
"--dataset_path", type=str, default="dataset", help="" "path for dataset"
)
data_arg.add_argument(
"--desc_path", type=str, default="desc", help="" "path for descriptor(kpt) dir"
)
data_arg.add_argument(
"--num_kpt", type=int, default=1000, help="" "number of kpt for training"
)
data_arg.add_argument(
"--input_normalize",
type=str,
default="img",
help="" "normalize type for input kpt, img or intrinsic",
)
data_arg.add_argument(
"--data_aug",
type=str2bool,
default=True,
help="" "apply kpt coordinate homography augmentation",
)
data_arg.add_argument(
"--desc_suffix", type=str, default="suffix", help="" "desc file suffix"
)
# -----------------------------------------------------------------------------
# Loss
loss_arg = add_argument_group("loss")
loss_arg.add_argument("--momentum", type=float, default=0.9, help="" "momentum")
loss_arg.add_argument(
"--seed_loss_weight",
type=float,
default=250,
help="" "confidence loss weight for sgm",
)
loss_arg.add_argument(
"--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm"
)
loss_arg.add_argument(
"--inlier_th",
type=float,
default=5e-3,
help="" "inlier threshold for epipolar distance (for sgm and visualization)",
)
# -----------------------------------------------------------------------------
# Training
train_arg = add_argument_group("Train")
train_arg.add_argument("--train_lr", type=float, default=1e-4, help="" "learning rate")
train_arg.add_argument("--train_batch_size", type=int, default=16, help="" "batch size")
train_arg.add_argument(
"--gpu_id", type=str, default="0", help="id(s) for CUDA_VISIBLE_DEVICES"
)
train_arg.add_argument(
"--train_iter", type=int, default=1000000, help="" "training iterations to perform"
)
train_arg.add_argument("--log_base", type=str, default="./log/", help="" "log path")
train_arg.add_argument(
"--val_intv", type=int, default=20000, help="" "validation interval"
)
train_arg.add_argument(
"--save_intv", type=int, default=1000, help="" "summary interval"
)
train_arg.add_argument("--log_intv", type=int, default=100, help="" "log interval")
train_arg.add_argument(
"--decay_rate", type=float, default=0.999996, help="" "lr decay rate"
)
train_arg.add_argument(
"--decay_iter", type=float, default=300000, help="" "lr decay iter"
)
train_arg.add_argument(
"--local_rank", type=int, default=0, help="" "local rank for ddp"
)
train_arg.add_argument(
"--train_vis_folder",
type=str,
default=".",
help="" "visualization folder during training",
)
# -----------------------------------------------------------------------------
# Visualization
vis_arg = add_argument_group("Visualization")
vis_arg.add_argument(
"--tqdm_width", type=int, default=79, help="" "width of the tqdm bar"
)
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed
def print_usage():
parser.print_usage()
#
# config.py ends here