File size: 2,183 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from collections import OrderedDict
import torch
from torch import nn

from maskrcnn_benchmark.modeling import registry
from . import bert_model
from . import rnn_model
from . import clip_model
from . import word_utils
from . import roberta_fused_model
from . import roberta_fused_model_v2
from . import roberta_fused_model_tiny

@registry.LANGUAGE_BACKBONES.register("bert-base-uncased")
def build_bert_backbone(cfg):
    body = bert_model.BertEncoder(cfg)
    model = nn.Sequential(OrderedDict([("body", body)]))
    return model


@registry.LANGUAGE_BACKBONES.register("roberta-base")
def build_bert_backbone(cfg):
    body = bert_model.BertEncoder(cfg)
    model = nn.Sequential(OrderedDict([("body", body)]))
    return model


@registry.LANGUAGE_BACKBONES.register("rnn")
def build_rnn_backbone(cfg):
    body = rnn_model.RNNEnoder(cfg)
    model = nn.Sequential(OrderedDict([("body", body)]))
    return model


@registry.LANGUAGE_BACKBONES.register("clip")
def build_clip_backbone(cfg):
    body = clip_model.CLIPTransformer(cfg)
    model = nn.Sequential(OrderedDict([("body", body)]))
    return model


@registry.LANGUAGE_BACKBONES.register("roberta-fused")
def build_clip_backbone(cfg):
    body = roberta_fused_model.RobertaFusedEncoder(cfg)
    model = nn.Sequential(OrderedDict([("body", body)]))
    return model


@registry.LANGUAGE_BACKBONES.register("roberta-fused-v2")
def build_clip_backbone(cfg):
    body = roberta_fused_model_v2.RobertaFusedEncoder(cfg)
    model = nn.Sequential(OrderedDict([("body", body)]))
    return model

@registry.LANGUAGE_BACKBONES.register("roberta-fused-tiny")
def build_clip_backbone(cfg):
    body = roberta_fused_model_tiny.RobertaFusedEncoder(cfg)
    model = nn.Sequential(OrderedDict([("body", body)]))
    return model

def build_backbone(cfg):
    assert (
        cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in registry.LANGUAGE_BACKBONES
    ), "cfg.MODEL.LANGUAGE_BACKBONE.TYPE: {} is not registered in registry".format(
        cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
    )
    return registry.LANGUAGE_BACKBONES[cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE](cfg)