File size: 4,331 Bytes
6723494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pickle as pkl
import numpy as np
import numpy.typing as npt

from PIL import Image
from PIL.Image import Image as ImageType
from pathlib import Path


def build_data(data_path: Path) -> dict:
    data = {}
    image_paths = (
        list(data_path.glob("*.png"))
        + list(data_path.glob("*.jpg"))
        + list(data_path.glob("*.jpeg"))
    )
    for image_path in image_paths:
        image_name = image_path.stem
        data[image_name] = {
            "image": image_path,
            "labels": [],
            "emb": None,
            "meta_data": None,
        }
    return data


class Data:
    def __init__(self, data_path: Path):
        self.data_path = data_path
        if Path(data_path).exists():
            with open(data_path, "rb") as f:
                self.data = pkl.load(f)
        else:
            data_path.parent.mkdir(parents=True, exist_ok=True)
            with open(data_path, "wb") as f:
                pkl.dump({}, f)
            self.data = {}

    def _save_data(self) -> None:
        with open(self.data_path, "wb") as f:
            pkl.dump(self.data, f)

    def __contains__(self, image: str) -> bool:
        return image in self.data

    def emb_exists(self, image: str) -> bool:
        return "emb" in self.data[image] and self.data[image]["emb"] is not None

    def save_labels(
        self, image: str, masks: list[ImageType], bboxes: list[tuple[int, ...]], labels: list[str]
    ) -> None:
        self.clear_labels(image)
        label_paths = []
        for i, (mask, label) in enumerate(zip(masks, labels)):
            label_path = self.data_path.parent / f"{image}.{label}.{i}.png"
            mask.save(label_path)
            label_paths.append(str(label_path))
        self.data[image]["masks"] = label_paths
        self.data[image]["labels"] = labels
        self.data[image]["bboxes"] = bboxes
        self._save_data()

    def save_meta_data(self, image: str, meta_data: dict) -> None:
        self.data[image]["meta_data"] = meta_data
        self._save_data()

    def save_emb(self, image: str, emb: npt.NDArray) -> None:
        emb_path = self.data_path.parent / f"{image}.emb.npy"
        np.save(emb_path, emb)
        self.data[image]["emb"] = emb_path
        self._save_data()

    def save_hq_emb(self, image: str, embs: list[npt.NDArray]) -> None:
        for i, emb in enumerate(embs):
            emb_path = self.data_path.parent / f"{image}.emb.{i}.npy"
            np.save(emb_path, emb)
            self.data[image][f"emb.{i}"] = emb_path
        self._save_data()

    def save_image(self, image: str, image_pil: ImageType) -> None:
        image_path = self.data_path.parent / f"{image}.png"
        image_pil.save(image_path)
        self.data[image] = {}
        self.data[image]["image"] = image_path
        self._save_data()

    def clear_labels(self, image: str) -> None:
        if "masks" in self.data[image]:
            for label_path in self.data[image]["masks"]:
                Path(label_path).unlink(missing_ok=True)
        if "labels" in self.data[image]:
            self.data[image]["labels"] = []
        self._save_data()

    def get_all_images(self) -> list:
        return list(self.data.keys())

    def get_image(self, image: str) -> ImageType:
        return Image.open(self.data[image]["image"])

    def get_emb(self, image: str) -> npt.NDArray:
        return np.load(self.data[image]["emb"])

    def get_hq_emb(self, image: str) -> list[npt.NDArray]:
        embs = []
        i = 0
        while True:
            if f"emb.{i}" in self.data[image]:
                embs.append(np.load(self.data[image][f"emb.{i}"]))
                i += 1
            else:
                break
        return embs

    def get_labels(
        self, image: str
    ) -> tuple[list[ImageType], list[tuple[int, ...]], list[str]]:
        if (
            "masks" not in self.data[image]
            or "labels" not in self.data[image]
            or "bboxes" not in self.data[image]
        ):
            return [], [], []
        return (
            [Image.open(mask) for mask in self.data[image]["masks"]],
            [tuple(e) for e in self.data[image]["bboxes"]],
            self.data[image]["labels"],
        )

    def get_meta_data(self, image: str) -> dict:
        return self.data[image]["meta_data"]