import pickle as pkl import numpy as np import numpy.typing as npt from PIL import Image from PIL.Image import Image as ImageType from pathlib import Path def build_data(data_path: Path) -> dict: data = {} image_paths = ( list(data_path.glob("*.png")) + list(data_path.glob("*.jpg")) + list(data_path.glob("*.jpeg")) ) for image_path in image_paths: image_name = image_path.stem data[image_name] = { "image": image_path, "labels": [], "emb": None, "meta_data": None, } return data class Data: def __init__(self, data_path: Path): self.data_path = data_path if Path(data_path).exists(): with open(data_path, "rb") as f: self.data = pkl.load(f) else: data_path.parent.mkdir(parents=True, exist_ok=True) with open(data_path, "wb") as f: pkl.dump({}, f) self.data = {} def _save_data(self) -> None: with open(self.data_path, "wb") as f: pkl.dump(self.data, f) def __contains__(self, image: str) -> bool: return image in self.data def emb_exists(self, image: str) -> bool: return "emb" in self.data[image] and self.data[image]["emb"] is not None def save_labels( self, image: str, masks: list[ImageType], bboxes: list[tuple[int, ...]], labels: list[str] ) -> None: self.clear_labels(image) label_paths = [] for i, (mask, label) in enumerate(zip(masks, labels)): label_path = self.data_path.parent / f"{image}.{label}.{i}.png" mask.save(label_path) label_paths.append(str(label_path)) self.data[image]["masks"] = label_paths self.data[image]["labels"] = labels self.data[image]["bboxes"] = bboxes self._save_data() def save_meta_data(self, image: str, meta_data: dict) -> None: self.data[image]["meta_data"] = meta_data self._save_data() def save_emb(self, image: str, emb: npt.NDArray) -> None: emb_path = self.data_path.parent / f"{image}.emb.npy" np.save(emb_path, emb) self.data[image]["emb"] = emb_path self._save_data() def save_hq_emb(self, image: str, embs: list[npt.NDArray]) -> None: for i, emb in enumerate(embs): emb_path = self.data_path.parent / f"{image}.emb.{i}.npy" np.save(emb_path, emb) self.data[image][f"emb.{i}"] = emb_path self._save_data() def save_image(self, image: str, image_pil: ImageType) -> None: image_path = self.data_path.parent / f"{image}.png" image_pil.save(image_path) self.data[image] = {} self.data[image]["image"] = image_path self._save_data() def clear_labels(self, image: str) -> None: if "masks" in self.data[image]: for label_path in self.data[image]["masks"]: Path(label_path).unlink(missing_ok=True) if "labels" in self.data[image]: self.data[image]["labels"] = [] self._save_data() def get_all_images(self) -> list: return list(self.data.keys()) def get_image(self, image: str) -> ImageType: return Image.open(self.data[image]["image"]) def get_emb(self, image: str) -> npt.NDArray: return np.load(self.data[image]["emb"]) def get_hq_emb(self, image: str) -> list[npt.NDArray]: embs = [] i = 0 while True: if f"emb.{i}" in self.data[image]: embs.append(np.load(self.data[image][f"emb.{i}"])) i += 1 else: break return embs def get_labels( self, image: str ) -> tuple[list[ImageType], list[tuple[int, ...]], list[str]]: if ( "masks" not in self.data[image] or "labels" not in self.data[image] or "bboxes" not in self.data[image] ): return [], [], [] return ( [Image.open(mask) for mask in self.data[image]["masks"]], [tuple(e) for e in self.data[image]["bboxes"]], self.data[image]["labels"], ) def get_meta_data(self, image: str) -> dict: return self.data[image]["meta_data"]