File size: 11,195 Bytes
c599a73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
cudnn_benchmark = True
custom_imports = dict(imports=['geospatial_fm'])
dataset_type = 'GeospatialDataset'
data_root = '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended'
num_frames = 1
img_size = 224
num_workers = 4
samples_per_gpu = 4
img_norm_cfg = dict(
    means=[
        0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
        0.2323245113436119, 0.1972854853760658, 0.11944914225186566
    ],
    stds=[
        0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
        0.07791732423672691, 0.08708738838140137, 0.07241979477437814
    ])
bands = [0, 1, 2, 3, 4, 5]
tile_size = 224
orig_nsize = 512
crop_size = (224, 224)
img_suffix = '_merged.tif'
seg_map_suffix = '.mask.tif'
ignore_index = -1
image_nodata = -9999
image_nodata_replace = 0
image_to_float32 = True
# pretrained_weights_path = '/dccstor/geofm-finetuning/pretrain_ckpts/mae_weights/2023-04-29_21-50-47/epoch-725-loss-0.0365.pt'
pretrained_weights_path = None
num_layers = 12
patch_size = 16
embed_dim = 768
num_heads = 12
tubelet_size = 1
epochs = 50
eval_epoch_interval = 5
experiment = 'test2'
project_dir = '/dccstor/geofm-finetuning/fire-scars/os'
work_dir = '/dccstor/geofm-finetuning/fire-scars/os/test2'
save_path = '/dccstor/geofm-finetuning/fire-scars/os/test2'
train_pipeline = [
    dict(type='LoadGeospatialImageFromFile', to_float32=True),
    dict(type='LoadGeospatialAnnotations', reduce_zero_label=False),
    dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
    dict(type='RandomFlip', prob=0.5),
    dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
    dict(
        type='TorchNormalize',
        means=[
            0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
            0.2323245113436119, 0.1972854853760658, 0.11944914225186566
        ],
        stds=[
            0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
            0.07791732423672691, 0.08708738838140137, 0.07241979477437814
        ]),
    dict(type='TorchRandomCrop', crop_size=(224, 224)),
    dict(type='Reshape', keys=['img'], new_shape=(6, 1, 224, 224)),
    dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, 224, 224)),
    dict(
        type='CastTensor',
        keys=['gt_semantic_seg'],
        new_type='torch.LongTensor'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
    dict(type='LoadGeospatialImageFromFile', to_float32=True),
    dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
    dict(type='ToTensor', keys=['img']),
    dict(
        type='TorchNormalize',
        means=[
            0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
            0.2323245113436119, 0.1972854853760658, 0.11944914225186566
        ],
        stds=[
            0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
            0.07791732423672691, 0.08708738838140137, 0.07241979477437814
        ]),
    dict(
        type='Reshape',
        keys=['img'],
        new_shape=(6, 1, -1, -1),
        look_up=dict({
            '2': 1,
            '3': 2
        })),
    dict(type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
    dict(
        type='CollectTestList',
        keys=['img'],
        meta_keys=[
            'img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename',
            'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape',
            'scale_factor', 'img_norm_cfg'
        ])
]
data = dict(
    samples_per_gpu=4,
    workers_per_gpu=4,
    train=dict(
        type='FireScars',
        data_root=
        '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
        img_dir='training',
        ann_dir='training',
        img_suffix='_merged.tif',
        seg_map_suffix='.mask.tif',
        pipeline=[
            dict(type='LoadGeospatialImageFromFile', to_float32=True),
            dict(type='LoadGeospatialAnnotations', reduce_zero_label=False),
            dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
            dict(type='RandomFlip', prob=0.5),
            dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
            dict(
                type='TorchNormalize',
                means=[
                    0.033349706741586264, 0.05701185520536176,
                    0.05889748132001316, 0.2323245113436119,
                    0.1972854853760658, 0.11944914225186566
                ],
                stds=[
                    0.02269135568823774, 0.026807560223070237,
                    0.04004109844362779, 0.07791732423672691,
                    0.08708738838140137, 0.07241979477437814
                ]),
            dict(type='TorchRandomCrop', crop_size=(224, 224)),
            dict(type='Reshape', keys=['img'], new_shape=(6, 1, 224, 224)),
            dict(
                type='Reshape',
                keys=['gt_semantic_seg'],
                new_shape=(1, 224, 224)),
            dict(
                type='CastTensor',
                keys=['gt_semantic_seg'],
                new_type='torch.LongTensor'),
            dict(type='Collect', keys=['img', 'gt_semantic_seg'])
        ],
        ignore_index=-1),
    val=dict(
        type='FireScars',
        data_root=
        '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
        img_dir='validation',
        ann_dir='validation',
        img_suffix='_merged.tif',
        seg_map_suffix='.mask.tif',
        pipeline=[
            dict(type='LoadGeospatialImageFromFile', to_float32=True),
            dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
            dict(type='ToTensor', keys=['img']),
            dict(
                type='TorchNormalize',
                means=[
                    0.033349706741586264, 0.05701185520536176,
                    0.05889748132001316, 0.2323245113436119,
                    0.1972854853760658, 0.11944914225186566
                ],
                stds=[
                    0.02269135568823774, 0.026807560223070237,
                    0.04004109844362779, 0.07791732423672691,
                    0.08708738838140137, 0.07241979477437814
                ]),
            dict(
                type='Reshape',
                keys=['img'],
                new_shape=(6, 1, -1, -1),
                look_up=dict({
                    '2': 1,
                    '3': 2
                })),
            dict(
                type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
            dict(
                type='CollectTestList',
                keys=['img'],
                meta_keys=[
                    'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
                    'filename', 'ori_filename', 'img', 'img_shape',
                    'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
                ])
        ],
        ignore_index=-1),
    test=dict(
        type='FireScars',
        data_root=
        '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
        img_dir='validation',
        ann_dir='validation',
        img_suffix='_merged.tif',
        seg_map_suffix='.mask.tif',
        pipeline=[
            dict(type='LoadGeospatialImageFromFile', to_float32=True),
            dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
            dict(type='ToTensor', keys=['img']),
            dict(
                type='TorchNormalize',
                means=[
                    0.033349706741586264, 0.05701185520536176,
                    0.05889748132001316, 0.2323245113436119,
                    0.1972854853760658, 0.11944914225186566
                ],
                stds=[
                    0.02269135568823774, 0.026807560223070237,
                    0.04004109844362779, 0.07791732423672691,
                    0.08708738838140137, 0.07241979477437814
                ]),
            dict(
                type='Reshape',
                keys=['img'],
                new_shape=(6, 1, -1, -1),
                look_up=dict({
                    '2': 1,
                    '3': 2
                })),
            dict(
                type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
            dict(
                type='CollectTestList',
                keys=['img'],
                meta_keys=[
                    'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
                    'filename', 'ori_filename', 'img', 'img_shape',
                    'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
                ])
        ],
        ignore_index=-1))
optimizer = dict(type='Adam', lr=1.3e-05, betas=(0.9, 0.999))
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-06,
    power=1.0,
    min_lr=0.0,
    by_epoch=False)
log_config = dict(
    interval=20,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),
        dict(type='TensorboardLoggerHook', by_epoch=False)
    ])
checkpoint_config = dict(
    by_epoch=True,
    interval=10,
    out_dir=
    '/dccstor/geofm-finetuning/carlosgomes/fire_scars/carlos_replicate_experiment_fixed_lr'
)
evaluation = dict(
    interval=1180,
    metric='mIoU',
    pre_eval=True,
    save_best='mIoU',
    by_epoch=False)
runner = dict(type='IterBasedRunner', max_iters=6300)
workflow = [('train', 1)]
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
    type='TemporalEncoderDecoder',
    frozen_backbone=False,
    backbone=dict(
        type='TemporalViTEncoder',
        pretrained=None,
        # '/dccstor/geofm-finetuning/pretrain_ckpts/mae_weights/2023-04-29_21-50-47/epoch-725-loss-0.0365.pt',
        img_size=224,
        patch_size=16,
        num_frames=1,
        tubelet_size=1,
        in_chans=6,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        norm_pix_loss=False),
    neck=dict(
        type='ConvTransformerTokensToEmbeddingNeck',
        embed_dim=768,
        output_embed_dim=768,
        drop_cls_token=True,
        Hp=14,
        Wp=14),
    decode_head=dict(
        num_classes=2,
        in_channels=768,
        type='FCNHead',
        in_index=-1,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        norm_cfg=dict(type='BN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='DiceLoss', use_sigmoid=False, loss_weight=1,
            ignore_index=-1)),
    auxiliary_head=dict(
        num_classes=2,
        in_channels=768,
        type='FCNHead',
        in_index=-1,
        channels=256,
        num_convs=2,
        concat_input=False,
        dropout_ratio=0.1,
        norm_cfg=dict(type='BN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='DiceLoss', use_sigmoid=False, loss_weight=1,
            ignore_index=-1)),
    train_cfg=dict(),
    test_cfg=dict(mode='slide', stride=(112, 112), crop_size=(224, 224)))
gpu_ids = range(0, 1)
auto_resume = False