from models.segmentation_models.cen import ChannelExchangingNetwork from models.segmentation_models.deeplabv3p import DeepLabV3p_r101, DeepLabV3p_r18, DeepLabV3p_r50 from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion from models.segmentation_models.linearfusebothmask.segformer import LinearFusionBothMask from models.segmentation_models.linearfusecons.segformer import LinearFusionConsistency from models.segmentation_models.linearfusemaemaskedcons.segformer import LinearFusionMAEMaskedConsistency from models.segmentation_models.linearfusemaskedcons.segformer import LinearFusionMaskedConsistency from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch from models.segmentation_models.linearfusesepdecodermaskedcons.segformer import LinearFusionSepDecoderMaskedConsistency from models.segmentation_models.linearfusetokenmix.segformer import LinearFusionTokenMix from models.segmentation_models.randomexchangecons.segformer import RandomExchangePredConsistency from models.segmentation_models.randomfusion.segformer import WeTrRandomFusion from models.segmentation_models.randomfusiondmlp.segformer import WeTrRandomFusionDMLP from models.segmentation_models.refinenet import MyRefineNet from models.segmentation_models.segformer.segformer import SegFormer from models.segmentation_models.tokenfusion.segformer import WeTr from models.segmentation_models.tokenfusionbothmask.segformer import TokenFusionBothMask from models.segmentation_models.tokenfusionmaemaskedconsistency.segformer import TokenFusionMAEMaskedConsistency from models.segmentation_models.tokenfusionmaskedconsistency.segformer import TokenFusionMaskedConsistency from models.segmentation_models.tokenfusionmaskedconsistencymixbatch.segformer import TokenFusionMaskedConsistencyMixBatch from models.segmentation_models.unifiedrepresentation.segformer import UnifiedRepresentationNetwork from models.segmentation_models.unifiedrepresentationmoddrop.segformer import UnifiedRepresentationNetworkModDrop def get_model(args, **kwargs): if args.seg_model == "dlv3p": if args.base_model == "r18": return DeepLabV3p_r18(args.num_classes, args) elif args.base_model == "r50": return DeepLabV3p_r50(args.num_classes, args) elif args.base_model == "r101": return DeepLabV3p_r101(args.num_classes, args) else: raise Exception(f"{args.base_model} not configured") elif args.seg_model == 'refinenet': if args.base_model == 'r18': return MyRefineNet(num_layers = 18, num_classes = args.num_classes) if args.base_model == 'r50': return MyRefineNet(num_layers = 50, num_classes = args.num_classes) if args.base_model == 'r101': return MyRefineNet(num_layers = 101, num_classes = args.num_classes) elif args.seg_model == 'cen': if args.base_model == 'r18': return ChannelExchangingNetwork(num_layers = 18, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) if args.base_model == 'r50': return ChannelExchangingNetwork(num_layers = 50, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) if args.base_model == 'r101': return ChannelExchangingNetwork(num_layers = 101, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) elif args.seg_model == 'segformer': return SegFormer(args.base_model, args, num_classes= args.num_classes) elif args.seg_model == 'tokenfusion': return WeTr(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) elif args.seg_model == 'randomfusion': return WeTrRandomFusion(args.base_model, args, num_classes = args.num_classes) elif args.seg_model == 'randomfusiondmlp': return WeTrRandomFusionDMLP(args.base_model, args, num_classes = args.num_classes) elif args.seg_model == 'randomexchangepredconsistency': return RandomExchangePredConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes) elif args.seg_model == 'linearfusion': pretrained = True if "pretrained_init" in args: pretrained = args.pretrained_init print("Using pretrained SegFormer? ", pretrained) return WeTrLinearFusion(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained) elif args.seg_model == 'linearfusionconsistency': return LinearFusionConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes) elif args.seg_model == 'linearfusionmaskedcons': pretrained = True if "pretrained_init" in args: pretrained = args.pretrained_init print("Using pretrained SegFormer? ", pretrained) return LinearFusionMaskedConsistency(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained) elif args.seg_model == 'linearfusionmaskedconsmixbatch': return LinearFusionMaskedConsistencyMixBatch(args.base_model, args, num_classes = args.num_classes) elif args.seg_model == 'linearfusionsepdecodermaskedcons': return LinearFusionSepDecoderMaskedConsistency(args.base_model, args, num_classes = args.num_classes) elif args.seg_model == 'linearfusionmaemaskedcons': return LinearFusionMAEMaskedConsistency(args.base_model, args, num_classes = args.num_classes) elif args.seg_model == 'tokenfusionmaskedcons': return TokenFusionMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) elif args.seg_model == 'tokenfusionmaskedconsmixbatch': return TokenFusionMaskedConsistencyMixBatch(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) elif args.seg_model == 'tokenfusionbothmask': return TokenFusionBothMask(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes, **kwargs) elif args.seg_model == "linearfusebothmask": return LinearFusionBothMask(args.base_model, args, num_classes = args.num_classes) elif args.seg_model == "linearfusiontokenmix": return LinearFusionTokenMix(args.base_model, args, num_classes = args.num_classes, exchange_percent = args.exchange_percent) elif args.seg_model == "tokenfusionmaemaskedcons": return TokenFusionMAEMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) elif args.seg_model == "unifiedrepresentationnetwork": return UnifiedRepresentationNetwork(args.base_model, args, num_classes = args.num_classes) elif args.seg_model == "unifiedrepresentationnetworkmoddrop": return UnifiedRepresentationNetworkModDrop(args.base_model, args, num_classes = args.num_classes) else: raise Exception(f"{args.seg_model} not configured")