File size: 2,084 Bytes
c509e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torchvision.models as models
from model.densenetccnl import *
from model.unetnc import *
from model.gienet import *


def get_model(name, n_classes=1, filters=64,version=None,in_channels=3, is_batchnorm=True, norm='batch', model_path=None, use_sigmoid=True, layers=3,img_size=512):
    model = _get_model_instance(name)


    if name == 'dnetccnl':
        model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32)
    elif name == 'dnetccnl512':
        model = model(img_size=img_size, in_channels=in_channels, out_channels=n_classes, filters=32)
    elif name == 'unetnc':
        model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
    elif name == 'gie':
        model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
    elif name == 'giecbam':
        model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
    elif name == 'gie2head':
        model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
    elif name == 'giemask':
        model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
    elif name == 'giemask2':
        model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
    elif name == 'giedilated':
        model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
    elif name == 'bmp':
        model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
    elif name == 'displacement':
        model = model(n_classes=2, num_filter=32, BatchNorm='GN', in_channels=5)
    return model

def _get_model_instance(name):
    try:
        return {
            'dnetccnl': dnetccnl,
            'dnetccnl512': dnetccnl512,
            'unetnc': UnetGenerator,
            'gie':GieGenerator,
            'giecbam':GiecbamGenerator,
            'giedilated':DilatedSingleUnet,
            'gie2head':Gie2headGenerator,
            'giemask':GiemaskGenerator,
            'giemask2':Giemask2Generator,
            'bmp':BmpGenerator,
        }[name]
    except:
        print('Model {} not available'.format(name))