# lightning.pytorch==2.1.1 seed_everything: 42 ### Trainer configuration trainer: accelerator: auto strategy: auto devices: auto num_nodes: 1 # precision: 16-mixed logger: class_path: TensorBoardLogger init_args: save_dir: ./experiments name: finetune_region callbacks: - class_path: RichProgressBar - class_path: LearningRateMonitor init_args: logging_interval: epoch - class_path: EarlyStopping init_args: monitor: val/loss patience: 100 max_epochs: 300 check_val_every_n_epoch: 1 log_every_n_steps: 20 enable_checkpointing: true default_root_dir: ./experiments ### Data configuration data: class_path: GenericNonGeoPixelwiseRegressionDataModule init_args: batch_size: 64 num_workers: 8 train_transform: - class_path: albumentations.HorizontalFlip init_args: p: 0.5 - class_path: albumentations.Rotate init_args: limit: 30 border_mode: 0 # cv2.BORDER_CONSTANT value: 0 # mask_value: 1 p: 0.5 - class_path: ToTensorV2 # Specify all bands which are in the input data. # -1 are placeholders for bands that are in the data but that we will discard dataset_bands: - -1 - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 - -1 - -1 - -1 - -1 output_bands: #Specify the bands which are used from the input data. - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 rgb_indices: - 2 - 1 - 0 # Directory roots to training, validation and test datasplits: train_data_root: train_images train_label_data_root: train_labels val_data_root: val_images val_label_data_root: val_labels test_data_root: test_images test_label_data_root: test_labels means: # Mean value of the training dataset per band - 547.36707 - 898.5121 - 1020.9082 - 2665.5352 - 2340.584 - 1610.1407 stds: # Standard deviation of the training dataset per band - 411.4701 - 558.54065 - 815.94025 - 812.4403 - 1113.7145 - 1067.641 # Nodata value in label data no_label_replace: -1 # Nodata value in the input data no_data_replace: 0 ### Model configuration model: class_path: terratorch.tasks.PixelwiseRegressionTask init_args: model_args: decoder: UperNetDecoder pretrained: false backbone: prithvi_swin_B backbone_drop_path_rate: 0.3 decoder_channels: 32 in_channels: 6 bands: - BLUE - GREEN - RED - NIR_NARROW - SWIR_1 - SWIR_2 num_frames: 1 head_dropout: 0.16194593880230534 head_final_act: torch.nn.ReLU head_learned_upscale_layers: 2 loss: rmse ignore_index: -1 freeze_backbone: false freeze_decoder: false model_factory: PrithviModelFactory # uncomment this block for tiled inference # tiled_inference_parameters: # h_crop: 224 # h_stride: 192 # w_crop: 224 # w_stride: 192 # average_patches: true optimizer: class_path: torch.optim.AdamW init_args: lr: 0.00031406904191973693 weight_decay: 0.03283253068408954 lr_scheduler: class_path: ReduceLROnPlateau init_args: monitor: val/loss