Spaces:
Running
on
Zero
Running
on
Zero
import time | |
import warnings | |
from importlib.util import find_spec | |
from pathlib import Path | |
from typing import Callable, List | |
import hydra | |
from omegaconf import DictConfig, OmegaConf | |
from pytorch_lightning import Callback | |
from pytorch_lightning.loggers import Logger | |
from pytorch_lightning.utilities import rank_zero_only | |
from . import pylogger, rich_utils | |
log = pylogger.get_pylogger(__name__) | |
def task_wrapper(task_func: Callable) -> Callable: | |
"""Optional decorator that wraps the task function in extra utilities. | |
Makes multirun more resistant to failure. | |
Utilities: | |
- Calling the `utils.extras()` before the task is started | |
- Calling the `utils.close_loggers()` after the task is finished | |
- Logging the exception if occurs | |
- Logging the task total execution time | |
- Logging the output dir | |
""" | |
def wrap(cfg: DictConfig): | |
# apply extra utilities | |
extras(cfg) | |
# execute the task | |
try: | |
start_time = time.time() | |
ret = task_func(cfg=cfg) | |
except Exception as ex: | |
log.exception("") # save exception to `.log` file | |
raise ex | |
finally: | |
path = Path(cfg.paths.output_dir, "exec_time.log") | |
content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" | |
save_file(path, content) # save task execution time (even if exception occurs) | |
close_loggers() # close loggers (even if exception occurs so multirun won't fail) | |
log.info(f"Output dir: {cfg.paths.output_dir}") | |
return ret | |
return wrap | |
def extras(cfg: DictConfig) -> None: | |
"""Applies optional utilities before the task is started. | |
Utilities: | |
- Ignoring python warnings | |
- Setting tags from command line | |
- Rich config printing | |
""" | |
# return if no `extras` config | |
if not cfg.get("extras"): | |
log.warning("Extras config not found! <cfg.extras=null>") | |
return | |
# disable python warnings | |
if cfg.extras.get("ignore_warnings"): | |
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") | |
warnings.filterwarnings("ignore") | |
# prompt user to input tags from command line if none are provided in the config | |
if cfg.extras.get("enforce_tags"): | |
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") | |
rich_utils.enforce_tags(cfg, save_to_file=True) | |
# pretty print config tree using Rich library | |
if cfg.extras.get("print_config"): | |
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") | |
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) | |
def save_file(path: str, content: str) -> None: | |
"""Save file in rank zero mode (only on one process in multi-GPU setup).""" | |
with open(path, "w+") as file: | |
file.write(content) | |
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: | |
"""Instantiates callbacks from config.""" | |
callbacks: List[Callback] = [] | |
if not callbacks_cfg: | |
log.warning("Callbacks config is empty.") | |
return callbacks | |
if not isinstance(callbacks_cfg, DictConfig): | |
raise TypeError("Callbacks config must be a DictConfig!") | |
for _, cb_conf in callbacks_cfg.items(): | |
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: | |
log.info(f"Instantiating callback <{cb_conf._target_}>") | |
callbacks.append(hydra.utils.instantiate(cb_conf)) | |
return callbacks | |
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: | |
"""Instantiates loggers from config.""" | |
logger: List[Logger] = [] | |
if not logger_cfg: | |
log.warning("Logger config is empty.") | |
return logger | |
if not isinstance(logger_cfg, DictConfig): | |
raise TypeError("Logger config must be a DictConfig!") | |
for _, lg_conf in logger_cfg.items(): | |
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: | |
log.info(f"Instantiating logger <{lg_conf._target_}>") | |
logger.append(hydra.utils.instantiate(lg_conf)) | |
return logger | |
def log_hyperparameters(object_dict: dict) -> None: | |
"""Controls which config parts are saved by lightning loggers. | |
Additionally saves: | |
- Number of model parameters | |
""" | |
hparams = {} | |
cfg = object_dict["cfg"] | |
model = object_dict["model"] | |
trainer = object_dict["trainer"] | |
if not trainer.logger: | |
log.warning("Logger not found! Skipping hyperparameter logging...") | |
return | |
# save number of model parameters | |
hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) | |
hparams["model/params/trainable"] = sum( | |
p.numel() for p in model.parameters() if p.requires_grad | |
) | |
hparams["model/params/non_trainable"] = sum( | |
p.numel() for p in model.parameters() if not p.requires_grad | |
) | |
for k in cfg.keys(): | |
hparams[k] = cfg.get(k) | |
# Resolve all interpolations | |
def _resolve(_cfg): | |
if isinstance(_cfg, DictConfig): | |
_cfg = OmegaConf.to_container(_cfg, resolve=True) | |
return _cfg | |
hparams = {k: _resolve(v) for k, v in hparams.items()} | |
# send hparams to all loggers | |
trainer.logger.log_hyperparams(hparams) | |
def get_metric_value(metric_dict: dict, metric_name: str) -> float: | |
"""Safely retrieves value of the metric logged in LightningModule.""" | |
if not metric_name: | |
log.info("Metric name is None! Skipping metric value retrieval...") | |
return None | |
if metric_name not in metric_dict: | |
raise Exception( | |
f"Metric value not found! <metric_name={metric_name}>\n" | |
"Make sure metric name logged in LightningModule is correct!\n" | |
"Make sure `optimized_metric` name in `hparams_search` config is correct!" | |
) | |
metric_value = metric_dict[metric_name].item() | |
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") | |
return metric_value | |
def close_loggers() -> None: | |
"""Makes sure all loggers closed properly (prevents logging failure during multirun).""" | |
log.info("Closing loggers...") | |
if find_spec("wandb"): # if wandb is installed | |
import wandb | |
if wandb.run: | |
log.info("Closing wandb!") | |
wandb.finish() | |