model = dict( backbone=dict( n_points=4, deform_num_heads=16, cffn_ratio=0.25, deform_ratio=0.5, with_cffn=True, interact_attn_type='deform', interaction_drop_path_rate=0.4, separate_head=True, branch1=dict( img_size=128, patch_size=16, pretrain_img_size=224, pretrain_patch_size=16, depth=12, embed_dim=768, num_heads=12, mlp_ratio=4, init_scale=1.0, qkv_bias=True, drop_rate=0.0, drop_path_rate=0.2, interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]], use_cls_token=True, use_flash_attn=True, with_cp=True, pretrained="pretrained/deit_base_patch16_224-b5f2ef4d.pth", ), branch2=dict( img_size=192, patch_size=16, pretrain_img_size=224, pretrain_patch_size=16, depth=12, embed_dim=384, num_heads=6, mlp_ratio=4, init_scale=1.0, qkv_bias=True, drop_rate=0.0, drop_path_rate=0.05, interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]], use_cls_token=True, use_flash_attn=True, with_cp=True, pretrained="pretrained/deit_small_patch16_224-cd65a155.pth", ), branch3=dict( img_size=368, patch_size=16, pretrain_img_size=224, pretrain_patch_size=16, depth=12, embed_dim=192, num_heads=3, mlp_ratio=4, init_scale=1.0, qkv_bias=True, drop_rate=0.0, drop_path_rate=0.05, interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]], use_cls_token=True, use_flash_attn=True, with_cp=True, pretrained="pretrained/deit_tiny_patch16_224-a1311bcf.pth", ), ), )