import argparse def str2bool(v): return v.lower() in ("true", "1") arg_lists = [] parser = argparse.ArgumentParser() def add_argument_group(name): arg = parser.add_argument_group(name) arg_lists.append(arg) return arg # ----------------------------------------------------------------------------- # Network net_arg = add_argument_group("Network") net_arg.add_argument( "--model_name", type=str, default="SGM", help="" "model for training" ) net_arg.add_argument( "--config_path", type=str, default="configs/sgm.yaml", help="" "config path for model", ) # ----------------------------------------------------------------------------- # Data data_arg = add_argument_group("Data") data_arg.add_argument( "--rawdata_path", type=str, default="rawdata", help="" "path for rawdata" ) data_arg.add_argument( "--dataset_path", type=str, default="dataset", help="" "path for dataset" ) data_arg.add_argument( "--desc_path", type=str, default="desc", help="" "path for descriptor(kpt) dir" ) data_arg.add_argument( "--num_kpt", type=int, default=1000, help="" "number of kpt for training" ) data_arg.add_argument( "--input_normalize", type=str, default="img", help="" "normalize type for input kpt, img or intrinsic", ) data_arg.add_argument( "--data_aug", type=str2bool, default=True, help="" "apply kpt coordinate homography augmentation", ) data_arg.add_argument( "--desc_suffix", type=str, default="suffix", help="" "desc file suffix" ) # ----------------------------------------------------------------------------- # Loss loss_arg = add_argument_group("loss") loss_arg.add_argument("--momentum", type=float, default=0.9, help="" "momentum") loss_arg.add_argument( "--seed_loss_weight", type=float, default=250, help="" "confidence loss weight for sgm", ) loss_arg.add_argument( "--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm" ) loss_arg.add_argument( "--inlier_th", type=float, default=5e-3, help="" "inlier threshold for epipolar distance (for sgm and visualization)", ) # ----------------------------------------------------------------------------- # Training train_arg = add_argument_group("Train") train_arg.add_argument("--train_lr", type=float, default=1e-4, help="" "learning rate") train_arg.add_argument("--train_batch_size", type=int, default=16, help="" "batch size") train_arg.add_argument( "--gpu_id", type=str, default="0", help="id(s) for CUDA_VISIBLE_DEVICES" ) train_arg.add_argument( "--train_iter", type=int, default=1000000, help="" "training iterations to perform" ) train_arg.add_argument("--log_base", type=str, default="./log/", help="" "log path") train_arg.add_argument( "--val_intv", type=int, default=20000, help="" "validation interval" ) train_arg.add_argument( "--save_intv", type=int, default=1000, help="" "summary interval" ) train_arg.add_argument("--log_intv", type=int, default=100, help="" "log interval") train_arg.add_argument( "--decay_rate", type=float, default=0.999996, help="" "lr decay rate" ) train_arg.add_argument( "--decay_iter", type=float, default=300000, help="" "lr decay iter" ) train_arg.add_argument( "--local_rank", type=int, default=0, help="" "local rank for ddp" ) train_arg.add_argument( "--train_vis_folder", type=str, default=".", help="" "visualization folder during training", ) # ----------------------------------------------------------------------------- # Visualization vis_arg = add_argument_group("Visualization") vis_arg.add_argument( "--tqdm_width", type=int, default=79, help="" "width of the tqdm bar" ) def get_config(): config, unparsed = parser.parse_known_args() return config, unparsed def print_usage(): parser.print_usage() # # config.py ends here