import pickle import torch import lzma from pathlib import Path from tops import logger class BaseDetector: def __init__(self, cache_directory: str) -> None: if cache_directory is not None: self.cache_directory = Path(cache_directory, str(self.__class__.__name__)) self.cache_directory.mkdir(exist_ok=True, parents=True) def save_to_cache(self, detection, cache_path: Path, after_preprocess=True): logger.log(f"Caching detection to: {cache_path}") with lzma.open(cache_path, "wb") as fp: torch.save( [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp, pickle_protocol=pickle.HIGHEST_PROTOCOL) def load_from_cache(self, cache_path: Path): logger.log(f"Loading detection from cache path: {cache_path}") with lzma.open(cache_path, "rb") as fp: state_dict = torch.load(fp) return [ state["cls"].from_state_dict(state_dict=state) for state in state_dict ] def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool): if cache_id is None: return self.forward(im) cache_path = self.cache_directory.joinpath(cache_id + ".torch") if cache_path.is_file() and load_cache: try: return self.load_from_cache(cache_path) except Exception as e: logger.warn(f"The cache file was corrupted: {cache_path}") exit() detections = self.forward(im) self.save_to_cache(detections, cache_path) return detections