import torchvision.models as models from model.densenetccnl import * from model.unetnc import * from model.gienet import * def get_model(name, n_classes=1, filters=64,version=None,in_channels=3, is_batchnorm=True, norm='batch', model_path=None, use_sigmoid=True, layers=3,img_size=512): model = _get_model_instance(name) if name == 'dnetccnl': model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32) elif name == 'dnetccnl512': model = model(img_size=img_size, in_channels=in_channels, out_channels=n_classes, filters=32) elif name == 'unetnc': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) elif name == 'gie': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) elif name == 'giecbam': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) elif name == 'gie2head': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) elif name == 'giemask': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) elif name == 'giemask2': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) elif name == 'giedilated': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) elif name == 'bmp': model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) elif name == 'displacement': model = model(n_classes=2, num_filter=32, BatchNorm='GN', in_channels=5) return model def _get_model_instance(name): try: return { 'dnetccnl': dnetccnl, 'dnetccnl512': dnetccnl512, 'unetnc': UnetGenerator, 'gie':GieGenerator, 'giecbam':GiecbamGenerator, 'giedilated':DilatedSingleUnet, 'gie2head':Gie2headGenerator, 'giemask':GiemaskGenerator, 'giemask2':Giemask2Generator, 'bmp':BmpGenerator, }[name] except: print('Model {} not available'.format(name))