Paolo-Fraccaro commited on
Commit
cf9d671
·
1 Parent(s): adbc22d

Create sen1floods11_Prithvi_100M.py

Browse files
Files changed (1) hide show
  1. sen1floods11_Prithvi_100M.py +291 -0
sen1floods11_Prithvi_100M.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # base options
4
+ dist_params = dict(backend='nccl')
5
+ log_level = 'INFO'
6
+ load_from = None
7
+ resume_from = None
8
+ cudnn_benchmark = True
9
+
10
+ custom_imports = dict(imports=["geospatial_fm"])
11
+
12
+
13
+ ### Configs
14
+ # Data
15
+ # TO BE DEFINED BY USER: Data root to sen1floods11 downloaded dataset
16
+ data_root = "<path to dataset>"
17
+
18
+ dataset_type = "GeospatialDataset"
19
+ num_classes=2
20
+ num_frames = 1
21
+ img_size = 224
22
+ num_workers = 2
23
+ samples_per_gpu = 4
24
+ CLASSES=(0,1)
25
+
26
+ img_norm_cfg = dict(means=[0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526,0.12046503],
27
+ stds=[0.04036231, 0.04186983, 0.05267646, 0.0822221 , 0.06834774, 0.05294205])
28
+
29
+ bands = [1, 2, 3, 8, 11, 12]
30
+ tile_size = img_size
31
+ orig_nsize = 512
32
+ crop_size = (tile_size, tile_size)
33
+
34
+ img_dir = data_root + "v1.1/data/flood_events/HandLabeled/S2Hand"
35
+ ann_dir = data_root + "v1.1/data/flood_events/HandLabeled/LabelHand"
36
+ img_suffix = f"_S2Hand.tif"
37
+ seg_map_suffix = f"_LabelHand.tif"
38
+
39
+ splits = {
40
+ "train": "data_splits/train_split.txt",
41
+ "val": "data_splits/val_split.txt",
42
+ "test": "data_splits/test_split.txt",
43
+ }
44
+ splits = {k: os.path.abspath(v) for (k, v) in splits.items()}
45
+
46
+ ignore_index = 2
47
+ label_nodata = -1
48
+ image_nodata = -9999
49
+ image_nodata_replace = 0
50
+ constant = 0.0001
51
+
52
+ # Model
53
+ # TO BE DEFINED BY USER: path to pretrained backbone weights
54
+ pretrained_weights_path = "<path to pretrained weights>"
55
+ num_layers = 12
56
+ patch_size = 16
57
+ embed_dim = 768
58
+ num_heads = 12
59
+ tubelet_size = 1
60
+
61
+ # TRAINING
62
+ epochs=100
63
+ eval_epoch_interval = 5
64
+
65
+ # TO BE DEFINED BY USER: Save directory
66
+ experiment = "<experiment name>"
67
+ project_dir = "<project dir>"
68
+ work_dir = os.path.join(project_dir, experiment)
69
+ save_path = work_dir
70
+
71
+ # Pipelines
72
+ train_pipeline = [
73
+ dict(
74
+ type="LoadGeospatialImageFromFile",
75
+ to_float32=False,
76
+ nodata=image_nodata,
77
+ nodata_replace=image_nodata_replace,
78
+ channels_last=False
79
+ ),
80
+ dict(
81
+ type="LoadGeospatialAnnotations",
82
+ reduce_zero_label=False,
83
+ nodata=label_nodata,
84
+ nodata_replace=ignore_index,
85
+ ),
86
+ dict(type="BandsExtract", bands=bands),
87
+ dict(type="ConstantMultiply", constant=constant),
88
+ dict(type="RandomFlip", prob=0.5),
89
+ dict(type="ToTensor", keys=["img", "gt_semantic_seg"]),
90
+ # to channels first
91
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
92
+ dict(type="TorchNormalize", **img_norm_cfg),
93
+ dict(type="TorchRandomCrop", crop_size=crop_size),
94
+ dict(
95
+ type="Reshape",
96
+ keys=["img"],
97
+ new_shape=(len(bands), num_frames, tile_size, tile_size),
98
+ ),
99
+ dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)),
100
+ dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"),
101
+ dict(type="Collect", keys=["img", "gt_semantic_seg"]),
102
+ ]
103
+
104
+
105
+ test_pipeline = [
106
+ dict(
107
+ type="LoadGeospatialImageFromFile",
108
+ to_float32=False,
109
+ nodata=image_nodata,
110
+ nodata_replace=image_nodata_replace,
111
+ channels_last=False
112
+ ),
113
+ dict(type="BandsExtract", bands=bands),
114
+ dict(type="ConstantMultiply", constant=constant),
115
+ dict(type="ToTensor", keys=["img"]),
116
+ # to channels first
117
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
118
+ dict(type="TorchNormalize", **img_norm_cfg),
119
+ dict(
120
+ type="Reshape",
121
+ keys=["img"],
122
+ new_shape=(len(bands), num_frames, -1, -1),
123
+ look_up={'2': 1, '3': 2}
124
+ ),
125
+ dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"),
126
+ dict(
127
+ type="CollectTestList",
128
+ keys=["img"],
129
+ meta_keys=[
130
+ "img_info",
131
+ "seg_fields",
132
+ "img_prefix",
133
+ "seg_prefix",
134
+ "filename",
135
+ "ori_filename",
136
+ "img",
137
+ "img_shape",
138
+ "ori_shape",
139
+ "pad_shape",
140
+ "scale_factor",
141
+ "img_norm_cfg",
142
+ ],
143
+ ),
144
+ ]
145
+
146
+ # Dataset
147
+ data = dict(
148
+ samples_per_gpu=samples_per_gpu,
149
+ workers_per_gpu=num_workers,
150
+ train=dict(
151
+ type=dataset_type,
152
+ CLASSES=CLASSES,
153
+ data_root=data_root,
154
+ img_dir=img_dir,
155
+ ann_dir=ann_dir,
156
+ img_suffix=img_suffix,
157
+ seg_map_suffix=seg_map_suffix,
158
+ pipeline=train_pipeline,
159
+ ignore_index=ignore_index,
160
+ split=splits["train"],
161
+ ),
162
+ val=dict(
163
+ type=dataset_type,
164
+ CLASSES=CLASSES,
165
+ data_root=data_root,
166
+ img_dir=img_dir,
167
+ ann_dir=ann_dir,
168
+ img_suffix=img_suffix,
169
+ seg_map_suffix=seg_map_suffix,
170
+ pipeline=test_pipeline,
171
+ ignore_index=ignore_index,
172
+ split=splits["val"],
173
+ gt_seg_map_loader_cfg=dict(nodata=label_nodata, nodata_replace=ignore_index)
174
+ ),
175
+ test=dict(
176
+ type=dataset_type,
177
+ CLASSES=CLASSES,
178
+ data_root=data_root,
179
+ img_dir=img_dir,
180
+ ann_dir=ann_dir,
181
+ img_suffix=img_suffix,
182
+ seg_map_suffix=seg_map_suffix,
183
+ pipeline=test_pipeline,
184
+ ignore_index=ignore_index,
185
+ split=splits["test"],
186
+ gt_seg_map_loader_cfg=dict(nodata=label_nodata, nodata_replace=ignore_index),
187
+ ),
188
+ )
189
+
190
+ # Training
191
+ optimizer = dict(type="Adam", lr=6e-5, weight_decay=0.05)
192
+ optimizer_config = dict(grad_clip=None)
193
+ lr_config = dict(
194
+ policy="poly",
195
+ warmup="linear",
196
+ warmup_iters=1500,
197
+ warmup_ratio=1e-6,
198
+ power=1.0,
199
+ min_lr=0.0,
200
+ by_epoch=False,
201
+ )
202
+
203
+ log_config = dict(
204
+ interval=10,
205
+ hooks=[
206
+ dict(type='TextLoggerHook', by_epoch=True),
207
+ dict(type='TensorboardLoggerHook', by_epoch=True),
208
+ ])
209
+
210
+ checkpoint_config = dict(
211
+ by_epoch=True, interval=10, out_dir=save_path
212
+ )
213
+
214
+ evaluation = dict(
215
+ interval=eval_epoch_interval, metric="mIoU", pre_eval=True, save_best="mIoU", by_epoch=True
216
+ )
217
+
218
+ runner = dict(type="EpochBasedRunner", max_epochs=epochs)
219
+
220
+ workflow = [("train", 1),("val", 1)]
221
+
222
+ norm_cfg = dict(type="BN", requires_grad=True)
223
+
224
+ ce_weights = [0.3, 0.7]
225
+
226
+ model = dict(
227
+ type="TemporalEncoderDecoder",
228
+ frozen_backbone=False,
229
+ backbone=dict(
230
+ type="TemporalViTEncoder",
231
+ pretrained=pretrained_weights_path,
232
+ img_size=img_size,
233
+ patch_size=patch_size,
234
+ num_frames=num_frames,
235
+ tubelet_size=1,
236
+ in_chans=len(bands),
237
+ embed_dim=embed_dim,
238
+ depth=num_layers,
239
+ num_heads=num_heads,
240
+ mlp_ratio=4.0,
241
+ norm_pix_loss=False,
242
+ ),
243
+ neck=dict(
244
+ type="ConvTransformerTokensToEmbeddingNeck",
245
+ embed_dim=num_frames*embed_dim,
246
+ output_embed_dim=embed_dim,
247
+ drop_cls_token=True,
248
+ Hp=img_size // patch_size,
249
+ Wp=img_size // patch_size,
250
+ ),
251
+ decode_head=dict(
252
+ num_classes=num_classes,
253
+ in_channels=embed_dim,
254
+ type="FCNHead",
255
+ in_index=-1,
256
+ ignore_index=ignore_index,
257
+ channels=256,
258
+ num_convs=1,
259
+ concat_input=False,
260
+ dropout_ratio=0.1,
261
+ norm_cfg=norm_cfg,
262
+ align_corners=False,
263
+ loss_decode=dict(
264
+ type="CrossEntropyLoss",
265
+ use_sigmoid=False,
266
+ loss_weight=1,
267
+ class_weight=ce_weights,
268
+ ),
269
+ ),
270
+ auxiliary_head=dict(
271
+ num_classes=num_classes,
272
+ in_channels=embed_dim,
273
+ ignore_index=ignore_index,
274
+ type="FCNHead",
275
+ in_index=-1,
276
+ channels=256,
277
+ num_convs=2,
278
+ concat_input=False,
279
+ dropout_ratio=0.1,
280
+ norm_cfg=norm_cfg,
281
+ align_corners=False,
282
+ loss_decode=dict(
283
+ type="CrossEntropyLoss",
284
+ use_sigmoid=False,
285
+ loss_weight=1,
286
+ class_weight=ce_weights,
287
+ ),
288
+ ),
289
+ train_cfg=dict(),
290
+ test_cfg=dict(mode="slide", stride=(int(tile_size/2), int(tile_size/2)), crop_size=(tile_size, tile_size)),
291
+ )