barc_gradio / src /configs /barc_cfg_defaults.py
Nadine Rueegg
initial commit for barc
7629b39
raw history blame
No virus
3.46 kB
from yacs.config import CfgNode as CN
import argparse
import yaml
import os
abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))
_C = CN()
_C.barc_dir = abs_barc_dir
_C.device = 'cuda'
## path settings
_C.paths = CN()
_C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/'
_C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/'
_C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
## parameter settings
_C.params = CN()
_C.params.ARCH = 'hg8'
_C.params.STRUCTURE_POSE_NET = 'normflow' # 'default' # 'vae'
_C.params.NF_VERSION = 3
_C.params.N_JOINTS = 35
_C.params.N_KEYP = 24 #20
_C.params.N_SEG = 2
_C.params.N_PARTSEG = 15
_C.params.UPSAMPLE_SEG = True
_C.params.ADD_PARTSEG = True # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt)
_C.params.N_BETAS = 30 # 10
_C.params.N_BETAS_LIMBS = 7
_C.params.N_BONES = 24
_C.params.N_BREEDS = 121 # 120 breeds plus background
_C.params.IMG_SIZE = 256
_C.params.SILH_NO_TAIL = False
_C.params.KP_THRESHOLD = None
_C.params.ADD_Z_TO_3D_INPUT = False
_C.params.N_SEGBPS = 64*2
_C.params.ADD_SEGBPS_TO_3D_INPUT = True
_C.params.FIX_FLENGTH = False
_C.params.RENDER_ALL = True
_C.params.VLIN = 2
_C.params.STRUCTURE_Z_TO_B = 'lin'
_C.params.N_Z_FREE = 64
_C.params.PCK_THRESH = 0.15
## optimization settings
_C.optim = CN()
_C.optim.LR = 5e-4
_C.optim.SCHEDULE = [150, 175, 200]
_C.optim.GAMMA = 0.1
_C.optim.MOMENTUM = 0
_C.optim.WEIGHT_DECAY = 0
_C.optim.EPOCHS = 220
_C.optim.BATCH_SIZE = 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
_C.optim.TRAIN_PARTS = 'all_without_shapedirs'
## dataset settings
_C.data = CN()
_C.data.DATASET = 'stanext24'
_C.data.V12 = True
_C.data.SHORTEN_VAL_DATASET_TO = None
_C.data.VAL_OPT = 'val'
_C.data.VAL_METRICS = 'no_loss'
# ---------------------------------------
def update_dependent_vars(cfg):
cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG
if cfg.params.VLIN == 0:
cfg.params.NUM_STAGE_COMB = 2
cfg.params.NUM_STAGE_HEADS = 1
cfg.params.NUM_STAGE_HEADS_POSE = 1
cfg.params.TRANS_SEP = False
elif cfg.params.VLIN == 1:
cfg.params.NUM_STAGE_COMB = 3
cfg.params.NUM_STAGE_HEADS = 1
cfg.params.NUM_STAGE_HEADS_POSE = 2
cfg.params.TRANS_SEP = False
elif cfg.params.VLIN == 2:
cfg.params.NUM_STAGE_COMB = 3
cfg.params.NUM_STAGE_HEADS = 1
cfg.params.NUM_STAGE_HEADS_POSE = 2
cfg.params.TRANS_SEP = True
else:
raise NotImplementedError
if cfg.params.STRUCTURE_Z_TO_B == '1dconv':
cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS
else:
cfg.params.N_Z = cfg.params.N_Z_FREE
return
update_dependent_vars(_C)
global _cfg_global
_cfg_global = _C.clone()
def get_cfg_defaults():
# Get a yacs CfgNode object with default values as defined within this file.
# Return a clone so that the defaults will not be altered.
return _C.clone()
def update_cfg_global_with_yaml(cfg_yaml_file):
_cfg_global.merge_from_file(cfg_yaml_file)
update_dependent_vars(_cfg_global)
return
def get_cfg_global_updated():
# return _cfg_global.clone()
return _cfg_global