|
from typing import Any, Dict, List, Tuple |
|
|
|
import hydra |
|
import rootutils |
|
from lightning import Callback, LightningDataModule, LightningModule, Trainer |
|
from lightning.pytorch.loggers import Logger |
|
from omegaconf import DictConfig |
|
|
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from src.utils import ( |
|
RankedLogger, |
|
extras, |
|
instantiate_callbacks, |
|
instantiate_loggers, |
|
log_hyperparameters, |
|
task_wrapper, |
|
) |
|
|
|
log = RankedLogger(__name__, rank_zero_only=True) |
|
|
|
|
|
@task_wrapper |
|
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
|
"""Evaluates given checkpoint on a datamodule testset. |
|
|
|
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during |
|
failure. Useful for multiruns, saving info about the crash, etc. |
|
|
|
:param cfg: DictConfig configuration composed by Hydra. |
|
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects. |
|
""" |
|
assert cfg.ckpt_path |
|
|
|
log.info(f"Instantiating datamodule <{cfg.data._target_}>") |
|
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) |
|
|
|
log.info(f"Instantiating model <{cfg.model._target_}>") |
|
model: LightningModule = hydra.utils.instantiate(cfg.model) |
|
|
|
log.info("Instantiating callbacks...") |
|
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) |
|
|
|
log.info("Instantiating loggers...") |
|
logger: List[Logger] = instantiate_loggers(cfg.get("logger")) |
|
|
|
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") |
|
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) |
|
|
|
object_dict = { |
|
"cfg": cfg, |
|
"datamodule": datamodule, |
|
"model": model, |
|
"callbacks": callbacks, |
|
"logger": logger, |
|
"trainer": trainer, |
|
} |
|
|
|
if logger: |
|
log.info("Logging hyperparameters!") |
|
log_hyperparameters(object_dict) |
|
|
|
log.info("Starting testing!") |
|
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) |
|
|
|
|
|
|
|
|
|
metric_dict = trainer.callback_metrics |
|
|
|
return metric_dict, object_dict |
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") |
|
def main(cfg: DictConfig) -> None: |
|
"""Main entry point for evaluation. |
|
|
|
:param cfg: DictConfig configuration composed by Hydra. |
|
""" |
|
|
|
|
|
extras(cfg) |
|
|
|
evaluate(cfg) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|