| from .segmenter import CGFormer |
| from .segmenter_rcc import CGFormer_RCC_sbert |
| from .segmenter_sbert import CGFormer_sbert |
| from .segmenter_rz_sbert import CGFormer_Refzomsbert |
| from loguru import logger |
| import torch |
| import torch.nn as nn |
| from .backbone import MultiModalSwinTransformer |
|
|
|
|
| def build_model(args): |
| |
| if args.swin_type == 'tiny': |
| embed_dim = 96 |
| depths = [2, 2, 6, 2] |
| num_heads = [3, 6, 12, 24] |
| elif args.swin_type == 'small': |
| embed_dim = 96 |
| depths = [2, 2, 18, 2] |
| num_heads = [3, 6, 12, 24] |
| elif args.swin_type == 'base': |
| embed_dim = 128 |
| depths = [2, 2, 18, 2] |
| num_heads = [4, 8, 16, 32] |
| elif args.swin_type == 'large': |
| embed_dim = 192 |
| depths = [2, 2, 18, 2] |
| num_heads = [6, 12, 24, 48] |
| else: |
| assert False |
| |
| if 'window12' in args.swin_pretrain or args.window12: |
| logger.info('Window size 12!') |
| window_size = 12 |
| else: |
| window_size = 7 |
|
|
| if args.mha: |
| mha = args.mha.split('-') |
| mha = [int(a) for a in mha] |
| else: |
| mha = [1, 1, 1, 1] |
|
|
| out_indices = (0, 1, 2, 3) |
| backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, |
| window_size=window_size, |
| ape=False, drop_path_rate=0.3, patch_norm=True, |
| out_indices=out_indices, |
| use_checkpoint=False, num_heads_fusion=mha, |
| fusion_drop=args.fusion_drop |
| ) |
| if args.swin_pretrain: |
| logger.info('Initializing Multi-modal Swin Transformer weights from ' + args.swin_pretrain) |
| backbone.init_weights(pretrained=args.swin_pretrain) |
| else: |
| logger.info('Randomly initialize Multi-modal Swin Transformer weights.') |
| backbone.init_weights() |
| |
| if ('refcocog' in args.dataset) and args.metric_learning: |
| model = CGFormer_sbert(backbone, args) |
| elif (args.dataset in ['refcoco', 'refcoco+']) and args.metric_learning : |
| model = CGFormer_RCC_sbert(backbone, args) |
| elif ('ref-zom' in args.dataset) and args.metric_learning: |
| print("Loading CGFormer_Refzomsbert...") |
| model = CGFormer_Refzomsbert(backbone, args) |
| else: |
| |
| model = CGFormer(backbone, args) |
| |
| return model |
|
|
| |
| def build_segmenter(args, DDP=True, OPEN=False): |
| model = build_model(args) |
| if DDP: |
| if args.sync_bn: |
| model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| |
| |
| |
| |
| |
|
|
| single_model = model |
| if OPEN: |
| for p in single_model.backbone.parameters(): |
| p.requires_grad_(False) |
| param_list = [ |
| { |
| "params": [ |
| p |
| for n, p in single_model.named_parameters() |
| if "backbone" not in n and "text_encoder" not in n and p.requires_grad |
| ], |
|
|
| }, |
| { |
| "params": [ |
| p |
| for n, p in single_model.named_parameters() |
| if "pwam" in n and p.requires_grad |
| ], |
|
|
| }, |
| { |
| "params": [p for n, p in single_model.named_parameters() if "backbone" in n and "pwam" not in n and p.requires_grad], |
| "lr": args.lr_backbone, |
| }, |
| { |
| "params": [p for n, p in single_model.named_parameters() if "text_encoder" in n and p.requires_grad], |
| "lr": args.lr_text_encoder, |
| }, |
| ] |
|
|
| return model, param_list |
| else: |
| model = model |
| return model |
|
|