|
from .utils import IntermediateLayerGetter |
|
from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3 |
|
from .enhanced_deeplab import EnhancedDeepLabHead, EnhancedDeepLabHeadV3Plus, EnhancedDeepLabV3 |
|
from .backbone import ( |
|
resnet, |
|
mobilenetv2, |
|
hrnetv2, |
|
xception |
|
) |
|
|
|
def _segm_hrnet(name, backbone_name, num_classes, pretrained_backbone, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
|
|
backbone = hrnetv2.__dict__[backbone_name](pretrained_backbone) |
|
|
|
|
|
|
|
hrnet_channels = int(backbone_name.split('_')[-1]) |
|
inplanes = sum([hrnet_channels * 2 ** i for i in range(4)]) |
|
low_level_planes = 256 |
|
aspp_dilate = [12, 24, 36] |
|
|
|
if name=='deeplabv3plus': |
|
return_layers = {'stage4': 'out', 'layer1': 'low_level'} |
|
classifier = EnhancedDeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
elif name=='deeplabv3': |
|
return_layers = {'stage4': 'out'} |
|
classifier = EnhancedDeepLabHead(inplanes, num_classes, aspp_dilate, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers, hrnet_flag=True) |
|
model = EnhancedDeepLabV3(backbone, classifier) |
|
return model |
|
|
|
def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
|
|
if output_stride==8: |
|
replace_stride_with_dilation=[False, True, True] |
|
aspp_dilate = [12, 24, 36] |
|
else: |
|
replace_stride_with_dilation=[False, False, True] |
|
aspp_dilate = [6, 12, 18] |
|
|
|
backbone = resnet.__dict__[backbone_name]( |
|
pretrained=pretrained_backbone, |
|
replace_stride_with_dilation=replace_stride_with_dilation) |
|
|
|
inplanes = 2048 |
|
low_level_planes = 256 |
|
|
|
if name=='deeplabv3plus': |
|
return_layers = {'layer4': 'out', 'layer1': 'low_level'} |
|
classifier = EnhancedDeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
elif name=='deeplabv3': |
|
return_layers = {'layer4': 'out'} |
|
classifier = EnhancedDeepLabHead(inplanes, num_classes, aspp_dilate, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) |
|
|
|
model = EnhancedDeepLabV3(backbone, classifier) |
|
return model |
|
|
|
|
|
def _segm_xception(name, backbone_name, num_classes, output_stride, pretrained_backbone, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
if output_stride==8: |
|
replace_stride_with_dilation=[False, False, True, True] |
|
aspp_dilate = [12, 24, 36] |
|
else: |
|
replace_stride_with_dilation=[False, False, False, True] |
|
aspp_dilate = [6, 12, 18] |
|
|
|
backbone = xception.xception(pretrained= 'imagenet' if pretrained_backbone else False, replace_stride_with_dilation=replace_stride_with_dilation) |
|
|
|
inplanes = 2048 |
|
low_level_planes = 128 |
|
|
|
if name=='deeplabv3plus': |
|
return_layers = {'conv4': 'out', 'block1': 'low_level'} |
|
classifier = EnhancedDeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
elif name=='deeplabv3': |
|
return_layers = {'conv4': 'out'} |
|
classifier = EnhancedDeepLabHead(inplanes, num_classes, aspp_dilate, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) |
|
model = EnhancedDeepLabV3(backbone, classifier) |
|
return model |
|
|
|
|
|
def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
if output_stride==8: |
|
aspp_dilate = [12, 24, 36] |
|
else: |
|
aspp_dilate = [6, 12, 18] |
|
|
|
backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride) |
|
|
|
|
|
backbone.low_level_features = backbone.features[0:4] |
|
backbone.high_level_features = backbone.features[4:-1] |
|
backbone.features = None |
|
backbone.classifier = None |
|
|
|
inplanes = 320 |
|
low_level_planes = 24 |
|
|
|
if name=='deeplabv3plus': |
|
return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'} |
|
classifier = EnhancedDeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
elif name=='deeplabv3': |
|
return_layers = {'high_level_features': 'out'} |
|
classifier = EnhancedDeepLabHead(inplanes, num_classes, aspp_dilate, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) |
|
|
|
model = EnhancedDeepLabV3(backbone, classifier) |
|
return model |
|
|
|
def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone, **kwargs): |
|
use_eoaNet = kwargs.get('use_eoaNet', True) |
|
msa_scales = kwargs.get('msa_scales', [1, 2, 4]) |
|
eog_beta = kwargs.get('eog_beta', 0.5) |
|
|
|
if backbone=='mobilenetv2': |
|
model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
elif backbone.startswith('resnet'): |
|
model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
elif backbone.startswith('hrnetv2'): |
|
model = _segm_hrnet(arch_type, backbone, num_classes, pretrained_backbone=pretrained_backbone, |
|
use_eoaNet=use_eoaNet, msa_scales=msa_scales, eog_beta=eog_beta) |
|
elif backbone=='xception': |
|
model = _segm_xception(arch_type, backbone, num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
else: |
|
raise NotImplementedError |
|
return model |
|
|
|
|
|
|
|
def deeplabv3_hrnetv2_48(num_classes=21, output_stride=4, pretrained_backbone=False, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3 model with a HRNetV2-48 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3', 'hrnetv2_48', num_classes, output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
def deeplabv3_hrnetv2_32(num_classes=21, output_stride=4, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3 model with a HRNetV2-32 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3', 'hrnetv2_32', num_classes, output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3 model with a ResNet-50 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3 model with a ResNet-101 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3 model with a MobileNetv2 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
def deeplabv3_xception(num_classes=21, output_stride=8, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3 model with a Xception backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3', 'xception', num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
|
|
|
|
def deeplabv3plus_hrnetv2_48(num_classes=21, output_stride=4, pretrained_backbone=False, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3+ model with a HRNetV2-48 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3plus', 'hrnetv2_48', num_classes, output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
def deeplabv3plus_hrnetv2_32(num_classes=21, output_stride=4, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3+ model with a HRNetV2-32 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3plus', 'hrnetv2_32', num_classes, output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3 model with a ResNet-50 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
|
|
def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3+ model with a ResNet-101 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
|
|
def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3+ model with a MobileNetv2 backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |
|
|
|
def deeplabv3plus_xception(num_classes=21, output_stride=8, pretrained_backbone=True, |
|
use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
|
"""Constructs a DeepLabV3+ model with a Xception backbone. |
|
|
|
Args: |
|
num_classes (int): number of classes. |
|
output_stride (int): output stride for deeplab. |
|
pretrained_backbone (bool): If True, use the pretrained backbone. |
|
use_eoaNet (bool): If True, use Entropy-Optimized Attention Network. |
|
msa_scales (list): Scales for Multi-Scale Attention. |
|
eog_beta (float): Entropy threshold for Entropy-Optimized Gating. |
|
""" |
|
return _load_model('deeplabv3plus', 'xception', num_classes, output_stride=output_stride, |
|
pretrained_backbone=pretrained_backbone, use_eoaNet=use_eoaNet, |
|
msa_scales=msa_scales, eog_beta=eog_beta) |