PIIP / classification /piip_3branch_tsb_368-192-128_cls_token_deit1.py
wzk1015's picture
cls model,log,config
3734857
raw
history blame
2.3 kB
model = dict(
backbone=dict(
n_points=4,
deform_num_heads=16,
cffn_ratio=0.25,
deform_ratio=0.5,
with_cffn=True,
interact_attn_type='deform',
interaction_drop_path_rate=0.4,
separate_head=True,
branch1=dict(
img_size=128,
patch_size=16,
pretrain_img_size=224,
pretrain_patch_size=16,
depth=12,
embed_dim=768,
num_heads=12,
mlp_ratio=4,
init_scale=1.0,
qkv_bias=True,
drop_rate=0.0,
drop_path_rate=0.2,
interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]],
use_cls_token=True,
use_flash_attn=True,
with_cp=True,
pretrained="pretrained/deit_base_patch16_224-b5f2ef4d.pth",
),
branch2=dict(
img_size=192,
patch_size=16,
pretrain_img_size=224,
pretrain_patch_size=16,
depth=12,
embed_dim=384,
num_heads=6,
mlp_ratio=4,
init_scale=1.0,
qkv_bias=True,
drop_rate=0.0,
drop_path_rate=0.05,
interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]],
use_cls_token=True,
use_flash_attn=True,
with_cp=True,
pretrained="pretrained/deit_small_patch16_224-cd65a155.pth",
),
branch3=dict(
img_size=368,
patch_size=16,
pretrain_img_size=224,
pretrain_patch_size=16,
depth=12,
embed_dim=192,
num_heads=3,
mlp_ratio=4,
init_scale=1.0,
qkv_bias=True,
drop_rate=0.0,
drop_path_rate=0.05,
interaction_indexes=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11]],
use_cls_token=True,
use_flash_attn=True,
with_cp=True,
pretrained="pretrained/deit_tiny_patch16_224-a1311bcf.pth",
),
),
)