| """ |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation |
| |
| Official implementation of the paper: |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis |
| Licensed under a modified MIT license |
| """ |
|
|
| import time |
| import warnings |
| from importlib.util import find_spec |
| from pathlib import Path |
| from typing import Callable, List |
|
|
| import hydra |
| from omegaconf import DictConfig, OmegaConf |
| from pytorch_lightning import Callback |
| from pytorch_lightning.loggers import Logger |
| from pytorch_lightning.utilities import rank_zero_only |
|
|
| from . import pylogger, rich_utils |
|
|
| log = pylogger.get_pylogger(__name__) |
|
|
|
|
| def task_wrapper(task_func: Callable) -> Callable: |
| """Optional decorator that wraps the task function in extra utilities. |
| |
| Makes multirun more resistant to failure. |
| |
| Utilities: |
| - Calling the `utils.extras()` before the task is started |
| - Calling the `utils.close_loggers()` after the task is finished |
| - Logging the exception if occurs |
| - Logging the task total execution time |
| - Logging the output dir |
| """ |
|
|
| def wrap(cfg: DictConfig): |
| start_time = time.time() |
| try: |
| |
| extras(cfg) |
|
|
| |
| ret = task_func(cfg=cfg) |
| except Exception as ex: |
| log.exception("") |
| raise ex |
| finally: |
| path = Path(cfg.paths.output_dir, "exec_time.log") |
| content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" |
| save_file(path, content) |
| close_loggers() |
|
|
| log.info(f"Output dir: {cfg.paths.output_dir}") |
|
|
| return ret |
|
|
| return wrap |
|
|
|
|
| def extras(cfg: DictConfig) -> None: |
| """Applies optional utilities before the task is started. |
| |
| Utilities: |
| - Ignoring python warnings |
| - Setting tags from command line |
| - Rich config printing |
| """ |
|
|
| |
| if not cfg.get("extras"): |
| log.warning("Extras config not found! <cfg.extras=null>") |
| return |
|
|
| |
| if cfg.extras.get("ignore_warnings"): |
| log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") |
| warnings.filterwarnings("ignore") |
|
|
| |
| if cfg.extras.get("enforce_tags"): |
| log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") |
| rich_utils.enforce_tags(cfg, save_to_file=True) |
|
|
| |
| if cfg.extras.get("print_config"): |
| log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") |
| rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) |
|
|
|
|
| @rank_zero_only |
| def save_file(path: str, content: str) -> None: |
| """Save file in rank zero mode (only on one process in multi-GPU setup).""" |
| with open(path, "w+") as file: |
| file.write(content) |
|
|
|
|
| def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: |
| """Instantiates callbacks from config.""" |
| callbacks: List[Callback] = [] |
|
|
| if not callbacks_cfg: |
| log.warning("Callbacks config is empty.") |
| return callbacks |
|
|
| if not isinstance(callbacks_cfg, DictConfig): |
| raise TypeError("Callbacks config must be a DictConfig!") |
|
|
| for _, cb_conf in callbacks_cfg.items(): |
| if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: |
| log.info(f"Instantiating callback <{cb_conf._target_}>") |
| callbacks.append(hydra.utils.instantiate(cb_conf)) |
|
|
| return callbacks |
|
|
|
|
| def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: |
| """Instantiates loggers from config.""" |
| logger: List[Logger] = [] |
|
|
| if not logger_cfg: |
| log.warning("Logger config is empty.") |
| return logger |
|
|
| if not isinstance(logger_cfg, DictConfig): |
| raise TypeError("Logger config must be a DictConfig!") |
|
|
| for _, lg_conf in logger_cfg.items(): |
| if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: |
| log.info(f"Instantiating logger <{lg_conf._target_}>") |
| logger.append(hydra.utils.instantiate(lg_conf)) |
|
|
| return logger |
|
|
|
|
| @rank_zero_only |
| def log_hyperparameters(object_dict: dict) -> None: |
| """Controls which config parts are saved by lightning loggers. |
| |
| Additionally saves: |
| - Number of model parameters |
| """ |
|
|
| hparams = {} |
|
|
| cfg = object_dict["cfg"] |
| model = object_dict["model"] |
| trainer = object_dict["trainer"] |
|
|
| if not trainer.logger: |
| log.warning("Logger not found! Skipping hyperparameter logging...") |
| return |
|
|
| |
| hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) |
| hparams["model/params/trainable"] = sum( |
| p.numel() for p in model.parameters() if p.requires_grad |
| ) |
| hparams["model/params/non_trainable"] = sum( |
| p.numel() for p in model.parameters() if not p.requires_grad |
| ) |
|
|
| for k in cfg.keys(): |
| hparams[k] = cfg.get(k) |
|
|
| |
| def _resolve(_cfg): |
| if isinstance(_cfg, DictConfig): |
| _cfg = OmegaConf.to_container(_cfg, resolve=True) |
| return _cfg |
|
|
| hparams = {k: _resolve(v) for k, v in hparams.items()} |
|
|
| |
| trainer.logger.log_hyperparams(hparams) |
|
|
|
|
| def get_metric_value(metric_dict: dict, metric_name: str) -> float: |
| """Safely retrieves value of the metric logged in LightningModule.""" |
|
|
| if not metric_name: |
| log.info("Metric name is None! Skipping metric value retrieval...") |
| return None |
|
|
| if metric_name not in metric_dict: |
| raise Exception( |
| f"Metric value not found! <metric_name={metric_name}>\n" |
| "Make sure metric name logged in LightningModule is correct!\n" |
| "Make sure `optimized_metric` name in `hparams_search` config is correct!" |
| ) |
|
|
| metric_value = metric_dict[metric_name].item() |
| log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") |
|
|
| return metric_value |
|
|
|
|
| def close_loggers() -> None: |
| """Makes sure all loggers closed properly (prevents logging failure during multirun).""" |
|
|
| log.info("Closing loggers...") |
|
|
| if find_spec("wandb"): |
| import wandb |
|
|
| if wandb.run: |
| log.info("Closing wandb!") |
| wandb.finish() |
|
|