Spaces:
Runtime error
Runtime error
| import glob | |
| import os | |
| import re | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from loguru import logger | |
| from torch import nn | |
| from df_local.config import Csv, config | |
| from df_local.model import init_model | |
| from df_local.utils import check_finite_module | |
| from libdf import DF | |
| def get_epoch(cp) -> int: | |
| return int(os.path.basename(cp).split(".")[0].split("_")[-1]) | |
| def load_model( | |
| cp_dir: Optional[str], | |
| df_state: DF, | |
| jit: bool = False, | |
| mask_only: bool = False, | |
| train_df_only: bool = False, | |
| extension: str = "ckpt", | |
| epoch: Union[str, int, None] = "latest", | |
| ) -> Tuple[nn.Module, int]: | |
| if mask_only and train_df_only: | |
| raise ValueError("Only one of `mask_only` `train_df_only` can be enabled") | |
| model = init_model(df_state, run_df=mask_only is False, train_mask=train_df_only is False) | |
| if jit: | |
| model = torch.jit.script(model) | |
| blacklist: List[str] = config("CP_BLACKLIST", [], Csv(), save=False, section="train") # type: ignore | |
| if cp_dir is not None: | |
| epoch = read_cp( | |
| model, "model", cp_dir, blacklist=blacklist, extension=extension, epoch=epoch | |
| ) | |
| epoch = 0 if epoch is None else epoch | |
| else: | |
| epoch = 0 | |
| return model, epoch | |
| def read_cp( | |
| obj: Union[torch.optim.Optimizer, nn.Module], | |
| name: str, | |
| dirname: str, | |
| epoch: Union[str, int, None] = "latest", | |
| extension="ckpt", | |
| blacklist=[], | |
| log: bool = True, | |
| ): | |
| checkpoints = [] | |
| if isinstance(epoch, str): | |
| assert epoch in ("best", "latest") | |
| if epoch == "best": | |
| checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}.best")) | |
| if len(checkpoints) == 0: | |
| logger.warning("Could not find `best` checkpoint. Checking for default...") | |
| if len(checkpoints) == 0: | |
| checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}")) | |
| checkpoints += glob.glob(os.path.join(dirname, f"{name}*.{extension}.best")) | |
| if len(checkpoints) == 0: | |
| return None | |
| if isinstance(epoch, int): | |
| latest = next((x for x in checkpoints if get_epoch(x) == epoch), None) | |
| if latest is None: | |
| logger.error(f"Could not find checkpoint of epoch {epoch}") | |
| exit(1) | |
| else: | |
| latest = max(checkpoints, key=get_epoch) | |
| epoch = get_epoch(latest) | |
| if log: | |
| logger.info("Found checkpoint {} with epoch {}".format(latest, epoch)) | |
| latest = torch.load(latest, map_location="cpu") | |
| latest = {k.replace("clc", "df"): v for k, v in latest.items()} | |
| if blacklist: | |
| reg = re.compile("".join(f"({b})|" for b in blacklist)[:-1]) | |
| len_before = len(latest) | |
| latest = {k: v for k, v in latest.items() if reg.search(k) is None} | |
| if len(latest) < len_before: | |
| logger.info("Filtered checkpoint modules: {}".format(blacklist)) | |
| if isinstance(obj, nn.Module): | |
| while True: | |
| try: | |
| missing, unexpected = obj.load_state_dict(latest, strict=False) | |
| except RuntimeError as e: | |
| e_str = str(e) | |
| logger.warning(e_str) | |
| if "size mismatch" in e_str: | |
| latest = {k: v for k, v in latest.items() if k not in e_str} | |
| continue | |
| raise e | |
| break | |
| for key in missing: | |
| logger.warning(f"Missing key: '{key}'") | |
| for key in unexpected: | |
| if key.endswith(".h0"): | |
| continue | |
| logger.warning(f"Unexpected key: {key}") | |
| return epoch | |
| obj.load_state_dict(latest) | |
| def write_cp( | |
| obj: Union[torch.optim.Optimizer, nn.Module], | |
| name: str, | |
| dirname: str, | |
| epoch: int, | |
| extension="ckpt", | |
| metric: Optional[float] = None, | |
| cmp="min", | |
| ): | |
| check_finite_module(obj) | |
| n_keep = config("n_checkpoint_history", default=3, cast=int, section="train") | |
| n_keep_best = config("n_best_checkpoint_history", default=5, cast=int, section="train") | |
| if metric is not None: | |
| assert cmp in ("min", "max") | |
| metric = float(metric) # Make sure it is not an integer | |
| # Each line contains a previous best with entries: (epoch, metric) | |
| with open(os.path.join(dirname, ".best"), "a+") as prev_best_f: | |
| prev_best_f.seek(0) # "a+" creates a file in read/write mode without truncating | |
| lines = prev_best_f.readlines() | |
| if len(lines) == 0: | |
| prev_best = float("inf" if cmp == "min" else "-inf") | |
| else: | |
| prev_best = float(lines[-1].strip().split(" ")[1]) | |
| cmp = "__lt__" if cmp == "min" else "__gt__" | |
| if getattr(metric, cmp)(prev_best): | |
| logger.info(f"Saving new best checkpoint at epoch {epoch} with metric: {metric}") | |
| prev_best_f.seek(0, os.SEEK_END) | |
| np.savetxt(prev_best_f, np.array([[float(epoch), metric]])) | |
| cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}.best") | |
| torch.save(obj.state_dict(), cp_name) | |
| cleanup(name, dirname, extension + ".best", nkeep=n_keep_best) | |
| cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}") | |
| logger.info(f"Writing checkpoint {cp_name} with epoch {epoch}") | |
| torch.save(obj.state_dict(), cp_name) | |
| cleanup(name, dirname, extension, nkeep=n_keep) | |
| def cleanup(name: str, dirname: str, extension: str, nkeep=5): | |
| if nkeep < 0: | |
| return | |
| checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}")) | |
| if len(checkpoints) == 0: | |
| return | |
| checkpoints = sorted(checkpoints, key=get_epoch, reverse=True) | |
| for cp in checkpoints[nkeep:]: | |
| logger.debug("Removing old checkpoint: {}".format(cp)) | |
| os.remove(cp) | |
| def check_patience( | |
| dirname: str, max_patience: int, new_metric: float, cmp: str = "min", raise_: bool = True | |
| ): | |
| cmp = "__lt__" if cmp == "min" else "__gt__" | |
| new_metric = float(new_metric) # Make sure it is not an integer | |
| prev_patience, prev_metric = read_patience(dirname) | |
| if prev_patience is None or getattr(new_metric, cmp)(prev_metric): | |
| # We have a better new_metric, reset patience | |
| write_patience(dirname, 0, new_metric) | |
| else: | |
| # We don't have a better metric, decrement patience | |
| new_patience = prev_patience + 1 | |
| write_patience(dirname, new_patience, prev_metric) | |
| if new_patience >= max_patience: | |
| if raise_: | |
| raise ValueError( | |
| f"No improvements on validation metric ({new_metric}) for {max_patience} epochs. " | |
| "Stopping." | |
| ) | |
| else: | |
| return False | |
| return True | |
| def read_patience(dirname: str) -> Tuple[Optional[int], float]: | |
| fn = os.path.join(dirname, ".patience") | |
| if not os.path.isfile(fn): | |
| return None, 0.0 | |
| patience, metric = np.loadtxt(fn) | |
| return int(patience), float(metric) | |
| def write_patience(dirname: str, new_patience: int, metric: float): | |
| return np.savetxt(os.path.join(dirname, ".patience"), [new_patience, metric]) | |
| def test_check_patience(): | |
| import tempfile | |
| with tempfile.TemporaryDirectory() as d: | |
| check_patience(d, 3, 1.0) | |
| check_patience(d, 3, 1.0) | |
| check_patience(d, 3, 1.0) | |
| assert check_patience(d, 3, 1.0, raise_=False) is False | |
| with tempfile.TemporaryDirectory() as d: | |
| check_patience(d, 3, 1.0) | |
| check_patience(d, 3, 0.9) | |
| check_patience(d, 3, 1.0) | |
| check_patience(d, 3, 1.0) | |
| assert check_patience(d, 3, 1.0, raise_=False) is False | |
| with tempfile.TemporaryDirectory() as d: | |
| check_patience(d, 3, 1.0, cmp="max") | |
| check_patience(d, 3, 1.9, cmp="max") | |
| check_patience(d, 3, 1.0, cmp="max") | |
| check_patience(d, 3, 1.0, cmp="max") | |
| assert check_patience(d, 3, 1.0, cmp="max", raise_=False) is False | |