Plonk /
nicolas-dufour's picture
squash: merge all unpushed commits
history blame
2.61 kB
import os
from models.module import DiffGeolocalizer
import hydra
import wandb
from os.path import isfile, join
from shutil import copyfile
import torch
from omegaconf import OmegaConf
from omegaconf import open_dict
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from pytorch_lightning.callbacks import LearningRateMonitor
from lightning_fabric.utilities.rank_zero import _get_rank
from models.module import DiffGeolocalizer
torch.set_float32_matmul_precision("high") # TODO do we need that?
# Registering the "eval" resolver allows for advanced config
# interpolation with arithmetic operations in hydra:
OmegaConf.register_new_resolver("eval", eval)
def load_model(cfg, dict_config, wandb_id):
logger = instantiate(cfg.logger, id=open(wandb_id, "r").read(), resume="allow")
model = DiffGeolocalizer.load_from_checkpoint(cfg.checkpoint, cfg=cfg.model)
trainer = instantiate(cfg.trainer, strategy=cfg.trainer.strategy, logger=logger)
return trainer, model
def hydra_boilerplate(cfg):
dict_config = OmegaConf.to_container(cfg, resolve=True)
trainer, model = load_model(cfg, dict_config, cfg.wandb_id)
return trainer, model
import copy
def generate_datamodules(cfg_):
for f in os.listdir(cfg_.test_dir):
cfg = copy.deepcopy(cfg_)
# open join(f, directory) with OmegaConf
with open_dict(cfg):
cfg_new = OmegaConf.load(join(cfg.test_dir, f))
cfg.datamodule = cfg_new.datamodule
cfg.dataset = cfg_new.dataset
cfg.dataset.test_transform = cfg_.dataset.test_transform
datamodule = instantiate(cfg.datamodule)
yield datamodule
if __name__ == "__main__":
import sys
sys.argv = (
+ ["+pt_model_path=${hydra:runtime.config_sources}"]
+ sys.argv[1:]
def main(cfg):
# print(hydra.runtime.config_sources)
with open_dict(cfg):
path = cfg.pt_model_path[1]["path"]
cfg.wandb_id = join(path, "wandb_id.txt")
cfg.checkpoint = join(path, "last.ckpt") = 1
) = hydra_boilerplate(cfg)
for datamodule in generate_datamodules(cfg):
model.datamodule = datamodule
print("Testing on", datamodule.test_dataset.class_name)
trainer.test(model, datamodule=datamodule)