|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
|
|
from fairseq.dataclass.initialize import add_defaults, hydra_init |
|
from fairseq_cli.train import main as pre_main |
|
from fairseq import distributed_utils, metrics |
|
from fairseq.dataclass.configs import FairseqConfig |
|
from fairseq.utils import reset_logging |
|
|
|
import hydra |
|
from hydra.core.hydra_config import HydraConfig |
|
import torch |
|
from omegaconf import OmegaConf, open_dict |
|
|
|
|
|
logger = logging.getLogger("fairseq_cli.hydra_train") |
|
|
|
|
|
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") |
|
def hydra_main(cfg: FairseqConfig) -> float: |
|
add_defaults(cfg) |
|
|
|
if cfg.common.reset_logging: |
|
reset_logging() |
|
else: |
|
with open_dict(cfg): |
|
|
|
cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True) |
|
|
|
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) |
|
OmegaConf.set_struct(cfg, True) |
|
|
|
try: |
|
if cfg.common.profile: |
|
with torch.cuda.profiler.profile(): |
|
with torch.autograd.profiler.emit_nvtx(): |
|
distributed_utils.call_main(cfg, pre_main) |
|
else: |
|
distributed_utils.call_main(cfg, pre_main) |
|
except BaseException as e: |
|
if not cfg.common.suppress_crashes: |
|
raise |
|
else: |
|
logger.error("Crashed! " + str(e)) |
|
|
|
|
|
try: |
|
best_val = metrics.get_smoothed_value( |
|
"valid", cfg.checkpoint.best_checkpoint_metric |
|
) |
|
except: |
|
best_val = None |
|
|
|
if best_val is None: |
|
best_val = float("inf") |
|
|
|
return best_val |
|
|
|
|
|
def cli_main(): |
|
try: |
|
from hydra._internal.utils import get_args |
|
|
|
cfg_name = get_args().config_name or "config" |
|
except: |
|
logger.warning("Failed to get config name from hydra args") |
|
cfg_name = "config" |
|
|
|
hydra_init(cfg_name) |
|
hydra_main() |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|