Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,084 Bytes
c509e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torchvision.models as models
from model.densenetccnl import *
from model.unetnc import *
from model.gienet import *
def get_model(name, n_classes=1, filters=64,version=None,in_channels=3, is_batchnorm=True, norm='batch', model_path=None, use_sigmoid=True, layers=3,img_size=512):
model = _get_model_instance(name)
if name == 'dnetccnl':
model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32)
elif name == 'dnetccnl512':
model = model(img_size=img_size, in_channels=in_channels, out_channels=n_classes, filters=32)
elif name == 'unetnc':
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
elif name == 'gie':
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
elif name == 'giecbam':
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
elif name == 'gie2head':
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
elif name == 'giemask':
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
elif name == 'giemask2':
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
elif name == 'giedilated':
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
elif name == 'bmp':
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
elif name == 'displacement':
model = model(n_classes=2, num_filter=32, BatchNorm='GN', in_channels=5)
return model
def _get_model_instance(name):
try:
return {
'dnetccnl': dnetccnl,
'dnetccnl512': dnetccnl512,
'unetnc': UnetGenerator,
'gie':GieGenerator,
'giecbam':GiecbamGenerator,
'giedilated':DilatedSingleUnet,
'gie2head':Gie2headGenerator,
'giemask':GiemaskGenerator,
'giemask2':Giemask2Generator,
'bmp':BmpGenerator,
}[name]
except:
print('Model {} not available'.format(name))
|