|
import os |
|
from typing import List |
|
|
|
import hydra |
|
from omegaconf import DictConfig |
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything |
|
from pytorch_lightning.loggers import LightningLoggerBase |
|
|
|
from src import utils |
|
|
|
log = utils.get_logger(__name__) |
|
|
|
|
|
def test(config: DictConfig) -> None: |
|
"""Contains minimal example of the testing pipeline. |
|
Evaluates given checkpoint on a testset. |
|
|
|
Args: |
|
config (DictConfig): Configuration composed by Hydra. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
|
|
if config.get("seed"): |
|
seed_everything(config.seed, workers=True) |
|
|
|
|
|
if not os.path.isabs(config.ckpt_path): |
|
config.ckpt_path = os.path.join(hydra.utils.get_original_cwd(), config.ckpt_path) |
|
|
|
|
|
log.info(f"Instantiating datamodule <{config.datamodule._target_}>") |
|
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) |
|
|
|
|
|
log.info(f"Instantiating model <{config.model._target_}>") |
|
model: LightningModule = hydra.utils.instantiate(config.model) |
|
|
|
|
|
logger: List[LightningLoggerBase] = [] |
|
if "logger" in config: |
|
for _, lg_conf in config.logger.items(): |
|
if "_target_" in lg_conf: |
|
log.info(f"Instantiating logger <{lg_conf._target_}>") |
|
logger.append(hydra.utils.instantiate(lg_conf)) |
|
|
|
|
|
log.info(f"Instantiating trainer <{config.trainer._target_}>") |
|
trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=logger) |
|
|
|
|
|
trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path}) |
|
|
|
log.info("Starting testing!") |
|
trainer.test(model=model, datamodule=datamodule, ckpt_path=config.ckpt_path) |
|
|