annezj commited on
Commit
671ad95
·
verified ·
1 Parent(s): 016c16d

Upload config.yaml

Browse files
Files changed (1) hide show
  1. config.yaml +147 -0
config.yaml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.1.1
2
+ seed_everything: 0
3
+ trainer:
4
+ accelerator: cpu
5
+ strategy: auto
6
+ devices: auto
7
+ num_nodes: 1
8
+ logger: True # will use tensorboardlogger
9
+
10
+ callbacks:
11
+ - class_path: RichProgressBar
12
+ - class_path: LearningRateMonitor
13
+ init_args:
14
+ logging_interval: epoch
15
+ - class_path: EarlyStopping
16
+ init_args:
17
+ monitor: val/loss
18
+ patience: 30
19
+
20
+ max_epochs: 200
21
+ check_val_every_n_epoch: 1
22
+ log_every_n_steps: 1
23
+ enable_checkpointing: true
24
+ default_root_dir: ./../data/fine_tuning/granite_geospatial_uki_flood_detection_v1
25
+ data:
26
+ class_path: GenericNonGeoSegmentationDataModule
27
+ init_args:
28
+ batch_size: 16
29
+ num_workers: 1
30
+ constant_scale: 0.0001
31
+ dataset_bands: # what bands are in your data
32
+ - VV
33
+ - VH
34
+ - BLUE
35
+ - GREEN
36
+ - RED
37
+ - NIR_NARROW
38
+ - SWIR_1
39
+ - SWIR_2
40
+ - CLOUD
41
+ output_bands: # which bands do you want to fine-tune
42
+ - BLUE
43
+ - GREEN
44
+ - RED
45
+ - NIR_NARROW
46
+ - SWIR_1
47
+ - SWIR_2
48
+ - VV
49
+ - VH
50
+ - CLOUD
51
+ rgb_indices:
52
+ - 4
53
+ - 3
54
+ - 2
55
+ train_data_root: ./../data/regions/uki/images/
56
+ train_label_data_root: ./../data/regions/uki/labels_without_cloud/
57
+ val_data_root: ./../data/regions/uki/images/
58
+ val_label_data_root: ./../data/regions/uki/labels_without_cloud/
59
+ test_data_root: ./../data/regions/uki/images/
60
+ test_label_data_root: ./../data/regions/uki/labels_without_cloud/
61
+ train_split: ./../data/regions/uki/splits/flood_train_data.txt
62
+ test_split: ./../data/regions/uki/splits/flood_test_data.txt
63
+ val_split: ./../data/regions/uki/splits/flood_val_data.txt
64
+ img_grep: "*_image.tif"
65
+ label_grep: "*_label.tif"
66
+ no_label_replace: -1
67
+ no_data_replace: 0
68
+ means:
69
+ - 0.08867253281911215 # BLUE
70
+ - 0.09101736325581869 # GREEN
71
+ - 0.08757093732833862 # RED
72
+ - 0.1670982579167684 # NIR_NARROW
73
+ - 0.09420119639078776 # SWIR_1
74
+ - 0.07141083437601725 # SWIR_2
75
+ - -0.0017641318140774339 # VV
76
+ - -0.002356150351719506 # VH
77
+ - 0.00002777560551961263 # CLOUD
78
+
79
+ stds:
80
+ - 0.13656951175974685
81
+ - 0.13202436625655786
82
+ - 0.1307223895526036
83
+ - 0.18946390520629108
84
+ - 0.11561659013865118
85
+ - 0.09351007561544347
86
+ - 0.001035692652952644
87
+ - 0.000864295592912648
88
+ - 0.00004478924301636066
89
+
90
+ num_classes: 2
91
+
92
+ model:
93
+ class_path: terratorch.tasks.SemanticSegmentationTask
94
+ init_args:
95
+ model_args:
96
+ decoder: FCNDecoder
97
+ backbone_pretrained: false
98
+ backbone: granite_geospatial_uki
99
+ backbone_pretrain_img_size: 512
100
+ decoder_channels: 256
101
+ backbone_bands:
102
+ - BLUE
103
+ - GREEN
104
+ - RED
105
+ - NIR_NARROW
106
+ - SWIR_1
107
+ - SWIR_2
108
+ - VV
109
+ - VH
110
+ - CLOUD
111
+ num_classes: 2
112
+ head_dropout: 0.1
113
+ decoder_num_convs: 4
114
+ head_channel_list:
115
+ - 256
116
+ necks:
117
+ - name: SelectIndices
118
+ indices:
119
+ - -1
120
+ - name: ReshapeTokensToImage
121
+ loss: ce
122
+ aux_heads:
123
+ - name: aux_head
124
+ decoder: FCNDecoder
125
+ decoder_args:
126
+ decoder_channels: 256
127
+ decoder_in_index: -1
128
+ decoder_num_convs: 2
129
+ head_dropout: 0.1
130
+ aux_loss:
131
+ aux_head: 1.0
132
+ ignore_index: -1
133
+ class_weights:
134
+ - 0.3
135
+ - 0.7
136
+ freeze_backbone: false
137
+ freeze_decoder: false
138
+ model_factory: EncoderDecoderFactory
139
+ optimizer:
140
+ class_path: torch.optim.AdamW
141
+ init_args:
142
+ lr: 6.e-5
143
+ weight_decay: 0.05
144
+ lr_scheduler:
145
+ class_path: ReduceLROnPlateau
146
+ init_args:
147
+ monitor: val/loss