| import os |
| import math |
|
|
|
|
| class Config: |
| def __init__(self) -> None: |
| |
| |
| self.sys_home_dir = [os.path.expanduser("~"), "/mnt/data"][0] |
| self.data_root_dir = os.path.join(self.sys_home_dir, "datasets/dis") |
|
|
| |
| self.task = ["DIS5K", "COD", "HRSOD", "General", "General-2K", "Matting"][0] |
| self.testsets = { |
| |
| "DIS5K": ",".join( |
| ["DIS-VD", "DIS-TE1", "DIS-TE2", "DIS-TE3", "DIS-TE4"][:1] |
| ), |
| "COD": ",".join(["CHAMELEON", "NC4K", "TE-CAMO", "TE-COD10K"]), |
| "HRSOD": ",".join( |
| ["DAVIS-S", "TE-HRSOD", "TE-UHRSD", "DUT-OMRON", "TE-DUTS"] |
| ), |
| |
| "General": ",".join(["DIS-VD", "TE-P3M-500-NP"]), |
| "General-2K": ",".join(["DIS-VD", "TE-P3M-500-NP"]), |
| "Matting": ",".join(["TE-P3M-500-NP", "TE-AM-2k"]), |
| }[self.task] |
| datasets_all = "+".join( |
| [ |
| ds |
| for ds in ( |
| os.listdir(os.path.join(self.data_root_dir, self.task)) |
| if os.path.isdir(os.path.join(self.data_root_dir, self.task)) |
| else [] |
| ) |
| if ds not in self.testsets.split(",") |
| ] |
| ) |
| self.training_set = { |
| "DIS5K": ["DIS-TR", "DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4"][0], |
| "COD": "TR-COD10K+TR-CAMO", |
| "HRSOD": [ |
| "TR-DUTS", |
| "TR-HRSOD", |
| "TR-UHRSD", |
| "TR-DUTS+TR-HRSOD", |
| "TR-DUTS+TR-UHRSD", |
| "TR-HRSOD+TR-UHRSD", |
| "TR-DUTS+TR-HRSOD+TR-UHRSD", |
| ][5], |
| "General": datasets_all, |
| "General-2K": datasets_all, |
| "Matting": datasets_all, |
| }[self.task] |
| self.prompt4loc = ["dense", "sparse"][0] |
|
|
| |
| self.load_all = False |
| self.compile = True |
| |
| |
| |
| self.precisionHigh = True |
|
|
| |
| self.ms_supervision = True |
| self.out_ref = self.ms_supervision and True |
| self.dec_ipt = True |
| self.dec_ipt_split = True |
| self.cxt_num = [0, 3][1] |
| self.mul_scl_ipt = ["", "add", "cat"][2] |
| self.dec_att = ["", "ASPP", "ASPPDeformable"][2] |
| self.squeeze_block = [ |
| "", |
| "BasicDecBlk_x1", |
| "ResBlk_x4", |
| "ASPP_x3", |
| "ASPPDeformable_x3", |
| ][1] |
| self.dec_blk = ["BasicDecBlk", "ResBlk"][0] |
|
|
| |
| self.batch_size = 4 |
| self.finetune_last_epochs = [ |
| 0, |
| { |
| "DIS5K": -40, |
| "COD": -20, |
| "HRSOD": -20, |
| "General": -40, |
| "General-2K": -20, |
| "Matting": -20, |
| }[self.task], |
| ][ |
| 1 |
| ] |
| self.lr = (1e-4 if "DIS5K" in self.task else 1e-5) * math.sqrt( |
| self.batch_size / 4 |
| ) |
| self.size = ( |
| (1024, 1024) if self.task not in ["General-2K"] else (2560, 1440) |
| ) |
| self.num_workers = max( |
| 4, self.batch_size |
| ) |
|
|
| |
| self.bb = [ |
| "vgg16", |
| "vgg16bn", |
| "resnet50", |
| "swin_v1_t", |
| "swin_v1_s", |
| "swin_v1_b", |
| "swin_v1_l", |
| "pvt_v2_b0", |
| "pvt_v2_b1", |
| "pvt_v2_b2", |
| "pvt_v2_b5", |
| ][6] |
| self.lateral_channels_in_collection = { |
| "vgg16": [512, 256, 128, 64], |
| "vgg16bn": [512, 256, 128, 64], |
| "resnet50": [1024, 512, 256, 64], |
| "pvt_v2_b2": [512, 320, 128, 64], |
| "pvt_v2_b5": [512, 320, 128, 64], |
| "swin_v1_b": [1024, 512, 256, 128], |
| "swin_v1_l": [1536, 768, 384, 192], |
| "swin_v1_t": [768, 384, 192, 96], |
| "swin_v1_s": [768, 384, 192, 96], |
| "pvt_v2_b0": [256, 160, 64, 32], |
| "pvt_v2_b1": [512, 320, 128, 64], |
| }[self.bb] |
| if self.mul_scl_ipt == "cat": |
| self.lateral_channels_in_collection = [ |
| channel * 2 for channel in self.lateral_channels_in_collection |
| ] |
| self.cxt = ( |
| self.lateral_channels_in_collection[1:][::-1][-self.cxt_num :] |
| if self.cxt_num |
| else [] |
| ) |
|
|
| |
| self.lat_blk = ["BasicLatBlk"][0] |
| self.dec_channels_inter = ["fixed", "adap"][0] |
| self.refine = ["", "itself", "RefUNet", "Refiner", "RefinerPVTInChannels4"][0] |
| self.progressive_ref = self.refine and True |
| self.ender = self.progressive_ref and False |
| self.scale = self.progressive_ref and 2 |
| self.auxiliary_classification = ( |
| False |
| ) |
| self.refine_iteration = 1 |
| self.freeze_bb = False |
| self.model = [ |
| "BiRefNet", |
| "BiRefNetC2F", |
| ][0] |
|
|
| |
| self.preproc_methods = ["flip", "enhance", "rotate", "pepper", "crop"][:4] |
| self.optimizer = ["Adam", "AdamW"][1] |
| self.lr_decay_epochs = [ |
| 1e5 |
| ] |
| self.lr_decay_rate = 0.5 |
| |
| if self.task in ["Matting"]: |
| self.lambdas_pix_last = { |
| "bce": 30 * 1, |
| "iou": 0.5 * 0, |
| "iou_patch": 0.5 * 0, |
| "mae": 100 * 1, |
| "mse": 30 * 0, |
| "triplet": 3 * 0, |
| "reg": 100 * 0, |
| "ssim": 10 * 1, |
| "cnt": 5 * 0, |
| "structure": 5 * 0, |
| } |
| elif self.task in ["General", "General-2K"]: |
| self.lambdas_pix_last = { |
| "bce": 30 * 1, |
| "iou": 0.5 * 1, |
| "iou_patch": 0.5 * 0, |
| "mae": 100 * 1, |
| "mse": 30 * 0, |
| "triplet": 3 * 0, |
| "reg": 100 * 0, |
| "ssim": 10 * 1, |
| "cnt": 5 * 0, |
| "structure": 5 * 0, |
| } |
| else: |
| self.lambdas_pix_last = { |
| |
| |
| "bce": 30 * 1, |
| "iou": 0.5 * 1, |
| "iou_patch": 0.5 * 0, |
| "mae": 30 * 0, |
| "mse": 30 * 0, |
| "triplet": 3 * 0, |
| "reg": 100 * 0, |
| "ssim": 10 * 1, |
| "cnt": 5 * 0, |
| "structure": 5 |
| * 0, |
| } |
| self.lambdas_cls = {"ce": 5.0} |
|
|
| |
| self.weights_root_dir = os.path.join(self.sys_home_dir, "weights/cv") |
| self.weights = { |
| "pvt_v2_b2": os.path.join(self.weights_root_dir, "pvt_v2_b2.pth"), |
| "pvt_v2_b5": os.path.join( |
| self.weights_root_dir, ["pvt_v2_b5.pth", "pvt_v2_b5_22k.pth"][0] |
| ), |
| "swin_v1_b": os.path.join( |
| self.weights_root_dir, |
| [ |
| "swin_base_patch4_window12_384_22kto1k.pth", |
| "swin_base_patch4_window12_384_22k.pth", |
| ][0], |
| ), |
| "swin_v1_l": os.path.join( |
| self.weights_root_dir, |
| [ |
| "swin_large_patch4_window12_384_22kto1k.pth", |
| "swin_large_patch4_window12_384_22k.pth", |
| ][0], |
| ), |
| "swin_v1_t": os.path.join( |
| self.weights_root_dir, |
| ["swin_tiny_patch4_window7_224_22kto1k_finetune.pth"][0], |
| ), |
| "swin_v1_s": os.path.join( |
| self.weights_root_dir, |
| ["swin_small_patch4_window7_224_22kto1k_finetune.pth"][0], |
| ), |
| "pvt_v2_b0": os.path.join(self.weights_root_dir, ["pvt_v2_b0.pth"][0]), |
| "pvt_v2_b1": os.path.join(self.weights_root_dir, ["pvt_v2_b1.pth"][0]), |
| } |
|
|
| |
| self.verbose_eval = True |
| self.only_S_MAE = False |
| self.SDPA_enabled = False |
|
|
| |
| self.device = [0, "cpu"][0] |
|
|
| self.batch_size_valid = 1 |
| self.rand_seed = 7 |
| run_sh_file = [f for f in os.listdir(".") if "train.sh" == f] + [ |
| os.path.join("..", f) for f in os.listdir("..") if "train.sh" == f |
| ] |
| if run_sh_file: |
| with open(run_sh_file[0], "r") as f: |
| lines = f.readlines() |
| self.save_last = int( |
| [ |
| l.strip() |
| for l in lines |
| if "'{}')".format(self.task) in l and "val_last=" in l |
| ][0] |
| .split("val_last=")[-1] |
| .split()[0] |
| ) |
| self.save_step = int( |
| [ |
| l.strip() |
| for l in lines |
| if "'{}')".format(self.task) in l and "step=" in l |
| ][0] |
| .split("step=")[-1] |
| .split()[0] |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser( |
| description="Only choose one argument to activate." |
| ) |
| parser.add_argument("--print_task", action="store_true", help="print task name") |
| parser.add_argument( |
| "--print_testsets", action="store_true", help="print validation set" |
| ) |
| args = parser.parse_args() |
|
|
| config = Config() |
| for arg_name, arg_value in args._get_kwargs(): |
| if arg_value: |
| print(config.__getattribute__(arg_name[len("print_") :])) |
|
|