Paolo-Fraccaro commited on
Commit
91b09fa
1 Parent(s): 44130b3

Update burn_scars_Prithvi_100M.py

Browse files
Files changed (1) hide show
  1. burn_scars_Prithvi_100M.py +89 -183
burn_scars_Prithvi_100M.py CHANGED
@@ -1,11 +1,17 @@
 
 
 
1
  dist_params = dict(backend='nccl')
2
  log_level = 'INFO'
3
  load_from = None
4
  resume_from = None
5
  cudnn_benchmark = True
6
- custom_imports = dict(imports=['geospatial_fm'])
7
  dataset_type = 'GeospatialDataset'
8
- data_root = '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended'
 
 
 
9
  num_frames = 1
10
  img_size = 224
11
  num_workers = 4
@@ -22,45 +28,45 @@ img_norm_cfg = dict(
22
  bands = [0, 1, 2, 3, 4, 5]
23
  tile_size = 224
24
  orig_nsize = 512
25
- crop_size = (224, 224)
26
  img_suffix = '_merged.tif'
27
  seg_map_suffix = '.mask.tif'
28
  ignore_index = -1
29
  image_nodata = -9999
30
  image_nodata_replace = 0
31
  image_to_float32 = True
32
- # pretrained_weights_path = '/dccstor/geofm-finetuning/pretrain_ckpts/mae_weights/2023-04-29_21-50-47/epoch-725-loss-0.0365.pt'
33
- pretrained_weights_path = None
 
 
34
  num_layers = 12
35
  patch_size = 16
36
  embed_dim = 768
37
  num_heads = 12
38
  tubelet_size = 1
39
- epochs = 50
40
- eval_epoch_interval = 5
41
- experiment = 'test2'
42
- project_dir = '/dccstor/geofm-finetuning/fire-scars/os'
43
- work_dir = '/dccstor/geofm-finetuning/fire-scars/os/test2'
44
- save_path = '/dccstor/geofm-finetuning/fire-scars/os/test2'
 
 
 
 
 
45
  train_pipeline = [
46
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
47
  dict(type='LoadGeospatialAnnotations', reduce_zero_label=False),
48
- dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
49
  dict(type='RandomFlip', prob=0.5),
50
  dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
51
- dict(
52
- type='TorchNormalize',
53
- means=[
54
- 0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
55
- 0.2323245113436119, 0.1972854853760658, 0.11944914225186566
56
- ],
57
- stds=[
58
- 0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
59
- 0.07791732423672691, 0.08708738838140137, 0.07241979477437814
60
- ]),
61
- dict(type='TorchRandomCrop', crop_size=(224, 224)),
62
- dict(type='Reshape', keys=['img'], new_shape=(6, 1, 224, 224)),
63
- dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, 224, 224)),
64
  dict(
65
  type='CastTensor',
66
  keys=['gt_semantic_seg'],
@@ -68,23 +74,16 @@ train_pipeline = [
68
  dict(type='Collect', keys=['img', 'gt_semantic_seg'])
69
  ]
70
  test_pipeline = [
71
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
72
- dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
73
  dict(type='ToTensor', keys=['img']),
74
- dict(
75
- type='TorchNormalize',
76
- means=[
77
- 0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
78
- 0.2323245113436119, 0.1972854853760658, 0.11944914225186566
79
- ],
80
- stds=[
81
- 0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
82
- 0.07791732423672691, 0.08708738838140137, 0.07241979477437814
83
- ]),
84
  dict(
85
  type='Reshape',
86
  keys=['img'],
87
- new_shape=(6, 1, -1, -1),
88
  look_up=dict({
89
  '2': 1,
90
  '3': 2
@@ -99,136 +98,43 @@ test_pipeline = [
99
  'scale_factor', 'img_norm_cfg'
100
  ])
101
  ]
 
 
 
102
  data = dict(
103
- samples_per_gpu=4,
104
- workers_per_gpu=4,
105
  train=dict(
106
- type='FireScars',
107
- data_root=
108
- '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
109
  img_dir='training',
110
  ann_dir='training',
111
- img_suffix='_merged.tif',
112
- seg_map_suffix='.mask.tif',
113
- pipeline=[
114
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
115
- dict(type='LoadGeospatialAnnotations', reduce_zero_label=False),
116
- dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
117
- dict(type='RandomFlip', prob=0.5),
118
- dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
119
- dict(
120
- type='TorchNormalize',
121
- means=[
122
- 0.033349706741586264, 0.05701185520536176,
123
- 0.05889748132001316, 0.2323245113436119,
124
- 0.1972854853760658, 0.11944914225186566
125
- ],
126
- stds=[
127
- 0.02269135568823774, 0.026807560223070237,
128
- 0.04004109844362779, 0.07791732423672691,
129
- 0.08708738838140137, 0.07241979477437814
130
- ]),
131
- dict(type='TorchRandomCrop', crop_size=(224, 224)),
132
- dict(type='Reshape', keys=['img'], new_shape=(6, 1, 224, 224)),
133
- dict(
134
- type='Reshape',
135
- keys=['gt_semantic_seg'],
136
- new_shape=(1, 224, 224)),
137
- dict(
138
- type='CastTensor',
139
- keys=['gt_semantic_seg'],
140
- new_type='torch.LongTensor'),
141
- dict(type='Collect', keys=['img', 'gt_semantic_seg'])
142
- ],
143
  ignore_index=-1),
144
  val=dict(
145
- type='FireScars',
146
- data_root=
147
- '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
148
  img_dir='validation',
149
  ann_dir='validation',
150
- img_suffix='_merged.tif',
151
- seg_map_suffix='.mask.tif',
152
- pipeline=[
153
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
154
- dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
155
- dict(type='ToTensor', keys=['img']),
156
- dict(
157
- type='TorchNormalize',
158
- means=[
159
- 0.033349706741586264, 0.05701185520536176,
160
- 0.05889748132001316, 0.2323245113436119,
161
- 0.1972854853760658, 0.11944914225186566
162
- ],
163
- stds=[
164
- 0.02269135568823774, 0.026807560223070237,
165
- 0.04004109844362779, 0.07791732423672691,
166
- 0.08708738838140137, 0.07241979477437814
167
- ]),
168
- dict(
169
- type='Reshape',
170
- keys=['img'],
171
- new_shape=(6, 1, -1, -1),
172
- look_up=dict({
173
- '2': 1,
174
- '3': 2
175
- })),
176
- dict(
177
- type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
178
- dict(
179
- type='CollectTestList',
180
- keys=['img'],
181
- meta_keys=[
182
- 'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
183
- 'filename', 'ori_filename', 'img', 'img_shape',
184
- 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
185
- ])
186
- ],
187
  ignore_index=-1),
188
  test=dict(
189
- type='FireScars',
190
- data_root=
191
- '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
192
  img_dir='validation',
193
  ann_dir='validation',
194
- img_suffix='_merged.tif',
195
- seg_map_suffix='.mask.tif',
196
- pipeline=[
197
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
198
- dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
199
- dict(type='ToTensor', keys=['img']),
200
- dict(
201
- type='TorchNormalize',
202
- means=[
203
- 0.033349706741586264, 0.05701185520536176,
204
- 0.05889748132001316, 0.2323245113436119,
205
- 0.1972854853760658, 0.11944914225186566
206
- ],
207
- stds=[
208
- 0.02269135568823774, 0.026807560223070237,
209
- 0.04004109844362779, 0.07791732423672691,
210
- 0.08708738838140137, 0.07241979477437814
211
- ]),
212
- dict(
213
- type='Reshape',
214
- keys=['img'],
215
- new_shape=(6, 1, -1, -1),
216
- look_up=dict({
217
- '2': 1,
218
- '3': 2
219
- })),
220
- dict(
221
- type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
222
- dict(
223
- type='CollectTestList',
224
- keys=['img'],
225
- meta_keys=[
226
- 'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
227
- 'filename', 'ori_filename', 'img', 'img_shape',
228
- 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
229
- ])
230
- ],
231
  ignore_index=-1))
 
232
  optimizer = dict(type='Adam', lr=1.3e-05, betas=(0.9, 0.999))
233
  optimizer_config = dict(grad_clip=None)
234
  lr_config = dict(
@@ -248,16 +154,20 @@ log_config = dict(
248
  checkpoint_config = dict(
249
  by_epoch=True,
250
  interval=10,
251
- out_dir=
252
- '/dccstor/geofm-finetuning/carlosgomes/fire_scars/carlos_replicate_experiment_fixed_lr'
253
  )
254
  evaluation = dict(
255
- interval=1180,
256
  metric='mIoU',
257
  pre_eval=True,
258
  save_best='mIoU',
259
  by_epoch=False)
260
- runner = dict(type='IterBasedRunner', max_iters=6300)
 
 
 
 
 
261
  workflow = [('train', 1)]
262
  norm_cfg = dict(type='BN', requires_grad=True)
263
  model = dict(
@@ -265,28 +175,27 @@ model = dict(
265
  frozen_backbone=False,
266
  backbone=dict(
267
  type='TemporalViTEncoder',
268
- pretrained=None,
269
- # '/dccstor/geofm-finetuning/pretrain_ckpts/mae_weights/2023-04-29_21-50-47/epoch-725-loss-0.0365.pt',
270
- img_size=224,
271
- patch_size=16,
272
- num_frames=1,
273
- tubelet_size=1,
274
- in_chans=6,
275
- embed_dim=768,
276
  depth=12,
277
- num_heads=12,
278
  mlp_ratio=4.0,
279
  norm_pix_loss=False),
280
  neck=dict(
281
  type='ConvTransformerTokensToEmbeddingNeck',
282
- embed_dim=768,
283
- output_embed_dim=768,
284
  drop_cls_token=True,
285
  Hp=14,
286
  Wp=14),
287
  decode_head=dict(
288
- num_classes=2,
289
- in_channels=768,
290
  type='FCNHead',
291
  in_index=-1,
292
  channels=256,
@@ -295,12 +204,11 @@ model = dict(
295
  dropout_ratio=0.1,
296
  norm_cfg=dict(type='BN', requires_grad=True),
297
  align_corners=False,
298
- loss_decode=dict(
299
- type='DiceLoss', use_sigmoid=False, loss_weight=1,
300
- ignore_index=-1)),
301
  auxiliary_head=dict(
302
- num_classes=2,
303
- in_channels=768,
304
  type='FCNHead',
305
  in_index=-1,
306
  channels=256,
@@ -309,10 +217,8 @@ model = dict(
309
  dropout_ratio=0.1,
310
  norm_cfg=dict(type='BN', requires_grad=True),
311
  align_corners=False,
312
- loss_decode=dict(
313
- type='DiceLoss', use_sigmoid=False, loss_weight=1,
314
- ignore_index=-1)),
315
  train_cfg=dict(),
316
- test_cfg=dict(mode='slide', stride=(112, 112), crop_size=(224, 224)))
317
  gpu_ids = range(0, 1)
318
- auto_resume = False
 
1
+ import os
2
+ custom_imports = dict(imports=['geospatial_fm'])
3
+
4
  dist_params = dict(backend='nccl')
5
  log_level = 'INFO'
6
  load_from = None
7
  resume_from = None
8
  cudnn_benchmark = True
9
+
10
  dataset_type = 'GeospatialDataset'
11
+
12
+ # TO BE DEFINED BY USER: data directory
13
+ data_root = '<path to data root>'
14
+
15
  num_frames = 1
16
  img_size = 224
17
  num_workers = 4
 
28
  bands = [0, 1, 2, 3, 4, 5]
29
  tile_size = 224
30
  orig_nsize = 512
31
+ crop_size = (tile_size, tile_size)
32
  img_suffix = '_merged.tif'
33
  seg_map_suffix = '.mask.tif'
34
  ignore_index = -1
35
  image_nodata = -9999
36
  image_nodata_replace = 0
37
  image_to_float32 = True
38
+
39
+ # model
40
+ # TO BE DEFINED BY USER: model path
41
+ pretrained_weights_path = '<path to pretrained weights>'
42
  num_layers = 12
43
  patch_size = 16
44
  embed_dim = 768
45
  num_heads = 12
46
  tubelet_size = 1
47
+ output_embed_dim = num_frames*embed_dim
48
+ max_intervals=10000
49
+ evaluation_interval=1000
50
+
51
+ # TO BE DEFINED BY USER: model path
52
+ experiment = '<experiment name>'
53
+ project_dir = '<project directory name>'
54
+ work_dir = os.path.join(project_dir, experiment)
55
+ save_path = work_dir
56
+
57
+ save_path = work_dir
58
  train_pipeline = [
59
+ dict(type='LoadGeospatialImageFromFile', to_float32=image_to_float32, channels_last=True),
60
  dict(type='LoadGeospatialAnnotations', reduce_zero_label=False),
61
+ dict(type='BandsExtract', bands=bands),
62
  dict(type='RandomFlip', prob=0.5),
63
  dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
64
+ # to channels first
65
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
66
+ dict(type='TorchNormalize', **img_norm_cfg),
67
+ dict(type='TorchRandomCrop', crop_size=(tile_size, tile_size)),
68
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, tile_size, tile_size)),
69
+ dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, tile_size, tile_size)),
 
 
 
 
 
 
 
70
  dict(
71
  type='CastTensor',
72
  keys=['gt_semantic_seg'],
 
74
  dict(type='Collect', keys=['img', 'gt_semantic_seg'])
75
  ]
76
  test_pipeline = [
77
+ dict(type='LoadGeospatialImageFromFile', to_float32=image_to_float32, channels_last=True),
78
+ dict(type='BandsExtract', bands=bands),
79
  dict(type='ToTensor', keys=['img']),
80
+ # to channels first
81
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
82
+ dict(type='TorchNormalize', **img_norm_cfg),
 
 
 
 
 
 
 
83
  dict(
84
  type='Reshape',
85
  keys=['img'],
86
+ new_shape=(len(bands), num_frames, -1, -1),
87
  look_up=dict({
88
  '2': 1,
89
  '3': 2
 
98
  'scale_factor', 'img_norm_cfg'
99
  ])
100
  ]
101
+
102
+ CLASSES = ('Unburnt land', 'Burn scar')
103
+
104
  data = dict(
105
+ samples_per_gpu=samples_per_gpu,
106
+ workers_per_gpu=num_workers,
107
  train=dict(
108
+ type=dataset_type,
109
+ CLASSES=CLASSES,
110
+ data_root=data_root,
111
  img_dir='training',
112
  ann_dir='training',
113
+ img_suffix=img_suffix,
114
+ seg_map_suffix=seg_map_suffix,
115
+ pipeline=train_pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  ignore_index=-1),
117
  val=dict(
118
+ type=dataset_type,
119
+ CLASSES=CLASSES,
120
+ data_root=data_root,
121
  img_dir='validation',
122
  ann_dir='validation',
123
+ img_suffix=img_suffix,
124
+ seg_map_suffix=seg_map_suffix,
125
+ pipeline=test_pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  ignore_index=-1),
127
  test=dict(
128
+ type=dataset_type,
129
+ CLASSES=CLASSES,
130
+ data_root=data_root,
131
  img_dir='validation',
132
  ann_dir='validation',
133
+ img_suffix=img_suffix,
134
+ seg_map_suffix=seg_map_suffix,
135
+ pipeline=test_pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  ignore_index=-1))
137
+
138
  optimizer = dict(type='Adam', lr=1.3e-05, betas=(0.9, 0.999))
139
  optimizer_config = dict(grad_clip=None)
140
  lr_config = dict(
 
154
  checkpoint_config = dict(
155
  by_epoch=True,
156
  interval=10,
157
+ out_dir=save_path
 
158
  )
159
  evaluation = dict(
160
+ interval=evaluation_interval,
161
  metric='mIoU',
162
  pre_eval=True,
163
  save_best='mIoU',
164
  by_epoch=False)
165
+
166
+ loss_func=dict(
167
+ type='DiceLoss', use_sigmoid=False, loss_weight=1,
168
+ ignore_index=-1)
169
+
170
+ runner = dict(type='IterBasedRunner', max_iters=max_intervals)
171
  workflow = [('train', 1)]
172
  norm_cfg = dict(type='BN', requires_grad=True)
173
  model = dict(
 
175
  frozen_backbone=False,
176
  backbone=dict(
177
  type='TemporalViTEncoder',
178
+ pretrained=pretrained_weights_path,
179
+ img_size=img_size,
180
+ patch_size=patch_size,
181
+ num_frames=num_frames,
182
+ tubelet_size=tubelet_size,
183
+ in_chans=len(bands),
184
+ embed_dim=embed_dim,
 
185
  depth=12,
186
+ num_heads=num_heads,
187
  mlp_ratio=4.0,
188
  norm_pix_loss=False),
189
  neck=dict(
190
  type='ConvTransformerTokensToEmbeddingNeck',
191
+ embed_dim=embed_dim*num_frames,
192
+ output_embed_dim=output_embed_dim,
193
  drop_cls_token=True,
194
  Hp=14,
195
  Wp=14),
196
  decode_head=dict(
197
+ num_classes=len(CLASSES),
198
+ in_channels=output_embed_dim,
199
  type='FCNHead',
200
  in_index=-1,
201
  channels=256,
 
204
  dropout_ratio=0.1,
205
  norm_cfg=dict(type='BN', requires_grad=True),
206
  align_corners=False,
207
+ loss_decode=
208
+ loss_decode=loss_func),
 
209
  auxiliary_head=dict(
210
+ num_classes=len(CLASSES),
211
+ in_channels=output_embed_dim,
212
  type='FCNHead',
213
  in_index=-1,
214
  channels=256,
 
217
  dropout_ratio=0.1,
218
  norm_cfg=dict(type='BN', requires_grad=True),
219
  align_corners=False,
220
+ loss_decode=loss_func),
 
 
221
  train_cfg=dict(),
222
+ test_cfg=dict(mode='slide', stride=(tile_size/2, tile_size/2), crop_size=(tile_size, tile_size)))
223
  gpu_ids = range(0, 1)
224
+ auto_resume = False