|
from pathlib import Path |
|
from typing import Sequence |
|
|
|
import rich |
|
import rich.syntax |
|
import rich.tree |
|
from hydra.core.hydra_config import HydraConfig |
|
from lightning.pytorch.utilities import rank_zero_only |
|
from omegaconf import DictConfig, OmegaConf, open_dict |
|
from rich.prompt import Prompt |
|
|
|
from matcha.utils import pylogger |
|
|
|
log = pylogger.get_pylogger(__name__) |
|
|
|
|
|
@rank_zero_only |
|
def print_config_tree( |
|
cfg: DictConfig, |
|
print_order: Sequence[str] = ( |
|
"data", |
|
"model", |
|
"callbacks", |
|
"logger", |
|
"trainer", |
|
"paths", |
|
"extras", |
|
), |
|
resolve: bool = False, |
|
save_to_file: bool = False, |
|
) -> None: |
|
"""Prints the contents of a DictConfig as a tree structure using the Rich library. |
|
|
|
:param cfg: A DictConfig composed by Hydra. |
|
:param print_order: Determines in what order config components are printed. Default is ``("data", "model", |
|
"callbacks", "logger", "trainer", "paths", "extras")``. |
|
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. |
|
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. |
|
""" |
|
style = "dim" |
|
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
|
|
|
queue = [] |
|
|
|
|
|
for field in print_order: |
|
_ = ( |
|
queue.append(field) |
|
if field in cfg |
|
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") |
|
) |
|
|
|
|
|
for field in cfg: |
|
if field not in queue: |
|
queue.append(field) |
|
|
|
|
|
for field in queue: |
|
branch = tree.add(field, style=style, guide_style=style) |
|
|
|
config_group = cfg[field] |
|
if isinstance(config_group, DictConfig): |
|
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) |
|
else: |
|
branch_content = str(config_group) |
|
|
|
branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
|
|
|
|
|
rich.print(tree) |
|
|
|
|
|
if save_to_file: |
|
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: |
|
rich.print(tree, file=file) |
|
|
|
|
|
@rank_zero_only |
|
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: |
|
"""Prompts user to input tags from command line if no tags are provided in config. |
|
|
|
:param cfg: A DictConfig composed by Hydra. |
|
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. |
|
""" |
|
if not cfg.get("tags"): |
|
if "id" in HydraConfig().cfg.hydra.job: |
|
raise ValueError("Specify tags before launching a multirun!") |
|
|
|
log.warning("No tags provided in config. Prompting user to input tags...") |
|
tags = Prompt.ask("Enter a list of comma separated tags", default="dev") |
|
tags = [t.strip() for t in tags.split(",") if t != ""] |
|
|
|
with open_dict(cfg): |
|
cfg.tags = tags |
|
|
|
log.info(f"Tags: {cfg.tags}") |
|
|
|
if save_to_file: |
|
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: |
|
rich.print(cfg.tags, file=file) |
|
|