Spaces:
Runtime error
Runtime error
File size: 3,484 Bytes
7629b39 432392d 7629b39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from yacs.config import CfgNode as CN
import argparse
import yaml
import os
abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))
_C = CN()
_C.barc_dir = abs_barc_dir
_C.device = 'cuda' # 'cpu' # 'cuda'
## path settings
_C.paths = CN()
_C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/'
_C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/'
_C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
## parameter settings
_C.params = CN()
_C.params.ARCH = 'hg8'
_C.params.STRUCTURE_POSE_NET = 'normflow' # 'default' # 'vae'
_C.params.NF_VERSION = 3
_C.params.N_JOINTS = 35
_C.params.N_KEYP = 24 #20
_C.params.N_SEG = 2
_C.params.N_PARTSEG = 15
_C.params.UPSAMPLE_SEG = True
_C.params.ADD_PARTSEG = True # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt)
_C.params.N_BETAS = 30 # 10
_C.params.N_BETAS_LIMBS = 7
_C.params.N_BONES = 24
_C.params.N_BREEDS = 121 # 120 breeds plus background
_C.params.IMG_SIZE = 256
_C.params.SILH_NO_TAIL = False
_C.params.KP_THRESHOLD = None
_C.params.ADD_Z_TO_3D_INPUT = False
_C.params.N_SEGBPS = 64*2
_C.params.ADD_SEGBPS_TO_3D_INPUT = True
_C.params.FIX_FLENGTH = False
_C.params.RENDER_ALL = True
_C.params.VLIN = 2
_C.params.STRUCTURE_Z_TO_B = 'lin'
_C.params.N_Z_FREE = 64
_C.params.PCK_THRESH = 0.15
## optimization settings
_C.optim = CN()
_C.optim.LR = 5e-4
_C.optim.SCHEDULE = [150, 175, 200]
_C.optim.GAMMA = 0.1
_C.optim.MOMENTUM = 0
_C.optim.WEIGHT_DECAY = 0
_C.optim.EPOCHS = 220
_C.optim.BATCH_SIZE = 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
_C.optim.TRAIN_PARTS = 'all_without_shapedirs'
## dataset settings
_C.data = CN()
_C.data.DATASET = 'stanext24'
_C.data.V12 = True
_C.data.SHORTEN_VAL_DATASET_TO = None
_C.data.VAL_OPT = 'val'
_C.data.VAL_METRICS = 'no_loss'
# ---------------------------------------
def update_dependent_vars(cfg):
cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG
if cfg.params.VLIN == 0:
cfg.params.NUM_STAGE_COMB = 2
cfg.params.NUM_STAGE_HEADS = 1
cfg.params.NUM_STAGE_HEADS_POSE = 1
cfg.params.TRANS_SEP = False
elif cfg.params.VLIN == 1:
cfg.params.NUM_STAGE_COMB = 3
cfg.params.NUM_STAGE_HEADS = 1
cfg.params.NUM_STAGE_HEADS_POSE = 2
cfg.params.TRANS_SEP = False
elif cfg.params.VLIN == 2:
cfg.params.NUM_STAGE_COMB = 3
cfg.params.NUM_STAGE_HEADS = 1
cfg.params.NUM_STAGE_HEADS_POSE = 2
cfg.params.TRANS_SEP = True
else:
raise NotImplementedError
if cfg.params.STRUCTURE_Z_TO_B == '1dconv':
cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS
else:
cfg.params.N_Z = cfg.params.N_Z_FREE
return
update_dependent_vars(_C)
global _cfg_global
_cfg_global = _C.clone()
def get_cfg_defaults():
# Get a yacs CfgNode object with default values as defined within this file.
# Return a clone so that the defaults will not be altered.
return _C.clone()
def update_cfg_global_with_yaml(cfg_yaml_file):
_cfg_global.merge_from_file(cfg_yaml_file)
update_dependent_vars(_cfg_global)
return
def get_cfg_global_updated():
# return _cfg_global.clone()
return _cfg_global
|