File size: 3,460 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

from yacs.config import CfgNode as CN
import argparse
import yaml
import os

abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))

_C = CN()
_C.barc_dir = abs_barc_dir
_C.device = 'cuda'

## path settings
_C.paths = CN()
_C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/'
_C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/'
_C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'

## parameter settings
_C.params = CN()
_C.params.ARCH = 'hg8'    
_C.params.STRUCTURE_POSE_NET = 'normflow'     # 'default'   # 'vae' 
_C.params.NF_VERSION = 3
_C.params.N_JOINTS = 35   
_C.params.N_KEYP = 24      #20    
_C.params.N_SEG = 2
_C.params.N_PARTSEG = 15
_C.params.UPSAMPLE_SEG = True
_C.params.ADD_PARTSEG = True   # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt)
_C.params.N_BETAS = 30    # 10
_C.params.N_BETAS_LIMBS = 7
_C.params.N_BONES = 24
_C.params.N_BREEDS = 121      # 120 breeds plus background
_C.params.IMG_SIZE = 256
_C.params.SILH_NO_TAIL = False
_C.params.KP_THRESHOLD = None    
_C.params.ADD_Z_TO_3D_INPUT = False   
_C.params.N_SEGBPS = 64*2
_C.params.ADD_SEGBPS_TO_3D_INPUT = True
_C.params.FIX_FLENGTH = False   
_C.params.RENDER_ALL = True
_C.params.VLIN = 2    
_C.params.STRUCTURE_Z_TO_B = 'lin'
_C.params.N_Z_FREE = 64
_C.params.PCK_THRESH = 0.15    

## optimization settings
_C.optim = CN()
_C.optim.LR = 5e-4
_C.optim.SCHEDULE = [150, 175, 200]
_C.optim.GAMMA = 0.1
_C.optim.MOMENTUM = 0
_C.optim.WEIGHT_DECAY = 0
_C.optim.EPOCHS = 220
_C.optim.BATCH_SIZE = 12       # keep 12 (needs to be an even number, as we have a custom data sampler)
_C.optim.TRAIN_PARTS = 'all_without_shapedirs'

## dataset settings
_C.data = CN()
_C.data.DATASET = 'stanext24'
_C.data.V12 = True     
_C.data.SHORTEN_VAL_DATASET_TO = None        
_C.data.VAL_OPT = 'val'
_C.data.VAL_METRICS = 'no_loss'

# ---------------------------------------
def update_dependent_vars(cfg):    
    cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG
    if cfg.params.VLIN == 0: 
        cfg.params.NUM_STAGE_COMB = 2
        cfg.params.NUM_STAGE_HEADS = 1  
        cfg.params.NUM_STAGE_HEADS_POSE = 1
        cfg.params.TRANS_SEP = False
    elif cfg.params.VLIN == 1:
        cfg.params.NUM_STAGE_COMB = 3              
        cfg.params.NUM_STAGE_HEADS = 1             
        cfg.params.NUM_STAGE_HEADS_POSE = 2        
        cfg.params.TRANS_SEP = False
    elif cfg.params.VLIN == 2:
        cfg.params.NUM_STAGE_COMB = 3              
        cfg.params.NUM_STAGE_HEADS = 1             
        cfg.params.NUM_STAGE_HEADS_POSE = 2        
        cfg.params.TRANS_SEP = True
    else:
        raise NotImplementedError
    if cfg.params.STRUCTURE_Z_TO_B == '1dconv':
        cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS
    else:
        cfg.params.N_Z = cfg.params.N_Z_FREE
    return


update_dependent_vars(_C)
global _cfg_global 
_cfg_global = _C.clone()


def get_cfg_defaults():
    # Get a yacs CfgNode object with default values as defined within this file.
    # Return a clone so that the defaults will not be altered.
    return _C.clone()

def update_cfg_global_with_yaml(cfg_yaml_file):    
    _cfg_global.merge_from_file(cfg_yaml_file)
    update_dependent_vars(_cfg_global)
    return 

def get_cfg_global_updated():
    # return _cfg_global.clone()
    return _cfg_global