import argparse from pathlib import Path from pprint import pprint from typing import Optional import pkg_resources from omegaconf import OmegaConf from ..models import get_model from ..settings import TRAINING_PATH from ..utils.experiments import load_experiment def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path: default_configs = {} for c in pkg_resources.resource_listdir("gluefactory", str(defaults)): if c.endswith(".yaml"): default_configs[Path(c).stem] = Path( pkg_resources.resource_filename("gluefactory", defaults + c) ) if name_or_path is None: return None if name_or_path in default_configs: return default_configs[name_or_path] path = Path(name_or_path) if not path.exists(): raise FileNotFoundError( f"Cannot find the config file: {name_or_path}. " f"Not in the default configs {list(default_configs.keys())} " "and not an existing path." ) return Path(path) def extract_benchmark_conf(conf, benchmark): mconf = OmegaConf.create( { "model": conf.get("model", {}), } ) if "benchmarks" in conf.keys(): return OmegaConf.merge(mconf, conf.benchmarks.get(benchmark, {})) else: return mconf def parse_eval_args(benchmark, args, configs_path, default=None): conf = {"data": {}, "model": {}, "eval": {}} if args.conf: conf_path = parse_config_path(args.conf, configs_path) custom_conf = OmegaConf.load(conf_path) conf = extract_benchmark_conf(OmegaConf.merge(conf, custom_conf), benchmark) args.tag = ( args.tag if args.tag is not None else conf_path.name.replace(".yaml", "") ) cli_conf = OmegaConf.from_cli(args.dotlist) conf = OmegaConf.merge(conf, cli_conf) conf.checkpoint = args.checkpoint if args.checkpoint else conf.get("checkpoint") if conf.checkpoint and not conf.checkpoint.endswith(".tar"): checkpoint_conf = OmegaConf.load( TRAINING_PATH / conf.checkpoint / "config.yaml" ) conf = OmegaConf.merge(extract_benchmark_conf(checkpoint_conf, benchmark), conf) if default: conf = OmegaConf.merge(default, conf) if args.tag is not None: name = args.tag elif args.conf and conf.checkpoint: name = f"{args.conf}_{conf.checkpoint}" elif args.conf: name = args.conf elif conf.checkpoint: name = conf.checkpoint if len(args.dotlist) > 0 and not args.tag: name = name + "_" + ":".join(args.dotlist) print("Running benchmark:", benchmark) print("Experiment tag:", name) print("Config:") pprint(OmegaConf.to_container(conf)) return name, conf def load_model(model_conf, checkpoint): if checkpoint: model = load_experiment(checkpoint, conf=model_conf).eval() else: model = get_model("two_view_pipeline")(model_conf).eval() if not model.is_initialized(): raise ValueError( "The provided model has non-initialized parameters. " + "Try to load a checkpoint instead." ) return model def get_eval_parser(): parser = argparse.ArgumentParser() parser.add_argument("--tag", type=str, default=None) parser.add_argument("--checkpoint", type=str, default=None) parser.add_argument("--conf", type=str, default=None) parser.add_argument("--overwrite", action="store_true") parser.add_argument("--overwrite_eval", action="store_true") parser.add_argument("--plot", action="store_true") parser.add_argument("dotlist", nargs="*") return parser