File size: 4,225 Bytes
b7eedf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
_base_=['../_base_/losses/all_losses.py',
       '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',

       '../_base_/datasets/nyu.py',
       '../_base_/datasets/kitti.py'
       ]

import numpy as np
model=dict(
    decode_head=dict(
        type='RAFTDepthNormalDPT5',
        iters=8,
        n_downsample=2,
        detach=False,
    ),
)

# loss method
losses=dict(
    decoder_losses=[
        dict(type='VNLoss', sample_ratio=0.2, loss_weight=0.1),   
        dict(type='GRUSequenceLoss', loss_weight=1.0, loss_gamma=0.9, stereo_sup=0),
        dict(type='DeNoConsistencyLoss', loss_weight=0.001, loss_fn='CEL', scale=2)
    ],
)

data_array = [

     [
          dict(KITTI='KITTI_dataset'),
     ],
]



# configs of the canonical space
data_basic=dict(
    canonical_space = dict(
        # img_size=(540, 960),
        focal_length=1000.0,
    ),
    depth_range=(0, 1),
    depth_normalize=(0.1, 200),
#     crop_size=(544, 1216),
#     crop_size = (544, 992),
    crop_size = (616, 1064),  # %28 = 0
) 

# online evaluation
# evaluation = dict(online_eval=True, interval=1000, metrics=['abs_rel', 'delta1', 'rmse'], multi_dataset_eval=True)
#log_interval = 100

interval = 4000
log_interval = 100
evaluation = dict(
    online_eval=False, 
    interval=interval, 
    metrics=['abs_rel', 'delta1', 'rmse', 'normal_mean', 'normal_rmse', 'normal_a1'], 
    multi_dataset_eval=True,
    exclude=['DIML_indoor', 'GL3D', 'Tourism', 'MegaDepth'],
)

# save checkpoint during training, with '*_AMP' is employing the automatic mix precision training
checkpoint_config = dict(by_epoch=False, interval=interval)
runner = dict(type='IterBasedRunner_AMP', max_iters=20010)

# optimizer
optimizer = dict(
    type='AdamW', 
    encoder=dict(lr=5e-7, betas=(0.9, 0.999), weight_decay=0, eps=1e-10),
    decoder=dict(lr=1e-5, betas=(0.9, 0.999), weight_decay=0, eps=1e-10),
    strict_match = True
)
# schedule
lr_config = dict(policy='poly',
                 warmup='linear',
                 warmup_iters=20,
                 warmup_ratio=1e-6,
                 power=0.9, min_lr=1e-8, by_epoch=False)

acc_batch = 1
batchsize_per_gpu = 2
thread_per_gpu = 2

KITTI_dataset=dict(
    data = dict(
    train=dict(
        pipeline=[dict(type='BGR2RGB'),
                  dict(type='LabelScaleCononical'),
                  dict(type='RandomResize',
                         prob=0.5,
                         ratio_range=(0.85, 1.15),
                         is_lidar=True),
                    dict(type='RandomCrop', 
                         crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
                         crop_type='rand', 
                         ignore_label=-1, 
                         padding=[0, 0, 0]),
                    dict(type='RandomEdgeMask',
                         mask_maxsize=50, 
                         prob=0.2, 
                         rgb_invalid=[0,0,0], 
                         label_invalid=-1,),
                  dict(type='RandomHorizontalFlip', 
                       prob=0.4),
                  dict(type='PhotoMetricDistortion', 
                       to_gray_prob=0.1,
                       distortion_prob=0.1,),
                  dict(type='Weather',
                       prob=0.05),
                  dict(type='RandomBlur', 
                       prob=0.05),
                  dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
                  dict(type='ToTensor'),
                  dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
                 ],
        #sample_size = 10,
    ),
    val=dict(
        pipeline=[dict(type='BGR2RGB'),
                  dict(type='LabelScaleCononical'),
                  dict(type='RandomCrop', 
                         crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
                         crop_type='center', 
                         ignore_label=-1, 
                         padding=[0, 0, 0]),
                  dict(type='ToTensor'),
                  dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
                 ],
        sample_size = 1200,
    ),
    ))