File size: 3,676 Bytes
3bb9581 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
logger: True # will use tensorboardlogger
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 30
max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 1
enable_checkpointing: true
default_root_dir: ./../data/fine_tuning/granite_geospatial_uki_flood_detection_v1
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 16
num_workers: 1
constant_scale: 0.0001
dataset_bands: # what bands are in your data
- VV
- VH
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- CLOUD
output_bands: # which bands do you want to fine-tune
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- VV
- VH
- CLOUD
rgb_indices:
- 4
- 3
- 2
train_data_root: ./../data/regions/uki/images/
train_label_data_root: ./../data/regions/uki/labels_without_cloud/
val_data_root: ./../data/regions/uki/images/
val_label_data_root: ./../data/regions/uki/labels_without_cloud/
test_data_root: ./../data/regions/uki/images/
test_label_data_root: ./../data/regions/uki/labels_without_cloud/
train_split: ./../data/regions/uki/splits/flood_train_data.txt
test_split: ./../data/regions/uki/splits/flood_test_data.txt
val_split: ./../data/regions/uki/splits/flood_val_data.txt
img_grep: "*_image.tif"
label_grep: "*_label.tif"
no_label_replace: -1
no_data_replace: 0
means:
- 0.08867253281911215 # BLUE
- 0.09101736325581869 # GREEN
- 0.08757093732833862 # RED
- 0.1670982579167684 # NIR_NARROW
- 0.09420119639078776 # SWIR_1
- 0.07141083437601725 # SWIR_2
- -0.0017641318140774339 # VV
- -0.002356150351719506 # VH
- 0.00002777560551961263 # CLOUD
stds:
- 0.13656951175974685
- 0.13202436625655786
- 0.1307223895526036
- 0.18946390520629108
- 0.11561659013865118
- 0.09351007561544347
- 0.001035692652952644
- 0.000864295592912648
- 0.00004478924301636066
num_classes: 2
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
backbone_pretrained: false
backbone: granite_geospatial_uki
backbone_pretrain_img_size: 512
decoder_channels: 256
backbone_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- VV
- VH
- CLOUD
num_classes: 2
head_dropout: 0.1
decoder_num_convs: 4
head_channel_list:
- 256
necks:
- name: SelectIndices
indices:
- -1
- name: ReshapeTokensToImage
loss: ce
aux_heads:
- name: aux_head
decoder: FCNDecoder
decoder_args:
decoder_channels: 256
decoder_in_index: -1
decoder_num_convs: 2
head_dropout: 0.1
aux_loss:
aux_head: 1.0
ignore_index: -1
class_weights:
- 0.3
- 0.7
freeze_backbone: false
freeze_decoder: false
model_factory: EncoderDecoderFactory
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 6.e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss |