File size: 3,400 Bytes
a82c831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# lightning.pytorch==2.1.1
seed_everything: 42

### Trainer configuration
trainer:
  accelerator: auto
  strategy: auto
  devices: auto
  num_nodes: 1
  # precision: 16-mixed
  logger:
    class_path: TensorBoardLogger
    init_args:
      save_dir: ./experiments
      name: finetune_region
  callbacks:
    - class_path: RichProgressBar
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch
    - class_path: EarlyStopping
      init_args:
        monitor: val/loss
        patience: 100
  max_epochs: 300
  check_val_every_n_epoch: 1
  log_every_n_steps: 20
  enable_checkpointing: true
  default_root_dir: ./experiments

### Data configuration
data:
  class_path: GenericNonGeoPixelwiseRegressionDataModule
  init_args:
    batch_size: 64
    num_workers: 8
    train_transform:
      - class_path: albumentations.HorizontalFlip
        init_args:
          p: 0.5      
      - class_path: albumentations.Rotate
        init_args:
          limit: 30
          border_mode: 0 # cv2.BORDER_CONSTANT
          value: 0
          # mask_value: 1
          p: 0.5
      - class_path: ToTensorV2
    # Specify all bands which are in the input data. 
    # -1 are placeholders for bands that are in the data but that we will discard
    dataset_bands: 
      - -1
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - -1
      - -1
      - -1
      - -1
    output_bands: #Specify the bands which are used from the input data.
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    rgb_indices:
      - 2
      - 1
      - 0
    # Directory roots to training, validation and test datasplits:
    train_data_root: train_images
    train_label_data_root: train_labels
    val_data_root: val_images
    val_label_data_root: val_labels
    test_data_root: test_images
    test_label_data_root: test_labels
    means: # Mean value of the training dataset per band
    - 547.36707
    - 898.5121
    - 1020.9082
    - 2665.5352
    - 2340.584
    - 1610.1407
    stds: # Standard deviation of the training dataset per band
    - 411.4701
    - 558.54065
    - 815.94025
    - 812.4403
    - 1113.7145
    - 1067.641
    # Nodata value in label data 
    no_label_replace: -1 
    # Nodata value in the input data
    no_data_replace: 0

### Model configuration
model:
  class_path: terratorch.tasks.PixelwiseRegressionTask 
  init_args:
    model_args:
      decoder: UperNetDecoder
      pretrained: false
      backbone: prithvi_swin_B
      backbone_drop_path_rate: 0.3
      decoder_channels: 32
      in_channels: 6
      bands:
        - BLUE
        - GREEN
        - RED
        - NIR_NARROW
        - SWIR_1
        - SWIR_2
      num_frames: 1
      head_dropout: 0.16194593880230534
      head_final_act: torch.nn.ReLU
      head_learned_upscale_layers: 2
    loss: rmse
    ignore_index: -1
    freeze_backbone: false
    freeze_decoder: false
    model_factory: PrithviModelFactory
    # uncomment this block for tiled inference
    # tiled_inference_parameters:
    #   h_crop: 224
    #   h_stride: 192
    #   w_crop: 224
    #   w_stride: 192
    #   average_patches: true
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 0.00031406904191973693
    weight_decay: 0.03283253068408954
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss