File size: 3,350 Bytes
62c1d5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# lightning.pytorch==2.4.0
seed_everything: 42
### Trainer configuration
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: ./experiments
name: finetune_region
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 300
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: ./experiments
### Data configuration
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 64
num_workers: 8
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.RandomRotate90
init_args:
p: 0.5
- class_path: albumentations.VerticalFlip
init_args:
p: 0.5
- class_path: ToTensorV2
# Specify all bands which are in the input data.
# -1 are placeholders for bands that are in the data but that we will discard
dataset_bands:
- -1
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- -1
- -1
- -1
- -1
output_bands: #Specify the bands which are used from the input data.
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
# Directory roots to training, validation and test datasplits:
train_data_root: train_images
train_label_data_root: train_labels
val_data_root: val_images
val_label_data_root: val_labels
test_data_root: test_images
test_label_data_root: test_labels
means: # Mean value of the training dataset per band
- 556.025024
- 910.020020
- 1039.141968
- 2665.447266
- 2361.062256
- 1633.309326
stds: # Standard deviation of the training dataset per band
- 413.787903
- 562.086670
- 819.830444
- 816.528381
- 1120.049438
- 1072.057861
# Nodata value in label data
no_label_replace: -1
# Nodata value in the input data
no_data_replace: 0
### Model configuration
model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: false
backbone: prithvi_swin_B
backbone_drop_path_rate: 0.3
decoder_channels: 32
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.16
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
ignore_index: -1
freeze_backbone: false
freeze_decoder: false
model_factory: PrithviModelFactory
# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 5.0e-05
weight_decay: 0.3
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
out_dtype: float32 |