File size: 3,019 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Mapping
from typing import Optional
from typing import Tuple

from typeguard import check_argument_types
from typeguard import check_return_type

from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import str_or_none


class ClassChoices:
    """Helper class to manage the options for variable objects and its configuration.

    Example:

    >>> class A:
    ...     def __init__(self, foo=3):  pass
    >>> class B:
    ...     def __init__(self, bar="aaaa"):  pass
    >>> choices = ClassChoices("var", dict(a=A, b=B), default="a")
    >>> import argparse
    >>> parser = argparse.ArgumentParser()
    >>> choices.add_arguments(parser)
    >>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4")
    >>> args.var
    a
    >>> args.var_conf
    {"foo": 4}
    >>> class_obj = choices.get_class(args.var)
    >>> a_object = class_obj(**args.var_conf)

    """

    def __init__(
        self,
        name: str,
        classes: Mapping[str, type],
        type_check: type = None,
        default: str = None,
        optional: bool = False,
    ):
        assert check_argument_types()
        self.name = name
        self.base_type = type_check
        self.classes = {k.lower(): v for k, v in classes.items()}
        if "none" in self.classes or "nil" in self.classes or "null" in self.classes:
            raise ValueError('"none", "nil", and "null" are reserved.')
        if type_check is not None:
            for v in self.classes.values():
                if not issubclass(v, type_check):
                    raise ValueError(f"must be {type_check.__name__}, but got {v}")

        self.optional = optional
        self.default = default
        if default is None:
            self.optional = True

    def choices(self) -> Tuple[Optional[str], ...]:
        retval = tuple(self.classes)
        if self.optional:
            return retval + (None,)
        else:
            return retval

    def get_class(self, name: Optional[str]) -> Optional[type]:
        assert check_argument_types()
        if name is None or (self.optional and name.lower() == ("none", "null", "nil")):
            retval = None
        elif name.lower() in self.classes:
            class_obj = self.classes[name]
            assert check_return_type(class_obj)
            retval = class_obj
        else:
            raise ValueError(
                f"--{self.name} must be one of {self.choices()}: "
                f"--{self.name} {name.lower()}"
            )

        return retval

    def add_arguments(self, parser):
        parser.add_argument(
            f"--{self.name}",
            type=lambda x: str_or_none(x.lower()),
            default=self.default,
            choices=self.choices(),
            help=f"The {self.name} type",
        )
        parser.add_argument(
            f"--{self.name}_conf",
            action=NestedDictAction,
            default=dict(),
            help=f"The keyword arguments for {self.name}",
        )