File size: 8,052 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import json
import os
import random
import signal

import albumentations as A
import cv2
import h5py
import numpy as np
import torch
import torchvision.transforms as T
from albumentations.pytorch.functional import img_to_tensor, mask_to_tensor
from skimage import segmentation
from termcolor import cprint
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


class ImageDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        dataset_name: str,
        datalist: str,
        mode: str,
        transform=None,
        uncorrect_label=False,
        spixel: bool = False,
        num_spixel: int = 100,
    ):
        super().__init__()

        assert os.path.exists(datalist), f"{datalist} does not exist"
        assert mode in ["train", "val"], f"{mode} unsupported mode"

        with open(datalist, "r") as f:
            self.datalist = json.load(f)

        self.datalist = dict(
            filter(lambda x: x[1]["subset"] == mode, self.datalist.items())
        )
        if len(self.datalist) == 0:
            raise NotImplementedError(f"no item in {datalist} {mode} dataset")
        self.video_id_list = list(self.datalist.keys())
        self.transform = transform
        self.uncorrect_label = uncorrect_label

        self.dataset_name = dataset_name
        h5_path = os.path.join("data", dataset_name + "_dataset.hdf5")
        self.use_h5 = os.path.exists(h5_path)
        if self.use_h5:
            cprint(
                f"{dataset_name} {mode} HDF5 database found, loading into memory...",
                "blue",
            )
            try:
                with timeout(seconds=60):
                    self.database = h5py.File(h5_path, "r", driver="core")
            except Exception as e:
                self.database = h5py.File(h5_path, "r")
                cprint(
                    "Failed to load {} HDF5 database to memory due to {}".format(
                        dataset_name, str(e)
                    ),
                    "red",
                )
        else:
            cprint(
                f"{dataset_name} {mode} HDF5 database not found, using raw images.",
                "blue",
            )

        self.spixel = False
        self.num_spixel = num_spixel
        if spixel:
            self.spixel = True
            self.spixel_dict = {}

    def __getitem__(self, index):
        image_id = self.video_id_list[index]
        info = self.datalist[image_id]
        label = float(info["label"])
        if self.use_h5:
            try:
                image = self.database[info["path"].replace("/", "-")][()]
            except Exception as e:
                cprint(
                    "Failed to load {} from {} due to {}".format(
                        image_id, self.dataset_name, str(e)
                    ),
                    "red",
                )
                image = cv2.imread(info["path"])
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        else:
            assert os.path.exists(info["path"]), f"{info['path']} does not exist!"
            image = cv2.imread(info["path"])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.spixel and image_id not in self.spixel_dict.keys():
            spixel = segmentation.slic(
                image, n_segments=self.num_spixel, channel_axis=2, start_label=0
            )
            self.spixel_dict[image_id] = spixel

        image_size = image.shape[:2]

        # 1 means modified area, 0 means pristine
        if "mask" in info.keys():
            if self.use_h5:
                try:
                    mask = self.database[info["mask"].replace("/", "-")][()]
                except Exception as e:
                    cprint(
                        "Failed to load {} mask from {} due to {}".format(
                            image_id, self.dataset_name, str(e)
                        ),
                        "red",
                    )
                    mask = cv2.imread(info["mask"], cv2.IMREAD_GRAYSCALE)
            else:
                mask = cv2.imread(info["mask"], cv2.IMREAD_GRAYSCALE)
        else:
            if label == 0:
                mask = np.zeros(image_size)
            else:
                mask = np.ones(image_size)

        if self.transform is not None:
            if self.spixel:
                transformed = self.transform(
                    image=image, masks=[mask, self.spixel_dict[image_id]]
                )  # TODO I am not sure if this is correct for scaling
                mask = transformed["masks"][0]
                spixel = transformed["masks"][1]
            else:
                transformed = self.transform(image=image, mask=mask)
                mask = transformed["mask"]

            image = transformed["image"]
            if not self.uncorrect_label:
                label = float(mask.max() != 0.0)

        if label == 1.0 and image.shape[:-1] != mask.shape:
            mask = cv2.resize(mask, dsize=(image.shape[1], image.shape[0]))

        unnormalized_image = img_to_tensor(image)
        image = img_to_tensor(
            image,
            normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
        )
        mask = mask_to_tensor(mask, num_classes=1, sigmoid=True)

        output = {
            "image": image,  # tensor of 3, H, W
            "label": label,  # float
            "mask": mask,  # tensor of 1, H, W
            "id": image_id,  # string
            "unnormalized_image": unnormalized_image,
        }  # tensor of 3, H, W
        if self.spixel:
            spixel = torch.from_numpy(spixel).unsqueeze(0)
            output["spixel"] = spixel
        return output

    def __len__(self):
        return len(self.video_id_list)


def crop_to_smallest_collate_fn(batch, max_size=128, uncorrect_label=False):
    # get the smallest image size in a batch
    smallest_size = [max_size, max_size]
    for item in batch:
        if item["mask"].shape[-2:] != item["image"].shape[-2:]:
            cprint(
                f"{item['id']} has inconsistent image-mask sizes,"
                "with image size {item['image'].shape[-2:]} and mask size"
                "{item['mask'].shape[-2:]}!",
                "red",
            )
        image_size = item["image"].shape[-2:]
        if image_size[0] < smallest_size[0]:
            smallest_size[0] = image_size[0]
        if image_size[1] < smallest_size[1]:
            smallest_size[1] = image_size[1]

    # crop all images and masks in each item to the smallest size
    result = {}
    for item in batch:
        image_size = item["image"].shape[-2:]
        x1 = random.randint(0, image_size[1] - smallest_size[1])
        y1 = random.randint(0, image_size[0] - smallest_size[0])
        x2 = x1 + smallest_size[1]
        y2 = y1 + smallest_size[0]
        for k in ["image", "mask", "unnormalized_image", "spixel"]:
            if k not in item.keys():
                continue
            item[k] = item[k][:, y1:y2, x1:x2]
            if not uncorrect_label:
                item["label"] = float(item["mask"].max() != 0.0)
        for k, v in item.items():
            if k in result.keys():
                result[k].append(v)
            else:
                result[k] = [v]

    # stack all outputs
    for k, v in result.items():
        if k in ["image", "mask", "unnormalized_image", "spixel"]:
            if k not in result.keys():
                continue
            result[k] = torch.stack(v, dim=0)
        elif k in ["label"]:
            result[k] = torch.tensor(v).float()

    return result


class timeout:
    def __init__(self, seconds=1, error_message="Timeout"):
        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, type, value, traceback):
        signal.alarm(0)