File size: 2,013 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Import optimizer class dynamically."""
import argparse

from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.fill_missing_args import fill_missing_args


class OptimizerFactoryInterface:
    """Optimizer adaptor."""

    @staticmethod
    def from_args(target, args: argparse.Namespace):
        """Initialize optimizer from argparse Namespace.

        Args:
            target: for pytorch `model.parameters()`,
                for chainer `model`
            args (argparse.Namespace): parsed command-line args

        """
        raise NotImplementedError()

    @staticmethod
    def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
        """Register args."""
        return parser

    @classmethod
    def build(cls, target, **kwargs):
        """Initialize optimizer with python-level args.

        Args:
            target: for pytorch `model.parameters()`,
                for chainer `model`

        Returns:
            new Optimizer

        """
        args = argparse.Namespace(**kwargs)
        args = fill_missing_args(args, cls.add_arguments)
        return cls.from_args(target, args)


def dynamic_import_optimizer(name: str, backend: str) -> OptimizerFactoryInterface:
    """Import optimizer class dynamically.

    Args:
        name (str): alias name or dynamic import syntax `module:class`
        backend (str): backend name e.g., chainer or pytorch

    Returns:
        OptimizerFactoryInterface or FunctionalOptimizerAdaptor

    """
    if backend == "pytorch":
        from espnet.optimizer.pytorch import OPTIMIZER_FACTORY_DICT

        return OPTIMIZER_FACTORY_DICT[name]
    elif backend == "chainer":
        from espnet.optimizer.chainer import OPTIMIZER_FACTORY_DICT

        return OPTIMIZER_FACTORY_DICT[name]
    else:
        raise NotImplementedError(f"unsupported backend: {backend}")

    factory_class = dynamic_import(name)
    assert issubclass(factory_class, OptimizerFactoryInterface)
    return factory_class