from collections import OrderedDict import torch from torch import nn from maskrcnn_benchmark.modeling import registry from . import bert_model from . import rnn_model from . import clip_model from . import word_utils from . import roberta_fused_model from . import roberta_fused_model_v2 from . import roberta_fused_model_tiny @registry.LANGUAGE_BACKBONES.register("bert-base-uncased") def build_bert_backbone(cfg): body = bert_model.BertEncoder(cfg) model = nn.Sequential(OrderedDict([("body", body)])) return model @registry.LANGUAGE_BACKBONES.register("roberta-base") def build_bert_backbone(cfg): body = bert_model.BertEncoder(cfg) model = nn.Sequential(OrderedDict([("body", body)])) return model @registry.LANGUAGE_BACKBONES.register("rnn") def build_rnn_backbone(cfg): body = rnn_model.RNNEnoder(cfg) model = nn.Sequential(OrderedDict([("body", body)])) return model @registry.LANGUAGE_BACKBONES.register("clip") def build_clip_backbone(cfg): body = clip_model.CLIPTransformer(cfg) model = nn.Sequential(OrderedDict([("body", body)])) return model @registry.LANGUAGE_BACKBONES.register("roberta-fused") def build_clip_backbone(cfg): body = roberta_fused_model.RobertaFusedEncoder(cfg) model = nn.Sequential(OrderedDict([("body", body)])) return model @registry.LANGUAGE_BACKBONES.register("roberta-fused-v2") def build_clip_backbone(cfg): body = roberta_fused_model_v2.RobertaFusedEncoder(cfg) model = nn.Sequential(OrderedDict([("body", body)])) return model @registry.LANGUAGE_BACKBONES.register("roberta-fused-tiny") def build_clip_backbone(cfg): body = roberta_fused_model_tiny.RobertaFusedEncoder(cfg) model = nn.Sequential(OrderedDict([("body", body)])) return model def build_backbone(cfg): assert ( cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in registry.LANGUAGE_BACKBONES ), "cfg.MODEL.LANGUAGE_BACKBONE.TYPE: {} is not registered in registry".format( cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE ) return registry.LANGUAGE_BACKBONES[cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE](cfg)