# lightning.pytorch==2.1.1
seed_everything: 0

### Trainer configuration
trainer:
  accelerator: auto
  strategy: auto
  devices: auto
  num_nodes: 1
  # precision: 16-mixed
  logger:
    # You can swtich to TensorBoard for logging by uncommenting the below line and commenting out the procedding line
    #class_path: TensorBoardLogger
    class_path: lightning.pytorch.loggers.csv_logs.CSVLogger
    init_args:
      save_dir: ./experiments
      name: fine_tune_suhi
  callbacks:
    - class_path: RichProgressBar
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch
    - class_path: EarlyStopping
      init_args:
        monitor: val/loss
        patience: 600
  max_epochs: 600
  check_val_every_n_epoch: 1
  log_every_n_steps: 10
  enable_checkpointing: true
  default_root_dir: ./experiments
out_dtype: float32

### Data configuration
data:
  class_path: GenericNonGeoPixelwiseRegressionDataModule
  init_args:
    batch_size: 64
    num_workers: 8
    train_transform:
      - class_path: albumentations.HorizontalFlip
        init_args:
          p: 0.5
      - class_path: albumentations.Rotate
        init_args:
          limit: 30
          border_mode: 0 # cv2.BORDER_CONSTANT
          value: 0
          mask_value: 1
          p: 0.5
      - class_path: ToTensorV2
    # Specify all bands which are in the input data. 
    dataset_bands:
    # 6 HLS bands
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    # ERA5-Land t2m_spatial_avg
      - 7
    # ERA5-Land t2m_sunrise_avg
      - 8
    # ERA5-Land t2m_midnight_avg
      - 9
    # ERA5-Land t2m_delta_avg
      - 10
    # cos_tod
      - 11
    # sin_tod
      - 12
    # cos_doy
      - 13
    # sin_doy
      - 14
    # Specify the bands which are used from the input data.
    # Bands 8 - 14 were discarded in the final model
    output_bands:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - 7
    rgb_indices:
      - 2
      - 1
      - 0
    # Directory roots to training, validation and test datasplits:
    train_data_root: train/inputs
    train_label_data_root: train/targets
    val_data_root: val/inputs
    val_label_data_root: val/targets
    test_data_root: test/inputs
    test_label_data_root: test/targets
    img_grep: "*.inputs.tif"
    label_grep: "*.lst.tif"
    # Nodata value in the input data
    no_data_replace: 0
    # Nodata value in label (target) data 
    no_label_replace: -9999
    # Mean value of the training dataset per band  
    means:
    - 702.4754028320312
    - 1023.23291015625
    - 1118.8924560546875 
    - 2440.750732421875 
    - 2052.705810546875 
    - 1514.15087890625 
    - 21.031919479370117 
    # Standard deviation of the training dataset per band
    stds:
    - 554.8255615234375 
    - 613.5565185546875 
    - 745.929443359375
    - 715.0111083984375 
    - 761.47607421875 
    - 734.991943359375 
    - 8.66781997680664 

### Model configuration
model:
  class_path: terratorch.tasks.PixelwiseRegressionTask
  init_args:
    model_args:
      decoder: UperNetDecoder
      pretrained: false
      backbone: prithvi_swin_L
      img_size: 224
      backbone_drop_path_rate: 0.3
      decoder_channels: 256
      in_channels: 7
      bands:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - 7
      num_frames: 1
    loss: rmse
    aux_heads:
      - name: aux_head
        decoder: IdentityDecoder
        decoder_args:
          head_dropout: 0.5
          head_channel_list:
          - 1
          head_final_act: torch.nn.LazyLinear
    aux_loss:
      aux_head: 0.4
    ignore_index: -9999
    freeze_backbone: false
    freeze_decoder: false
    model_factory: PrithviModelFactory
    # uncomment this block for tiled inference
    tiled_inference_parameters:
       h_crop: 224
       h_stride: 224
       w_crop: 224
       w_stride: 224
       average_patches: true
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 0.0001
    weight_decay: 0.05
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss