from timm.models import create_model from .swin_transformer import SwinTransformer from . import focalnet def build_model(config): model_type = config.TYPE print(f"Creating model: {model_type}") if "swin" in model_type: model = SwinTransformer( num_classes=0, img_size=config.IMG_SIZE, patch_size=config.SWIN.PATCH_SIZE, in_chans=config.SWIN.IN_CHANS, embed_dim=config.SWIN.EMBED_DIM, depths=config.SWIN.DEPTHS, num_heads=config.SWIN.NUM_HEADS, window_size=config.SWIN.WINDOW_SIZE, mlp_ratio=config.SWIN.MLP_RATIO, qkv_bias=config.SWIN.QKV_BIAS, qk_scale=config.SWIN.QK_SCALE, drop_rate=config.DROP_RATE, drop_path_rate=config.DROP_PATH_RATE, ape=config.SWIN.APE, patch_norm=config.SWIN.PATCH_NORM, use_checkpoint=False ) elif "focal" in model_type: model = create_model( model_type, pretrained=False, img_size=config.IMG_SIZE, num_classes=0, drop_path_rate=config.DROP_PATH_RATE, use_conv_embed=config.FOCAL.USE_CONV_EMBED, use_layerscale=config.FOCAL.USE_LAYERSCALE, use_postln=config.FOCAL.USE_POSTLN ) elif "vit" in model_type: model = create_model( model_type, pretrained=is_pretrained, img_size=config.DATA.IMG_SIZE, num_classes=config.MODEL.NUM_CLASSES, ) elif "resnet" in model_type: model = create_model( model_type, pretrained=is_pretrained, num_classes=config.MODEL.NUM_CLASSES ) else: model = create_model( model_type, pretrained=is_pretrained, num_classes=config.MODEL.NUM_CLASSES ) return model