import os # base options dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None resume_from = None cudnn_benchmark = True custom_imports = dict(imports=["geospatial_fm"]) data_root = "/home" dataset_type = "GeospatialDataset" num_classes = 2 num_frames = 1 img_size = 224 num_workers = 2 samples_per_gpu = 4 CLASSES = (0, 1) img_norm_cfg = dict(means=[0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503], stds=[0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205]) bands = [1, 2, 3, 8, 11, 12] tile_size = img_size orig_nsize = 512 crop_size = (tile_size, tile_size) img_dir = data_root + "/files/S1/" ann_dir = data_root + "/files/Labels/" img_suffix = f"_S1Hand.tif" seg_map_suffix = f"_LabelHand.tif" splits = { "train": "/home/flood_train_data.csv", "val": "/home/flood_val_data.csv", "test": "/home/flood_test_data.csv", } splits = {k: os.path.abspath(v) for (k, v) in splits.items()} ignore_index = 2 label_nodata = -1 image_nodata = -9999 image_nodata_replace = 0 constant = 0.0001 # Model # TO BE DEFINED BY USER: path to pretrained backbone weights pretrained_weights_path = "/home/Prithvi_100M.pt" num_layers = 12 patch_size = 16 embed_dim = 768 num_heads = 12 tubelet_size = 1 epochs = 30 eval_epoch_interval = 5 experiment = "/home/output" save_path = experiment train_pipeline = [ dict( type="LoadGeospatialImageFromFile", to_float32=False, nodata=image_nodata, nodata_replace=image_nodata_replace ), dict( type="LoadGeospatialAnnotations", reduce_zero_label=False, nodata=label_nodata, nodata_replace=ignore_index, ), dict(type="BandsExtract", bands=bands), dict(type="ConstantMultiply", constant=constant), dict(type="RandomFlip", prob=0.5), dict(type="ToTensor", keys=["img", "gt_semantic_seg"]), # to channels first dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)), dict(type="TorchNormalize", **img_norm_cfg), dict(type="TorchRandomCrop", crop_size=crop_size), dict( type="Reshape", keys=["img"], new_shape=(len(bands), num_frames, tile_size, tile_size), ), dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)), dict(type="CastTensor", keys=[ "gt_semantic_seg"], new_type="torch.LongTensor"), dict(type="Collect", keys=["img", "gt_semantic_seg"]), ] test_pipeline = [ dict( type="LoadGeospatialImageFromFile", to_float32=False, nodata=image_nodata, nodata_replace=image_nodata_replace ), dict(type="BandsExtract", bands=bands), dict(type="ConstantMultiply", constant=constant), dict(type="ToTensor", keys=["img"]), # to channels first dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)), dict(type="TorchNormalize", **img_norm_cfg), dict( type="Reshape", keys=["img"], new_shape=(len(bands), num_frames, -1, -1), look_up={'2': 1, '3': 2} ), dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"), dict( type="CollectTestList", keys=["img"], meta_keys=[ "img_info", "seg_fields", "img_prefix", "seg_prefix", "filename", "ori_filename", "img", "img_shape", "ori_shape", "pad_shape", "scale_factor", "img_norm_cfg", ], ), ] data = dict( samples_per_gpu=samples_per_gpu, workers_per_gpu=num_workers, train=dict( type=dataset_type, CLASSES=CLASSES, data_root=data_root, img_dir=img_dir, ann_dir=ann_dir, img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, pipeline=train_pipeline, ignore_index=ignore_index, split=splits["train"], ), val=dict( type=dataset_type, CLASSES=CLASSES, data_root=data_root, img_dir=img_dir, ann_dir=ann_dir, img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, pipeline=test_pipeline, ignore_index=ignore_index, split=splits["val"], gt_seg_map_loader_cfg=dict( nodata=label_nodata, nodata_replace=ignore_index) ), test=dict( type=dataset_type, CLASSES=CLASSES, data_root=data_root, img_dir=img_dir, ann_dir=ann_dir, img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, pipeline=test_pipeline, ignore_index=ignore_index, split=splits["test"], gt_seg_map_loader_cfg=dict( nodata=label_nodata, nodata_replace=ignore_index), ), ) optimizer = dict(type="SGD", lr=6e-5, weight_decay=0.05) optimizer_config = dict(grad_clip=None) lr_config = dict( policy="poly", warmup="linear", warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False, ) log_config = dict( interval=10, hooks=[ dict(type='TextLoggerHook', by_epoch=True), dict(type='TensorboardLoggerHook', by_epoch=True), ]) checkpoint_config = dict( by_epoch=True, interval=10, out_dir=save_path ) evaluation = dict( interval=eval_epoch_interval, metric="mIoU", pre_eval=True, save_best="mIoU", by_epoch=True ) runner = dict(type="EpochBasedRunner", max_epochs=epochs) workflow = [("train", 1), ("val", 1)] norm_cfg = dict(type="BN", requires_grad=True) ce_weights = [0.3, 0.7] model = dict( type="TemporalEncoderDecoder", frozen_backbone=False, backbone=dict( type="TemporalViTEncoder", pretrained=pretrained_weights_path, img_size=img_size, patch_size=patch_size, num_frames=num_frames, tubelet_size=1, in_chans=len(bands), embed_dim=embed_dim, depth=num_layers, num_heads=num_heads, mlp_ratio=4.0, norm_pix_loss=False, ), neck=dict( type="ConvTransformerTokensToEmbeddingNeck", embed_dim=num_frames*embed_dim, output_embed_dim=embed_dim, drop_cls_token=True, Hp=img_size // patch_size, Wp=img_size // patch_size, ), decode_head=dict( num_classes=num_classes, in_channels=embed_dim, type="FCNHead", in_index=-1, ignore_index=ignore_index, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1, class_weight=ce_weights, ), ), auxiliary_head=dict( num_classes=num_classes, in_channels=embed_dim, ignore_index=ignore_index, type="FCNHead", in_index=-1, channels=256, num_convs=2, concat_input=False, dropout_ratio=0.1, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1, class_weight=ce_weights, ), ), train_cfg=dict(), test_cfg=dict(mode="slide", stride=(int(tile_size/2), int(tile_size/2)), crop_size=(tile_size, tile_size)), )