3v324v23's picture
add
c310e19
raw
history blame
4.4 kB
# # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# from collections import OrderedDict
# from torch import nn
# from . import fpn as fpn_module
# from . import resnet
# def build_resnet_backbone(cfg):
# body = resnet.ResNet(cfg)
# model = nn.Sequential(OrderedDict([("body", body)]))
# return model
# def build_resnet_fpn_backbone(cfg):
# body = resnet.ResNet(cfg)
# in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
# out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
# fpn = fpn_module.FPN(
# in_channels_list=[
# in_channels_stage2,
# in_channels_stage2 * 2,
# in_channels_stage2 * 4,
# in_channels_stage2 * 8,
# ],
# out_channels=out_channels,
# top_blocks=fpn_module.LastLevelMaxPool(),
# )
# model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
# return model
# _BACKBONES = {"resnet": build_resnet_backbone, "resnet-fpn": build_resnet_fpn_backbone}
# def build_backbone(cfg):
# assert cfg.MODEL.BACKBONE.CONV_BODY.startswith(
# "R-"
# ), "Only ResNet and ResNeXt models are currently implemented"
# # Models using FPN end with "-FPN"
# if cfg.MODEL.BACKBONE.CONV_BODY.endswith("-FPN"):
# return build_resnet_fpn_backbone(cfg)
# return build_resnet_backbone(cfg)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from collections import OrderedDict
from torch import nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.make_layers import conv_with_kaiming_uniform
from . import fpn as fpn_module
# from . import resnet
@registry.BACKBONES.register("R-50-C4")
@registry.BACKBONES.register("R-50-C5")
@registry.BACKBONES.register("R-101-C4")
@registry.BACKBONES.register("R-101-C5")
def build_resnet_backbone(cfg):
body = resnet.ResNet(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
model.out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS
return model
@registry.BACKBONES.register("R-18-FPN")
@registry.BACKBONES.register("R-34-FPN")
@registry.BACKBONES.register("R-50-FPN")
@registry.BACKBONES.register("R-101-FPN")
@registry.BACKBONES.register("R-152-FPN")
def build_resnet_fpn_backbone(cfg):
if cfg.MODEL.RESNET34:
from . import resnet34 as resnet
body = resnet.ResNet(layers=cfg.MODEL.RESNETS.LAYERS)
else:
from . import resnet
body = resnet.ResNet(cfg)
in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS
fpn = fpn_module.FPN(
in_channels_list=[
in_channels_stage2,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
],
out_channels=out_channels,
conv_block=conv_with_kaiming_uniform(
cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
),
top_blocks=fpn_module.LastLevelMaxPool(),
)
model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
model.out_channels = out_channels
return model
@registry.BACKBONES.register("R-50-FPN-RETINANET")
@registry.BACKBONES.register("R-101-FPN-RETINANET")
def build_resnet_fpn_p3p7_backbone(cfg):
body = resnet.ResNet(cfg)
in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS
in_channels_p6p7 = in_channels_stage2 * 8 if cfg.MODEL.RETINANET.USE_C5 \
else out_channels
fpn = fpn_module.FPN(
in_channels_list=[
0,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
],
out_channels=out_channels,
conv_block=conv_with_kaiming_uniform(
cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
),
top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels),
)
model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
model.out_channels = out_channels
return model
def build_backbone(cfg):
assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
"cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format(
cfg.MODEL.BACKBONE.CONV_BODY
)
return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg)