auto-labeler / app /data.py
dillonlaird's picture
initial commit
6723494
raw
history blame
No virus
4.33 kB
import pickle as pkl
import numpy as np
import numpy.typing as npt
from PIL import Image
from PIL.Image import Image as ImageType
from pathlib import Path
def build_data(data_path: Path) -> dict:
data = {}
image_paths = (
list(data_path.glob("*.png"))
+ list(data_path.glob("*.jpg"))
+ list(data_path.glob("*.jpeg"))
)
for image_path in image_paths:
image_name = image_path.stem
data[image_name] = {
"image": image_path,
"labels": [],
"emb": None,
"meta_data": None,
}
return data
class Data:
def __init__(self, data_path: Path):
self.data_path = data_path
if Path(data_path).exists():
with open(data_path, "rb") as f:
self.data = pkl.load(f)
else:
data_path.parent.mkdir(parents=True, exist_ok=True)
with open(data_path, "wb") as f:
pkl.dump({}, f)
self.data = {}
def _save_data(self) -> None:
with open(self.data_path, "wb") as f:
pkl.dump(self.data, f)
def __contains__(self, image: str) -> bool:
return image in self.data
def emb_exists(self, image: str) -> bool:
return "emb" in self.data[image] and self.data[image]["emb"] is not None
def save_labels(
self, image: str, masks: list[ImageType], bboxes: list[tuple[int, ...]], labels: list[str]
) -> None:
self.clear_labels(image)
label_paths = []
for i, (mask, label) in enumerate(zip(masks, labels)):
label_path = self.data_path.parent / f"{image}.{label}.{i}.png"
mask.save(label_path)
label_paths.append(str(label_path))
self.data[image]["masks"] = label_paths
self.data[image]["labels"] = labels
self.data[image]["bboxes"] = bboxes
self._save_data()
def save_meta_data(self, image: str, meta_data: dict) -> None:
self.data[image]["meta_data"] = meta_data
self._save_data()
def save_emb(self, image: str, emb: npt.NDArray) -> None:
emb_path = self.data_path.parent / f"{image}.emb.npy"
np.save(emb_path, emb)
self.data[image]["emb"] = emb_path
self._save_data()
def save_hq_emb(self, image: str, embs: list[npt.NDArray]) -> None:
for i, emb in enumerate(embs):
emb_path = self.data_path.parent / f"{image}.emb.{i}.npy"
np.save(emb_path, emb)
self.data[image][f"emb.{i}"] = emb_path
self._save_data()
def save_image(self, image: str, image_pil: ImageType) -> None:
image_path = self.data_path.parent / f"{image}.png"
image_pil.save(image_path)
self.data[image] = {}
self.data[image]["image"] = image_path
self._save_data()
def clear_labels(self, image: str) -> None:
if "masks" in self.data[image]:
for label_path in self.data[image]["masks"]:
Path(label_path).unlink(missing_ok=True)
if "labels" in self.data[image]:
self.data[image]["labels"] = []
self._save_data()
def get_all_images(self) -> list:
return list(self.data.keys())
def get_image(self, image: str) -> ImageType:
return Image.open(self.data[image]["image"])
def get_emb(self, image: str) -> npt.NDArray:
return np.load(self.data[image]["emb"])
def get_hq_emb(self, image: str) -> list[npt.NDArray]:
embs = []
i = 0
while True:
if f"emb.{i}" in self.data[image]:
embs.append(np.load(self.data[image][f"emb.{i}"]))
i += 1
else:
break
return embs
def get_labels(
self, image: str
) -> tuple[list[ImageType], list[tuple[int, ...]], list[str]]:
if (
"masks" not in self.data[image]
or "labels" not in self.data[image]
or "bboxes" not in self.data[image]
):
return [], [], []
return (
[Image.open(mask) for mask in self.data[image]["masks"]],
[tuple(e) for e in self.data[image]["bboxes"]],
self.data[image]["labels"],
)
def get_meta_data(self, image: str) -> dict:
return self.data[image]["meta_data"]