File size: 1,993 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Module defining encoders."""
import os
import importlib
from onmt.encoders.encoder import EncoderBase
from onmt.encoders.transformer import TransformerEncoder
from onmt.encoders.ggnn_encoder import GGNNEncoder
from onmt.encoders.rnn_encoder import RNNEncoder
from onmt.encoders.cnn_encoder import CNNEncoder
from onmt.encoders.mean_encoder import MeanEncoder


str2enc = {
    "ggnn": GGNNEncoder,
    "rnn": RNNEncoder,
    "brnn": RNNEncoder,
    "cnn": CNNEncoder,
    "transformer": TransformerEncoder,
    "mean": MeanEncoder,
}

__all__ = [
    "EncoderBase",
    "TransformerEncoder",
    "GGNNEncoder",
    "RNNEncoder",
    "CNNEncoder",
    "MeanEncoder",
    "str2enc",
]


def get_encoders_cls(encoder_names):
    """Return valid encoder class indicated in `encoder_names`."""
    encoders_cls = {}
    for name in encoder_names:
        if name not in str2enc:
            raise ValueError("%s encoder not supported!" % name)
        encoders_cls[name] = str2enc[name]
    return encoders_cls


def register_encoder(name):
    """Encoder register that can be used to add new encoder class."""

    def register_encoder_cls(cls):
        if name in str2enc:
            raise ValueError("Cannot register duplicate encoder ({})".format(name))
        if not issubclass(cls, EncoderBase):
            raise ValueError(f"encoder ({name}: {cls.__name_}) must extend EncoderBase")
        str2enc[name] = cls
        __all__.append(cls.__name__)  # added to be complete
        return cls

    return register_encoder_cls


# Auto import python files in this directory
encoder_dir = os.path.dirname(__file__)
for file in os.listdir(encoder_dir):
    path = os.path.join(encoder_dir, file)
    if (
        not file.startswith("_")
        and not file.startswith(".")
        and (file.endswith(".py") or os.path.isdir(path))
    ):
        file_name = file[: file.find(".py")] if file.endswith(".py") else file
        module = importlib.import_module("onmt.encoders." + file_name)