Spaces:
Paused
Paused
# -*- coding: utf-8 -*- | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
import functools | |
import inspect | |
def configurable(init_func=None, *, from_config=None): | |
""" | |
Decorate a function or a class's __init__ method so that it can be called | |
with a :class:`CfgNode` object using a :func:`from_config` function that translates | |
:class:`CfgNode` to arguments. | |
Examples: | |
:: | |
# Usage 1: Decorator on __init__: | |
class A: | |
@configurable | |
def __init__(self, a, b=2, c=3): | |
pass | |
@classmethod | |
def from_config(cls, cfg): # 'cfg' must be the first argument | |
# Returns kwargs to be passed to __init__ | |
return {"a": cfg.A, "b": cfg.B} | |
a1 = A(a=1, b=2) # regular construction | |
a2 = A(cfg) # construct with a cfg | |
a3 = A(cfg, b=3, c=4) # construct with extra overwrite | |
# Usage 2: Decorator on any function. Needs an extra from_config argument: | |
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) | |
def a_func(a, b=2, c=3): | |
pass | |
a1 = a_func(a=1, b=2) # regular call | |
a2 = a_func(cfg) # call with a cfg | |
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite | |
Args: | |
init_func (callable): a class's ``__init__`` method in usage 1. The | |
class must have a ``from_config`` classmethod which takes `cfg` as | |
the first argument. | |
from_config (callable): the from_config function in usage 2. It must take `cfg` | |
as its first argument. | |
""" | |
if init_func is not None: | |
assert ( | |
inspect.isfunction(init_func) | |
and from_config is None | |
and init_func.__name__ == "__init__" | |
), "Incorrect use of @configurable. Check API documentation for examples." | |
def wrapped(self, *args, **kwargs): | |
try: | |
from_config_func = type(self).from_config | |
except AttributeError as e: | |
raise AttributeError( | |
"Class with @configurable must have a 'from_config' classmethod." | |
) from e | |
if not inspect.ismethod(from_config_func): | |
raise TypeError("Class with @configurable must have a 'from_config' classmethod.") | |
if _called_with_cfg(*args, **kwargs): | |
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) | |
init_func(self, **explicit_args) | |
else: | |
init_func(self, *args, **kwargs) | |
return wrapped | |
else: | |
if from_config is None: | |
return configurable # @configurable() is made equivalent to @configurable | |
assert inspect.isfunction( | |
from_config | |
), "from_config argument of configurable must be a function!" | |
def wrapper(orig_func): | |
def wrapped(*args, **kwargs): | |
if _called_with_cfg(*args, **kwargs): | |
explicit_args = _get_args_from_config(from_config, *args, **kwargs) | |
return orig_func(**explicit_args) | |
else: | |
return orig_func(*args, **kwargs) | |
wrapped.from_config = from_config | |
return wrapped | |
return wrapper | |
def _called_with_cfg(*args, **kwargs): | |
""" | |
Returns: | |
bool: whether the arguments contain CfgNode and should be considered | |
forwarded to from_config. | |
""" | |
from omegaconf import DictConfig | |
if len(args) and isinstance(args[0], (dict)): | |
return True | |
if isinstance(kwargs.pop("cfg", None), (dict)): | |
return True | |
# `from_config`'s first argument is forced to be "cfg". | |
# So the above check covers all cases. | |
return False | |
def _get_args_from_config(from_config_func, *args, **kwargs): | |
""" | |
Use `from_config` to obtain explicit arguments. | |
Returns: | |
dict: arguments to be used for cls.__init__ | |
""" | |
signature = inspect.signature(from_config_func) | |
if list(signature.parameters.keys())[0] != "cfg": | |
if inspect.isfunction(from_config_func): | |
name = from_config_func.__name__ | |
else: | |
name = f"{from_config_func.__self__}.from_config" | |
raise TypeError(f"{name} must take 'cfg' as the first argument!") | |
support_var_arg = any( | |
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] | |
for param in signature.parameters.values() | |
) | |
if support_var_arg: # forward all arguments to from_config, if from_config accepts them | |
ret = from_config_func(*args, **kwargs) | |
else: | |
# forward supported arguments to from_config | |
supported_arg_names = set(signature.parameters.keys()) | |
extra_kwargs = {} | |
for name in list(kwargs.keys()): | |
if name not in supported_arg_names: | |
extra_kwargs[name] = kwargs.pop(name) | |
ret = from_config_func(*args, **kwargs) | |
# forward the other arguments to __init__ | |
ret.update(extra_kwargs) | |
return ret |