|
|
import os |
|
|
from functools import partial |
|
|
import torch |
|
|
|
|
|
from .vmamba import VSSM |
|
|
from .csms6s import flops_selective_scan_fn,flops_selective_scan_ref |
|
|
|
|
|
|
|
|
def build_vssm_model(config, **kwargs): |
|
|
model_type = config.MODEL.TYPE |
|
|
if model_type in ["vssm"]: |
|
|
model = VSSM( |
|
|
patch_size=config.MODEL.VSSM.PATCH_SIZE, |
|
|
in_chans=config.MODEL.VSSM.IN_CHANS, |
|
|
num_classes=config.MODEL.NUM_CLASSES, |
|
|
depths=config.MODEL.VSSM.DEPTHS, |
|
|
dims=config.MODEL.VSSM.EMBED_DIM, |
|
|
|
|
|
ssm_d_state=config.MODEL.VSSM.SSM_D_STATE, |
|
|
ssm_ratio=config.MODEL.VSSM.SSM_RATIO, |
|
|
ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO, |
|
|
ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)), |
|
|
ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER, |
|
|
ssm_conv=config.MODEL.VSSM.SSM_CONV, |
|
|
ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS, |
|
|
ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE, |
|
|
ssm_init=config.MODEL.VSSM.SSM_INIT, |
|
|
forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE, |
|
|
|
|
|
mlp_ratio=config.MODEL.VSSM.MLP_RATIO, |
|
|
mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER, |
|
|
mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE, |
|
|
|
|
|
drop_path_rate=config.MODEL.DROP_PATH_RATE, |
|
|
patch_norm=config.MODEL.VSSM.PATCH_NORM, |
|
|
norm_layer=config.MODEL.VSSM.NORM_LAYER, |
|
|
downsample_version=config.MODEL.VSSM.DOWNSAMPLE, |
|
|
patchembed_version=config.MODEL.VSSM.PATCHEMBED, |
|
|
gmlp=config.MODEL.VSSM.GMLP, |
|
|
use_checkpoint=config.TRAIN.USE_CHECKPOINT, |
|
|
|
|
|
posembed=config.MODEL.VSSM.POSEMBED, |
|
|
imgsize=config.DATA.IMG_SIZE, |
|
|
) |
|
|
return model |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def build_model(config, is_pretrain=False): |
|
|
model = None |
|
|
if model is None: |
|
|
model = build_vssm_model(config, is_pretrain) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|