File size: 5,084 Bytes
404d2af
 
 
 
 
 
 
 
 
 
 
 
8b973ee
 
 
 
404d2af
 
8b973ee
404d2af
8b973ee
 
404d2af
 
 
8b973ee
 
 
 
 
404d2af
8b973ee
404d2af
 
 
 
 
8b973ee
 
404d2af
 
8b973ee
 
404d2af
8b973ee
 
 
404d2af
 
8b973ee
 
 
404d2af
8b973ee
 
404d2af
 
 
 
8b973ee
404d2af
8b973ee
404d2af
 
 
 
8b973ee
404d2af
 
 
 
 
 
 
 
8b973ee
 
404d2af
8b973ee
404d2af
 
 
 
 
8b973ee
 
 
 
 
 
 
 
 
404d2af
 
 
 
 
 
8b973ee
404d2af
 
 
 
 
 
8b973ee
404d2af
 
 
 
 
 
8b973ee
404d2af
 
 
 
 
 
 
8b973ee
404d2af
 
 
 
8b973ee
 
404d2af
 
8b973ee
404d2af
8b973ee
 
 
404d2af
 
8b973ee
404d2af
8b973ee
404d2af
 
8b973ee
404d2af
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use

import os, pdb
import numpy as np
from PIL import Image

from .dataset import Dataset
from .pair_dataset import PairDataset, StillPairDataset


class AachenImages(Dataset):
    """Loads all images from the Aachen Day-Night dataset"""

    def __init__(self, select="db day night", root="data/aachen"):
        Dataset.__init__(self)
        self.root = root
        self.img_dir = "images_upright"
        self.select = set(select.split())
        assert self.select, "Nothing was selected"

        self.imgs = []
        root = os.path.join(root, self.img_dir)
        for dirpath, _, filenames in os.walk(root):
            r = dirpath[len(root) + 1 :]
            if not (self.select & set(r.split("/"))):
                continue
            self.imgs += [os.path.join(r, f) for f in filenames if f.endswith(".jpg")]

        self.nimg = len(self.imgs)
        assert self.nimg, "Empty Aachen dataset"

    def get_key(self, idx):
        return self.imgs[idx]


class AachenImages_DB(AachenImages):
    """Only database (db) images."""

    def __init__(self, **kw):
        AachenImages.__init__(self, select="db", **kw)
        self.db_image_idxs = {self.get_tag(i): i for i, f in enumerate(self.imgs)}

    def get_tag(self, idx):
        # returns image tag == img number (name)
        return os.path.split(self.imgs[idx][:-4])[1]


class AachenPairs_StyleTransferDayNight(AachenImages_DB, StillPairDataset):
    """synthetic day-night pairs of images
    (night images obtained using autoamtic style transfer from web night images)
    """

    def __init__(self, root="data/aachen/style_transfer", **kw):
        StillPairDataset.__init__(self)
        AachenImages_DB.__init__(self, **kw)
        old_root = os.path.join(self.root, self.img_dir)
        self.root = os.path.commonprefix((old_root, root))
        self.img_dir = ""

        newpath = lambda folder, f: os.path.join(folder, f)[len(self.root) :]
        self.imgs = [newpath(old_root, f) for f in self.imgs]

        self.image_pairs = []
        for fname in os.listdir(root):
            tag = fname.split(".jpg.st_")[0]
            self.image_pairs.append((self.db_image_idxs[tag], len(self.imgs)))
            self.imgs.append(newpath(root, fname))

        self.nimg = len(self.imgs)
        self.npairs = len(self.image_pairs)
        assert self.nimg and self.npairs


class AachenPairs_OpticalFlow(AachenImages_DB, PairDataset):
    """Image pairs from Aachen db with optical flow."""

    def __init__(self, root="data/aachen/optical_flow", **kw):
        PairDataset.__init__(self)
        AachenImages_DB.__init__(self, **kw)
        self.root_flow = root

        # find out the subsest of valid pairs from the list of flow files
        flows = {
            f for f in os.listdir(os.path.join(root, "flow")) if f.endswith(".png")
        }
        masks = {
            f for f in os.listdir(os.path.join(root, "mask")) if f.endswith(".png")
        }
        assert flows == masks, "Missing flow or mask pairs"

        make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split("_"))
        self.image_pairs = [make_pair(f) for f in flows]
        self.npairs = len(self.image_pairs)
        assert self.nimg and self.npairs

    def get_mask_filename(self, pair_idx):
        tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx])
        return os.path.join(self.root_flow, "mask", f"{tag_a}_{tag_b}.png")

    def get_mask(self, pair_idx):
        return np.asarray(Image.open(self.get_mask_filename(pair_idx)))

    def get_flow_filename(self, pair_idx):
        tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx])
        return os.path.join(self.root_flow, "flow", f"{tag_a}_{tag_b}.png")

    def get_flow(self, pair_idx):
        fname = self.get_flow_filename(pair_idx)
        try:
            return self._png2flow(fname)
        except IOError:
            flow = open(fname[:-4], "rb")
            help = np.fromfile(flow, np.float32, 1)
            assert help == 202021.25
            W, H = np.fromfile(flow, np.int32, 2)
            flow = np.fromfile(flow, np.float32).reshape((H, W, 2))
            return self._flow2png(flow, fname)

    def get_pair(self, idx, output=()):
        if isinstance(output, str):
            output = output.split()

        img1, img2 = map(self.get_image, self.image_pairs[idx])
        meta = {}

        if "flow" in output or "aflow" in output:
            flow = self.get_flow(idx)
            assert flow.shape[:2] == img1.size[::-1]
            meta["flow"] = flow
            H, W = flow.shape[:2]
            meta["aflow"] = flow + np.mgrid[:H, :W][::-1].transpose(1, 2, 0)

        if "mask" in output:
            mask = self.get_mask(idx)
            assert mask.shape[:2] == img1.size[::-1]
            meta["mask"] = mask

        return img1, img2, meta


if __name__ == "__main__":
    print(aachen_db_images)
    print(aachen_style_transfer_pairs)
    print(aachen_flow_pairs)
    pdb.set_trace()