import glob import os import re from typing import List, Optional, Tuple, Union import numpy as np import torch from loguru import logger from torch import nn from df_local.config import Csv, config from df_local.model import init_model from df_local.utils import check_finite_module from libdf import DF def get_epoch(cp) -> int: return int(os.path.basename(cp).split(".")[0].split("_")[-1]) def load_model( cp_dir: Optional[str], df_state: DF, jit: bool = False, mask_only: bool = False, train_df_only: bool = False, extension: str = "ckpt", epoch: Union[str, int, None] = "latest", ) -> Tuple[nn.Module, int]: if mask_only and train_df_only: raise ValueError("Only one of `mask_only` `train_df_only` can be enabled") model = init_model(df_state, run_df=mask_only is False, train_mask=train_df_only is False) if jit: model = torch.jit.script(model) blacklist: List[str] = config("CP_BLACKLIST", [], Csv(), save=False, section="train") # type: ignore if cp_dir is not None: epoch = read_cp( model, "model", cp_dir, blacklist=blacklist, extension=extension, epoch=epoch ) epoch = 0 if epoch is None else epoch else: epoch = 0 return model, epoch def read_cp( obj: Union[torch.optim.Optimizer, nn.Module], name: str, dirname: str, epoch: Union[str, int, None] = "latest", extension="ckpt", blacklist=[], log: bool = True, ): checkpoints = [] if isinstance(epoch, str): assert epoch in ("best", "latest") if epoch == "best": checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}.best")) if len(checkpoints) == 0: logger.warning("Could not find `best` checkpoint. Checking for default...") if len(checkpoints) == 0: checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}")) checkpoints += glob.glob(os.path.join(dirname, f"{name}*.{extension}.best")) if len(checkpoints) == 0: return None if isinstance(epoch, int): latest = next((x for x in checkpoints if get_epoch(x) == epoch), None) if latest is None: logger.error(f"Could not find checkpoint of epoch {epoch}") exit(1) else: latest = max(checkpoints, key=get_epoch) epoch = get_epoch(latest) if log: logger.info("Found checkpoint {} with epoch {}".format(latest, epoch)) latest = torch.load(latest, map_location="cpu") latest = {k.replace("clc", "df"): v for k, v in latest.items()} if blacklist: reg = re.compile("".join(f"({b})|" for b in blacklist)[:-1]) len_before = len(latest) latest = {k: v for k, v in latest.items() if reg.search(k) is None} if len(latest) < len_before: logger.info("Filtered checkpoint modules: {}".format(blacklist)) if isinstance(obj, nn.Module): while True: try: missing, unexpected = obj.load_state_dict(latest, strict=False) except RuntimeError as e: e_str = str(e) logger.warning(e_str) if "size mismatch" in e_str: latest = {k: v for k, v in latest.items() if k not in e_str} continue raise e break for key in missing: logger.warning(f"Missing key: '{key}'") for key in unexpected: if key.endswith(".h0"): continue logger.warning(f"Unexpected key: {key}") return epoch obj.load_state_dict(latest) def write_cp( obj: Union[torch.optim.Optimizer, nn.Module], name: str, dirname: str, epoch: int, extension="ckpt", metric: Optional[float] = None, cmp="min", ): check_finite_module(obj) n_keep = config("n_checkpoint_history", default=3, cast=int, section="train") n_keep_best = config("n_best_checkpoint_history", default=5, cast=int, section="train") if metric is not None: assert cmp in ("min", "max") metric = float(metric) # Make sure it is not an integer # Each line contains a previous best with entries: (epoch, metric) with open(os.path.join(dirname, ".best"), "a+") as prev_best_f: prev_best_f.seek(0) # "a+" creates a file in read/write mode without truncating lines = prev_best_f.readlines() if len(lines) == 0: prev_best = float("inf" if cmp == "min" else "-inf") else: prev_best = float(lines[-1].strip().split(" ")[1]) cmp = "__lt__" if cmp == "min" else "__gt__" if getattr(metric, cmp)(prev_best): logger.info(f"Saving new best checkpoint at epoch {epoch} with metric: {metric}") prev_best_f.seek(0, os.SEEK_END) np.savetxt(prev_best_f, np.array([[float(epoch), metric]])) cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}.best") torch.save(obj.state_dict(), cp_name) cleanup(name, dirname, extension + ".best", nkeep=n_keep_best) cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}") logger.info(f"Writing checkpoint {cp_name} with epoch {epoch}") torch.save(obj.state_dict(), cp_name) cleanup(name, dirname, extension, nkeep=n_keep) def cleanup(name: str, dirname: str, extension: str, nkeep=5): if nkeep < 0: return checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}")) if len(checkpoints) == 0: return checkpoints = sorted(checkpoints, key=get_epoch, reverse=True) for cp in checkpoints[nkeep:]: logger.debug("Removing old checkpoint: {}".format(cp)) os.remove(cp) def check_patience( dirname: str, max_patience: int, new_metric: float, cmp: str = "min", raise_: bool = True ): cmp = "__lt__" if cmp == "min" else "__gt__" new_metric = float(new_metric) # Make sure it is not an integer prev_patience, prev_metric = read_patience(dirname) if prev_patience is None or getattr(new_metric, cmp)(prev_metric): # We have a better new_metric, reset patience write_patience(dirname, 0, new_metric) else: # We don't have a better metric, decrement patience new_patience = prev_patience + 1 write_patience(dirname, new_patience, prev_metric) if new_patience >= max_patience: if raise_: raise ValueError( f"No improvements on validation metric ({new_metric}) for {max_patience} epochs. " "Stopping." ) else: return False return True def read_patience(dirname: str) -> Tuple[Optional[int], float]: fn = os.path.join(dirname, ".patience") if not os.path.isfile(fn): return None, 0.0 patience, metric = np.loadtxt(fn) return int(patience), float(metric) def write_patience(dirname: str, new_patience: int, metric: float): return np.savetxt(os.path.join(dirname, ".patience"), [new_patience, metric]) def test_check_patience(): import tempfile with tempfile.TemporaryDirectory() as d: check_patience(d, 3, 1.0) check_patience(d, 3, 1.0) check_patience(d, 3, 1.0) assert check_patience(d, 3, 1.0, raise_=False) is False with tempfile.TemporaryDirectory() as d: check_patience(d, 3, 1.0) check_patience(d, 3, 0.9) check_patience(d, 3, 1.0) check_patience(d, 3, 1.0) assert check_patience(d, 3, 1.0, raise_=False) is False with tempfile.TemporaryDirectory() as d: check_patience(d, 3, 1.0, cmp="max") check_patience(d, 3, 1.9, cmp="max") check_patience(d, 3, 1.0, cmp="max") check_patience(d, 3, 1.0, cmp="max") assert check_patience(d, 3, 1.0, cmp="max", raise_=False) is False