Spaces:
Build error
Build error
| import time | |
| import warnings | |
| from importlib.util import find_spec | |
| from pathlib import Path | |
| from typing import Callable, List | |
| import hydra | |
| from omegaconf import DictConfig, OmegaConf | |
| from pytorch_lightning import Callback | |
| from pytorch_lightning.loggers import Logger | |
| from pytorch_lightning.utilities import rank_zero_only | |
| from . import pylogger, rich_utils | |
| log = pylogger.get_pylogger(__name__) | |
| def task_wrapper(task_func: Callable) -> Callable: | |
| """Optional decorator that wraps the task function in extra utilities. | |
| Makes multirun more resistant to failure. | |
| Utilities: | |
| - Calling the `utils.extras()` before the task is started | |
| - Calling the `utils.close_loggers()` after the task is finished | |
| - Logging the exception if occurs | |
| - Logging the task total execution time | |
| - Logging the output dir | |
| """ | |
| def wrap(cfg: DictConfig): | |
| # apply extra utilities | |
| extras(cfg) | |
| # execute the task | |
| try: | |
| start_time = time.time() | |
| ret = task_func(cfg=cfg) | |
| except Exception as ex: | |
| log.exception("") # save exception to `.log` file | |
| raise ex | |
| finally: | |
| path = Path(cfg.paths.output_dir, "exec_time.log") | |
| content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" | |
| save_file(path, content) # save task execution time (even if exception occurs) | |
| close_loggers() # close loggers (even if exception occurs so multirun won't fail) | |
| log.info(f"Output dir: {cfg.paths.output_dir}") | |
| return ret | |
| return wrap | |
| def extras(cfg: DictConfig) -> None: | |
| """Applies optional utilities before the task is started. | |
| Utilities: | |
| - Ignoring python warnings | |
| - Setting tags from command line | |
| - Rich config printing | |
| """ | |
| # return if no `extras` config | |
| if not cfg.get("extras"): | |
| log.warning("Extras config not found! <cfg.extras=null>") | |
| return | |
| # disable python warnings | |
| if cfg.extras.get("ignore_warnings"): | |
| log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") | |
| warnings.filterwarnings("ignore") | |
| # prompt user to input tags from command line if none are provided in the config | |
| if cfg.extras.get("enforce_tags"): | |
| log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") | |
| rich_utils.enforce_tags(cfg, save_to_file=True) | |
| # pretty print config tree using Rich library | |
| if cfg.extras.get("print_config"): | |
| log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") | |
| rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) | |
| def save_file(path: str, content: str) -> None: | |
| """Save file in rank zero mode (only on one process in multi-GPU setup).""" | |
| with open(path, "w+") as file: | |
| file.write(content) | |
| def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: | |
| """Instantiates callbacks from config.""" | |
| callbacks: List[Callback] = [] | |
| if not callbacks_cfg: | |
| log.warning("Callbacks config is empty.") | |
| return callbacks | |
| if not isinstance(callbacks_cfg, DictConfig): | |
| raise TypeError("Callbacks config must be a DictConfig!") | |
| for _, cb_conf in callbacks_cfg.items(): | |
| if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: | |
| log.info(f"Instantiating callback <{cb_conf._target_}>") | |
| callbacks.append(hydra.utils.instantiate(cb_conf)) | |
| return callbacks | |
| def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: | |
| """Instantiates loggers from config.""" | |
| logger: List[Logger] = [] | |
| if not logger_cfg: | |
| log.warning("Logger config is empty.") | |
| return logger | |
| if not isinstance(logger_cfg, DictConfig): | |
| raise TypeError("Logger config must be a DictConfig!") | |
| for _, lg_conf in logger_cfg.items(): | |
| if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: | |
| log.info(f"Instantiating logger <{lg_conf._target_}>") | |
| logger.append(hydra.utils.instantiate(lg_conf)) | |
| return logger | |
| def log_hyperparameters(object_dict: dict) -> None: | |
| """Controls which config parts are saved by lightning loggers. | |
| Additionally saves: | |
| - Number of model parameters | |
| """ | |
| hparams = {} | |
| cfg = object_dict["cfg"] | |
| model = object_dict["model"] | |
| trainer = object_dict["trainer"] | |
| if not trainer.logger: | |
| log.warning("Logger not found! Skipping hyperparameter logging...") | |
| return | |
| # save number of model parameters | |
| hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) | |
| hparams["model/params/trainable"] = sum( | |
| p.numel() for p in model.parameters() if p.requires_grad | |
| ) | |
| hparams["model/params/non_trainable"] = sum( | |
| p.numel() for p in model.parameters() if not p.requires_grad | |
| ) | |
| for k in cfg.keys(): | |
| hparams[k] = cfg.get(k) | |
| # Resolve all interpolations | |
| def _resolve(_cfg): | |
| if isinstance(_cfg, DictConfig): | |
| _cfg = OmegaConf.to_container(_cfg, resolve=True) | |
| return _cfg | |
| hparams = {k: _resolve(v) for k, v in hparams.items()} | |
| # send hparams to all loggers | |
| trainer.logger.log_hyperparams(hparams) | |
| def get_metric_value(metric_dict: dict, metric_name: str) -> float: | |
| """Safely retrieves value of the metric logged in LightningModule.""" | |
| if not metric_name: | |
| log.info("Metric name is None! Skipping metric value retrieval...") | |
| return None | |
| if metric_name not in metric_dict: | |
| raise Exception( | |
| f"Metric value not found! <metric_name={metric_name}>\n" | |
| "Make sure metric name logged in LightningModule is correct!\n" | |
| "Make sure `optimized_metric` name in `hparams_search` config is correct!" | |
| ) | |
| metric_value = metric_dict[metric_name].item() | |
| log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") | |
| return metric_value | |
| def close_loggers() -> None: | |
| """Makes sure all loggers closed properly (prevents logging failure during multirun).""" | |
| log.info("Closing loggers...") | |
| if find_spec("wandb"): # if wandb is installed | |
| import wandb | |
| if wandb.run: | |
| log.info("Closing wandb!") | |
| wandb.finish() | |