| import time
|
| import warnings
|
| from importlib.util import find_spec
|
| from pathlib import Path
|
| from typing import Callable, List
|
|
|
| import hydra
|
| from omegaconf import DictConfig
|
| from pytorch_lightning import Callback
|
| from pytorch_lightning.loggers import LightningLoggerBase
|
| from pytorch_lightning.utilities import rank_zero_only
|
|
|
| from src.utils import pylogger, rich_utils
|
|
|
| log = pylogger.get_pylogger(__name__)
|
|
|
|
|
| def task_wrapper(task_func: Callable) -> Callable:
|
| """Optional decorator that wraps the task function in extra utilities.
|
|
|
| Makes multirun more resistant to failure.
|
|
|
| Utilities:
|
| - Calling the `utils.extras()` before the task is started
|
| - Calling the `utils.close_loggers()` after the task is finished
|
| - Logging the exception if occurs
|
| - Logging the task total execution time
|
| - Logging the output dir
|
| """
|
|
|
| def wrap(cfg: DictConfig):
|
|
|
| extras(cfg)
|
|
|
|
|
| try:
|
| start_time = time.time()
|
| metric_dict, object_dict = task_func(cfg=cfg)
|
| except Exception as ex:
|
| log.exception("")
|
| raise ex
|
| finally:
|
| path = Path(cfg.paths.output_dir, "exec_time.log")
|
| content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"
|
| save_file(path, content)
|
| close_loggers()
|
|
|
| log.info(f"Output dir: {cfg.paths.output_dir}")
|
|
|
| return metric_dict, object_dict
|
|
|
| return wrap
|
|
|
|
|
| def extras(cfg: DictConfig) -> None:
|
| """Applies optional utilities before the task is started.
|
|
|
| Utilities:
|
| - Ignoring python warnings
|
| - Setting tags from command line
|
| - Rich config printing
|
| """
|
|
|
| if not cfg.get("extras"):
|
| log.warning("Extras config not found! <cfg.extras=null>")
|
| return
|
|
|
|
|
| if cfg.extras.get("ignore_warnings"):
|
| log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
| warnings.filterwarnings("ignore")
|
|
|
|
|
| if cfg.extras.get("enforce_tags"):
|
| log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
| rich_utils.enforce_tags(cfg, save_to_file=True)
|
|
|
|
|
| if cfg.extras.get("print_config"):
|
| log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
| rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
|
|
|
|
| @rank_zero_only
|
| def save_file(path: str, content: str) -> None:
|
| """Save file in rank zero mode (only on one process in multi-GPU setup)."""
|
| with open(path, "w+") as file:
|
| file.write(content)
|
|
|
|
|
| def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
| """Instantiates callbacks from config."""
|
| callbacks: List[Callback] = []
|
|
|
| if not callbacks_cfg:
|
| log.warning("Callbacks config is empty.")
|
| return callbacks
|
|
|
| if not isinstance(callbacks_cfg, DictConfig):
|
| raise TypeError("Callbacks config must be a DictConfig!")
|
|
|
| for _, cb_conf in callbacks_cfg.items():
|
| if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
| log.info(f"Instantiating callback <{cb_conf._target_}>")
|
| callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
|
| return callbacks
|
|
|
|
|
| def instantiate_loggers(logger_cfg: DictConfig) -> List[LightningLoggerBase]:
|
| """Instantiates loggers from config."""
|
| logger: List[LightningLoggerBase] = []
|
|
|
| if not logger_cfg:
|
| log.warning("Logger config is empty.")
|
| return logger
|
|
|
| if not isinstance(logger_cfg, DictConfig):
|
| raise TypeError("Logger config must be a DictConfig!")
|
|
|
| for _, lg_conf in logger_cfg.items():
|
| if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
| log.info(f"Instantiating logger <{lg_conf._target_}>")
|
| logger.append(hydra.utils.instantiate(lg_conf))
|
|
|
| return logger
|
|
|
|
|
| @rank_zero_only
|
| def log_hyperparameters(object_dict: dict) -> None:
|
| """Controls which config parts are saved by lightning loggers.
|
|
|
| Additionally saves:
|
| - Number of model parameters
|
| """
|
| hparams = {}
|
|
|
| cfg = object_dict["cfg"]
|
| model = object_dict["model"]
|
| trainer = object_dict["trainer"]
|
|
|
| if not trainer.loggers:
|
| log.warning("Logger not found! Skipping hyperparameter logging...")
|
| return
|
|
|
| hparams["model"] = cfg["model"]
|
|
|
|
|
| hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
| hparams["model/params/trainable"] = sum(
|
| p.numel() for p in model.parameters() if p.requires_grad
|
| )
|
| hparams["model/params/non_trainable"] = sum(
|
| p.numel() for p in model.parameters() if not p.requires_grad
|
| )
|
|
|
| hparams["datamodule"] = cfg["datamodule"]
|
| hparams["trainer"] = cfg["trainer"]
|
|
|
| hparams["callbacks"] = cfg.get("callbacks")
|
| hparams["extras"] = cfg.get("extras")
|
|
|
| hparams["task_name"] = cfg.get("task_name")
|
| hparams["tags"] = cfg.get("tags")
|
| hparams["ckpt_path"] = cfg.get("ckpt_path")
|
| hparams["seed"] = cfg.get("seed")
|
|
|
|
|
| for logger in trainer.loggers:
|
| logger.log_hyperparams(hparams)
|
|
|
|
|
| def get_metric_value(metric_dict: dict, metric_name: str) -> float:
|
| """Safely retrieves value of the metric logged in LightningModule."""
|
| if not metric_name:
|
| log.info("Metric name is None! Skipping metric value retrieval...")
|
| return None
|
|
|
| if metric_name not in metric_dict:
|
| raise Exception(
|
| f"Metric value not found! <metric_name={metric_name}>\n"
|
| "Make sure metric name logged in LightningModule is correct!\n"
|
| "Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
| )
|
|
|
| metric_value = metric_dict[metric_name].item()
|
| log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
|
|
| return metric_value
|
|
|
|
|
| def close_loggers() -> None:
|
| """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
|
| log.info("Closing loggers...")
|
|
|
| if find_spec("wandb"):
|
| import wandb
|
|
|
| if wandb.run:
|
| log.info("Closing wandb!")
|
| wandb.finish()
|
|
|