Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import ast | |
| import inspect | |
| import logging | |
| import os | |
| import re | |
| from argparse import ArgumentError, ArgumentParser, Namespace | |
| from dataclasses import _MISSING_TYPE, MISSING, is_dataclass | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional, Tuple, Type | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.dataclass.configs import FairseqConfig | |
| from hydra.core.global_hydra import GlobalHydra | |
| from hydra.experimental import compose, initialize | |
| from omegaconf import DictConfig, OmegaConf, open_dict, _utils | |
| logger = logging.getLogger(__name__) | |
| def eval_str_list(x, x_type=float): | |
| if x is None: | |
| return None | |
| if isinstance(x, str): | |
| if len(x) == 0: | |
| return [] | |
| x = ast.literal_eval(x) | |
| try: | |
| return list(map(x_type, x)) | |
| except TypeError: | |
| return [x_type(x)] | |
| def interpret_dc_type(field_type): | |
| if isinstance(field_type, str): | |
| raise RuntimeError("field should be a type") | |
| if field_type == Any: | |
| return str | |
| typestring = str(field_type) | |
| if re.match( | |
| r"(typing.|^)Union\[(.*), NoneType\]$", typestring | |
| ) or typestring.startswith("typing.Optional"): | |
| return field_type.__args__[0] | |
| return field_type | |
| def gen_parser_from_dataclass( | |
| parser: ArgumentParser, | |
| dataclass_instance: FairseqDataclass, | |
| delete_default: bool = False, | |
| with_prefix: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| convert a dataclass instance to tailing parser arguments. | |
| If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are | |
| building a flat namespace from a structured dataclass (see transformer_config.py for example). | |
| """ | |
| def argparse_name(name: str): | |
| if name == "data" and (with_prefix is None or with_prefix == ""): | |
| # normally data is positional args, so we don't add the -- nor the prefix | |
| return name | |
| if name == "_name": | |
| # private member, skip | |
| return None | |
| full_name = "--" + name.replace("_", "-") | |
| if with_prefix is not None and with_prefix != "": | |
| # if a prefix is specified, construct the prefixed arg name | |
| full_name = with_prefix + "-" + full_name[2:] # strip -- when composing | |
| return full_name | |
| def get_kwargs_from_dc( | |
| dataclass_instance: FairseqDataclass, k: str | |
| ) -> Dict[str, Any]: | |
| """k: dataclass attributes""" | |
| kwargs = {} | |
| field_type = dataclass_instance._get_type(k) | |
| inter_type = interpret_dc_type(field_type) | |
| field_default = dataclass_instance._get_default(k) | |
| if isinstance(inter_type, type) and issubclass(inter_type, Enum): | |
| field_choices = [t.value for t in list(inter_type)] | |
| else: | |
| field_choices = None | |
| field_help = dataclass_instance._get_help(k) | |
| field_const = dataclass_instance._get_argparse_const(k) | |
| if isinstance(field_default, str) and field_default.startswith("${"): | |
| kwargs["default"] = field_default | |
| else: | |
| if field_default is MISSING: | |
| kwargs["required"] = True | |
| if field_choices is not None: | |
| kwargs["choices"] = field_choices | |
| if ( | |
| isinstance(inter_type, type) | |
| and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) | |
| ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): | |
| if "int" in str(inter_type): | |
| kwargs["type"] = lambda x: eval_str_list(x, int) | |
| elif "float" in str(inter_type): | |
| kwargs["type"] = lambda x: eval_str_list(x, float) | |
| elif "str" in str(inter_type): | |
| kwargs["type"] = lambda x: eval_str_list(x, str) | |
| else: | |
| raise NotImplementedError( | |
| "parsing of type " + str(inter_type) + " is not implemented" | |
| ) | |
| if field_default is not MISSING: | |
| kwargs["default"] = ( | |
| ",".join(map(str, field_default)) | |
| if field_default is not None | |
| else None | |
| ) | |
| elif ( | |
| isinstance(inter_type, type) and issubclass(inter_type, Enum) | |
| ) or "Enum" in str(inter_type): | |
| kwargs["type"] = str | |
| if field_default is not MISSING: | |
| if isinstance(field_default, Enum): | |
| kwargs["default"] = field_default.value | |
| else: | |
| kwargs["default"] = field_default | |
| elif inter_type is bool: | |
| kwargs["action"] = ( | |
| "store_false" if field_default is True else "store_true" | |
| ) | |
| kwargs["default"] = field_default | |
| else: | |
| kwargs["type"] = inter_type | |
| if field_default is not MISSING: | |
| kwargs["default"] = field_default | |
| # build the help with the hierarchical prefix | |
| if with_prefix is not None and with_prefix != "" and field_help is not None: | |
| field_help = with_prefix[2:] + ": " + field_help | |
| kwargs["help"] = field_help | |
| if field_const is not None: | |
| kwargs["const"] = field_const | |
| kwargs["nargs"] = "?" | |
| return kwargs | |
| for k in dataclass_instance._get_all_attributes(): | |
| field_name = argparse_name(dataclass_instance._get_name(k)) | |
| field_type = dataclass_instance._get_type(k) | |
| if field_name is None: | |
| continue | |
| elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass): | |
| # for fields that are of type FairseqDataclass, we can recursively | |
| # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace) | |
| prefix = None | |
| if with_prefix is not None: | |
| # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace | |
| # but we prefix them with the name of the current field. | |
| prefix = field_name | |
| gen_parser_from_dataclass(parser, field_type(), delete_default, prefix) | |
| continue | |
| kwargs = get_kwargs_from_dc(dataclass_instance, k) | |
| field_args = [field_name] | |
| alias = dataclass_instance._get_argparse_alias(k) | |
| if alias is not None: | |
| field_args.append(alias) | |
| if "default" in kwargs: | |
| if isinstance(kwargs["default"], str) and kwargs["default"].startswith( | |
| "${" | |
| ): | |
| if kwargs["help"] is None: | |
| # this is a field with a name that will be added elsewhere | |
| continue | |
| else: | |
| del kwargs["default"] | |
| if delete_default and "default" in kwargs: | |
| del kwargs["default"] | |
| try: | |
| parser.add_argument(*field_args, **kwargs) | |
| except ArgumentError: | |
| pass | |
| def _set_legacy_defaults(args, cls): | |
| """Helper to set default arguments based on *add_args*.""" | |
| if not hasattr(cls, "add_args"): | |
| return | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| argument_default=argparse.SUPPRESS, allow_abbrev=False | |
| ) | |
| cls.add_args(parser) | |
| # copied from argparse.py: | |
| defaults = argparse.Namespace() | |
| for action in parser._actions: | |
| if action.dest is not argparse.SUPPRESS: | |
| if not hasattr(defaults, action.dest): | |
| if action.default is not argparse.SUPPRESS: | |
| setattr(defaults, action.dest, action.default) | |
| for key, default_value in vars(defaults).items(): | |
| if not hasattr(args, key): | |
| setattr(args, key, default_value) | |
| def _override_attr( | |
| sub_node: str, data_class: Type[FairseqDataclass], args: Namespace | |
| ) -> List[str]: | |
| overrides = [] | |
| if not inspect.isclass(data_class) or not issubclass(data_class, FairseqDataclass): | |
| return overrides | |
| def get_default(f): | |
| if not isinstance(f.default_factory, _MISSING_TYPE): | |
| return f.default_factory() | |
| return f.default | |
| for k, v in data_class.__dataclass_fields__.items(): | |
| if k.startswith("_"): | |
| # private member, skip | |
| continue | |
| val = get_default(v) if not hasattr(args, k) else getattr(args, k) | |
| field_type = interpret_dc_type(v.type) | |
| if ( | |
| isinstance(val, str) | |
| and not val.startswith("${") # not interpolation | |
| and field_type != str | |
| and ( | |
| not inspect.isclass(field_type) or not issubclass(field_type, Enum) | |
| ) # not choices enum | |
| ): | |
| # upgrade old models that stored complex parameters as string | |
| val = ast.literal_eval(val) | |
| if isinstance(val, tuple): | |
| val = list(val) | |
| v_type = getattr(v.type, "__origin__", None) | |
| if ( | |
| (v_type is List or v_type is list or v_type is Optional) | |
| # skip interpolation | |
| and not (isinstance(val, str) and val.startswith("${")) | |
| ): | |
| # if type is int but val is float, then we will crash later - try to convert here | |
| if hasattr(v.type, "__args__"): | |
| t_args = v.type.__args__ | |
| if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int): | |
| val = list(map(t_args[0], val)) | |
| elif val is not None and ( | |
| field_type is int or field_type is bool or field_type is float | |
| ): | |
| try: | |
| val = field_type(val) | |
| except: | |
| pass # ignore errors here, they are often from interpolation args | |
| if val is None: | |
| overrides.append("{}.{}=null".format(sub_node, k)) | |
| elif val == "": | |
| overrides.append("{}.{}=''".format(sub_node, k)) | |
| elif isinstance(val, str): | |
| val = val.replace("'", r"\'") | |
| overrides.append("{}.{}='{}'".format(sub_node, k, val)) | |
| elif isinstance(val, FairseqDataclass): | |
| overrides += _override_attr(f"{sub_node}.{k}", type(val), args) | |
| elif isinstance(val, Namespace): | |
| sub_overrides, _ = override_module_args(val) | |
| for so in sub_overrides: | |
| overrides.append(f"{sub_node}.{k}.{so}") | |
| else: | |
| overrides.append("{}.{}={}".format(sub_node, k, val)) | |
| return overrides | |
| def migrate_registry( | |
| name, value, registry, args, overrides, deletes, use_name_as_val=False | |
| ): | |
| if value in registry: | |
| overrides.append("{}={}".format(name, value)) | |
| overrides.append("{}._name={}".format(name, value)) | |
| overrides.extend(_override_attr(name, registry[value], args)) | |
| elif use_name_as_val and value is not None: | |
| overrides.append("{}={}".format(name, value)) | |
| else: | |
| deletes.append(name) | |
| def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: | |
| """use the field in args to overrides those in cfg""" | |
| overrides = [] | |
| deletes = [] | |
| for k in FairseqConfig.__dataclass_fields__.keys(): | |
| overrides.extend( | |
| _override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args) | |
| ) | |
| if args is not None: | |
| if hasattr(args, "task"): | |
| from fairseq.tasks import TASK_DATACLASS_REGISTRY | |
| migrate_registry( | |
| "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes | |
| ) | |
| else: | |
| deletes.append("task") | |
| # these options will be set to "None" if they have not yet been migrated | |
| # so we can populate them with the entire flat args | |
| CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} | |
| from fairseq.registry import REGISTRIES | |
| for k, v in REGISTRIES.items(): | |
| if hasattr(args, k): | |
| migrate_registry( | |
| k, | |
| getattr(args, k), | |
| v["dataclass_registry"], | |
| args, | |
| overrides, | |
| deletes, | |
| use_name_as_val=k not in CORE_REGISTRIES, | |
| ) | |
| else: | |
| deletes.append(k) | |
| no_dc = True | |
| if hasattr(args, "arch"): | |
| from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY | |
| if args.arch in ARCH_MODEL_REGISTRY: | |
| m_cls = ARCH_MODEL_REGISTRY[args.arch] | |
| dc = getattr(m_cls, "__dataclass", None) | |
| if dc is not None: | |
| m_name = ARCH_MODEL_NAME_REGISTRY[args.arch] | |
| overrides.append("model={}".format(m_name)) | |
| overrides.append("model._name={}".format(args.arch)) | |
| # override model params with those exist in args | |
| overrides.extend(_override_attr("model", dc, args)) | |
| no_dc = False | |
| if no_dc: | |
| deletes.append("model") | |
| return overrides, deletes | |
| class omegaconf_no_object_check: | |
| def __init__(self): | |
| # Changed in https://github.com/omry/omegaconf/pull/911 - both are kept for back compat. | |
| if hasattr(_utils, "is_primitive_type"): | |
| self.old_is_primitive = _utils.is_primitive_type | |
| else: | |
| self.old_is_primitive = _utils.is_primitive_type_annotation | |
| def __enter__(self): | |
| if hasattr(_utils, "is_primitive_type"): | |
| _utils.is_primitive_type = lambda _: True | |
| else: | |
| _utils.is_primitive_type_annotation = lambda _: True | |
| def __exit__(self, type, value, traceback): | |
| if hasattr(_utils, "is_primitive_type"): | |
| _utils.is_primitive_type = self.old_is_primitive | |
| else: | |
| _utils.is_primitive_type_annotation = self.old_is_primitive | |
| def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: | |
| """Convert a flat argparse.Namespace to a structured DictConfig.""" | |
| # Here we are using field values provided in args to override counterparts inside config object | |
| overrides, deletes = override_module_args(args) | |
| # configs will be in fairseq/config after installation | |
| config_path = os.path.join("..", "config") | |
| GlobalHydra.instance().clear() | |
| with initialize(config_path=config_path): | |
| try: | |
| composed_cfg = compose("config", overrides=overrides, strict=False) | |
| except: | |
| logger.error("Error when composing. Overrides: " + str(overrides)) | |
| raise | |
| for k in deletes: | |
| composed_cfg[k] = None | |
| cfg = OmegaConf.create( | |
| OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) | |
| ) | |
| # hack to be able to set Namespace in dict config. this should be removed when we update to newer | |
| # omegaconf version that supports object flags, or when we migrate all existing models | |
| from omegaconf import _utils | |
| with omegaconf_no_object_check(): | |
| if cfg.task is None and getattr(args, "task", None): | |
| cfg.task = Namespace(**vars(args)) | |
| from fairseq.tasks import TASK_REGISTRY | |
| _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) | |
| cfg.task._name = args.task | |
| if cfg.model is None and getattr(args, "arch", None): | |
| cfg.model = Namespace(**vars(args)) | |
| from fairseq.models import ARCH_MODEL_REGISTRY | |
| _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) | |
| cfg.model._name = args.arch | |
| if cfg.optimizer is None and getattr(args, "optimizer", None): | |
| cfg.optimizer = Namespace(**vars(args)) | |
| from fairseq.optim import OPTIMIZER_REGISTRY | |
| _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) | |
| cfg.optimizer._name = args.optimizer | |
| if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): | |
| cfg.lr_scheduler = Namespace(**vars(args)) | |
| from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY | |
| _set_legacy_defaults( | |
| cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler] | |
| ) | |
| cfg.lr_scheduler._name = args.lr_scheduler | |
| if cfg.criterion is None and getattr(args, "criterion", None): | |
| cfg.criterion = Namespace(**vars(args)) | |
| from fairseq.criterions import CRITERION_REGISTRY | |
| _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) | |
| cfg.criterion._name = args.criterion | |
| OmegaConf.set_struct(cfg, True) | |
| return cfg | |
| def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): | |
| # this will be deprecated when we get rid of argparse and model_overrides logic | |
| from fairseq.registry import REGISTRIES | |
| with open_dict(cfg): | |
| for k in cfg.keys(): | |
| # "k in cfg" will return false if its a "mandatory value (e.g. ???)" | |
| if k in cfg and isinstance(cfg[k], DictConfig): | |
| if k in overrides and isinstance(overrides[k], dict): | |
| for ok, ov in overrides[k].items(): | |
| if isinstance(ov, dict) and cfg[k][ok] is not None: | |
| overwrite_args_by_name(cfg[k][ok], ov) | |
| else: | |
| cfg[k][ok] = ov | |
| else: | |
| overwrite_args_by_name(cfg[k], overrides) | |
| elif k in cfg and isinstance(cfg[k], Namespace): | |
| for override_key, val in overrides.items(): | |
| setattr(cfg[k], override_key, val) | |
| elif k in overrides: | |
| if ( | |
| k in REGISTRIES | |
| and overrides[k] in REGISTRIES[k]["dataclass_registry"] | |
| ): | |
| cfg[k] = DictConfig( | |
| REGISTRIES[k]["dataclass_registry"][overrides[k]] | |
| ) | |
| overwrite_args_by_name(cfg[k], overrides) | |
| cfg[k]._name = overrides[k] | |
| else: | |
| cfg[k] = overrides[k] | |
| def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False): | |
| if remove_missing: | |
| def remove_missing_rec(src_keys, target_cfg): | |
| if is_dataclass(target_cfg): | |
| target_keys = set(target_cfg.__dataclass_fields__.keys()) | |
| else: | |
| target_keys = set(target_cfg.keys()) | |
| for k in list(src_keys.keys()): | |
| if k not in target_keys: | |
| del src_keys[k] | |
| elif OmegaConf.is_config(src_keys[k]): | |
| tgt = getattr(target_cfg, k) | |
| if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")): | |
| remove_missing_rec(src_keys[k], tgt) | |
| with open_dict(cfg): | |
| remove_missing_rec(cfg, dc) | |
| merged_cfg = OmegaConf.merge(dc, cfg) | |
| merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] | |
| OmegaConf.set_struct(merged_cfg, True) | |
| return merged_cfg | |