File size: 1,954 Bytes
2fafc55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from timm.models import create_model
from .swin_transformer import SwinTransformer
from . import focalnet

def build_model(config):
    model_type = config.TYPE
    print(f"Creating model: {model_type}")
    
    if "swin" in model_type:
        model = SwinTransformer(
            num_classes=0,
            img_size=config.IMG_SIZE,
            patch_size=config.SWIN.PATCH_SIZE,
            in_chans=config.SWIN.IN_CHANS,
            embed_dim=config.SWIN.EMBED_DIM,
            depths=config.SWIN.DEPTHS,
            num_heads=config.SWIN.NUM_HEADS,
            window_size=config.SWIN.WINDOW_SIZE,
            mlp_ratio=config.SWIN.MLP_RATIO,
            qkv_bias=config.SWIN.QKV_BIAS,
            qk_scale=config.SWIN.QK_SCALE,
            drop_rate=config.DROP_RATE,
            drop_path_rate=config.DROP_PATH_RATE,
            ape=config.SWIN.APE,
            patch_norm=config.SWIN.PATCH_NORM,
            use_checkpoint=False
            ) 
    elif "focal" in model_type:
        model = create_model(
            model_type, 
            pretrained=False, 
            img_size=config.IMG_SIZE,
            num_classes=0,
            drop_path_rate=config.DROP_PATH_RATE,
            use_conv_embed=config.FOCAL.USE_CONV_EMBED, 
            use_layerscale=config.FOCAL.USE_LAYERSCALE,
            use_postln=config.FOCAL.USE_POSTLN
        )  

    elif "vit" in model_type:
        model = create_model(
            model_type,
            pretrained=is_pretrained,
            img_size=config.DATA.IMG_SIZE,
            num_classes=config.MODEL.NUM_CLASSES,
        )
    elif "resnet" in model_type:
        model = create_model(
            model_type,
            pretrained=is_pretrained,
            num_classes=config.MODEL.NUM_CLASSES
        )
    else:
        model = create_model(
            model_type,
            pretrained=is_pretrained,
            num_classes=config.MODEL.NUM_CLASSES
        )        
    return model