|
|
|
|
|
|
|
|
|
"""isort:skip_file""" |
|
|
|
import logging |
|
from hydra.core.config_store import ConfigStore |
|
from fairseq.dataclass.configs import FairseqConfig |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def hydra_init(cfg_name="config") -> None: |
|
|
|
cs = ConfigStore.instance() |
|
cs.store(name=f"{cfg_name}", node=FairseqConfig) |
|
|
|
for k in FairseqConfig.__dataclass_fields__: |
|
v = FairseqConfig.__dataclass_fields__[k].default |
|
try: |
|
if (v is None): |
|
print("DEBUG",k,v) |
|
else: |
|
cs.store(name=k, node=v) |
|
except BaseException: |
|
logger.error(f"{k} - {v}") |
|
raise |
|
|
|
|
|
def add_defaults(cfg: DictConfig) -> None: |
|
"""This function adds default values that are stored in dataclasses that hydra doesn't know about""" |
|
|
|
from fairseq.registry import REGISTRIES |
|
from fairseq.tasks import TASK_DATACLASS_REGISTRY |
|
from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY |
|
from fairseq.dataclass.utils import merge_with_parent |
|
from typing import Any |
|
|
|
OmegaConf.set_struct(cfg, False) |
|
|
|
for k, v in FairseqConfig.__dataclass_fields__.items(): |
|
field_cfg = cfg.get(k) |
|
if field_cfg is not None and v.type == Any: |
|
dc = None |
|
|
|
if isinstance(field_cfg, str): |
|
field_cfg = DictConfig({"_name": field_cfg}) |
|
field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"] |
|
|
|
name = getattr(field_cfg, "_name", None) |
|
|
|
if k == "task": |
|
dc = TASK_DATACLASS_REGISTRY.get(name) |
|
elif k == "model": |
|
name = ARCH_MODEL_NAME_REGISTRY.get(name, name) |
|
dc = MODEL_DATACLASS_REGISTRY.get(name) |
|
elif k in REGISTRIES: |
|
dc = REGISTRIES[k]["dataclass_registry"].get(name) |
|
|
|
if dc is not None: |
|
cfg[k] = merge_with_parent(dc, field_cfg) |
|
|