File size: 3,676 Bytes
3bb9581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
  accelerator: cpu
  strategy: auto
  devices: auto
  num_nodes: 1
  logger: True # will use tensorboardlogger

  callbacks:
    - class_path: RichProgressBar
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch
    - class_path: EarlyStopping 
      init_args:
        monitor: val/loss
        patience: 30

  max_epochs: 200
  check_val_every_n_epoch: 1
  log_every_n_steps: 1
  enable_checkpointing: true
  default_root_dir: ./../data/fine_tuning/granite_geospatial_uki_flood_detection_v1
data:
  class_path: GenericNonGeoSegmentationDataModule
  init_args:
    batch_size: 16
    num_workers: 1
    constant_scale: 0.0001
    dataset_bands: # what bands are in your data 
      - VV
      - VH
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - CLOUD
    output_bands: # which bands do you want to fine-tune 
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - VV
      - VH
      - CLOUD
    rgb_indices:
      - 4
      - 3
      - 2
    train_data_root: ./../data/regions/uki/images/
    train_label_data_root: ./../data/regions/uki/labels_without_cloud/
    val_data_root: ./../data/regions/uki/images/
    val_label_data_root: ./../data/regions/uki/labels_without_cloud/
    test_data_root: ./../data/regions/uki/images/
    test_label_data_root: ./../data/regions/uki/labels_without_cloud/
    train_split: ./../data/regions/uki/splits/flood_train_data.txt
    test_split: ./../data/regions/uki/splits/flood_test_data.txt
    val_split: ./../data/regions/uki/splits/flood_val_data.txt
    img_grep: "*_image.tif"
    label_grep: "*_label.tif"
    no_label_replace: -1
    no_data_replace: 0
    means:
      - 0.08867253281911215    # BLUE
      - 0.09101736325581869    # GREEN
      - 0.08757093732833862    # RED
      - 0.1670982579167684    # NIR_NARROW
      - 0.09420119639078776    # SWIR_1
      - 0.07141083437601725    # SWIR_2
      - -0.0017641318140774339 # VV
      - -0.002356150351719506  # VH
      - 0.00002777560551961263 # CLOUD

    stds:
      - 0.13656951175974685
      - 0.13202436625655786
      - 0.1307223895526036
      - 0.18946390520629108
      - 0.11561659013865118
      - 0.09351007561544347
      - 0.001035692652952644
      - 0.000864295592912648
      - 0.00004478924301636066
      
    num_classes: 2

model:
  class_path: terratorch.tasks.SemanticSegmentationTask 
  init_args:
    model_args:
      decoder: FCNDecoder
      backbone_pretrained: false
      backbone: granite_geospatial_uki      
      backbone_pretrain_img_size: 512
      decoder_channels: 256
      backbone_bands:
        - BLUE
        - GREEN
        - RED
        - NIR_NARROW
        - SWIR_1
        - SWIR_2
        - VV
        - VH
        - CLOUD 
      num_classes: 2
      head_dropout: 0.1
      decoder_num_convs: 4
      head_channel_list:
        - 256
      necks:
        - name: SelectIndices
          indices:
            - -1
        - name: ReshapeTokensToImage
    loss: ce
    aux_heads:
      - name: aux_head
        decoder: FCNDecoder
        decoder_args:
          decoder_channels: 256
          decoder_in_index: -1
          decoder_num_convs: 2
          head_dropout: 0.1
    aux_loss:
      aux_head: 1.0
    ignore_index: -1
    class_weights:
      - 0.3
      - 0.7
    freeze_backbone: false
    freeze_decoder: false
    model_factory: EncoderDecoderFactory
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 6.e-5
    weight_decay: 0.05
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss