Vincentqyw
fix: roma
8b973ee
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import os, pdb
import numpy as np
from PIL import Image
from .dataset import Dataset
from .pair_dataset import PairDataset, StillPairDataset
class AachenImages(Dataset):
"""Loads all images from the Aachen Day-Night dataset"""
def __init__(self, select="db day night", root="data/aachen"):
Dataset.__init__(self)
self.root = root
self.img_dir = "images_upright"
self.select = set(select.split())
assert self.select, "Nothing was selected"
self.imgs = []
root = os.path.join(root, self.img_dir)
for dirpath, _, filenames in os.walk(root):
r = dirpath[len(root) + 1 :]
if not (self.select & set(r.split("/"))):
continue
self.imgs += [os.path.join(r, f) for f in filenames if f.endswith(".jpg")]
self.nimg = len(self.imgs)
assert self.nimg, "Empty Aachen dataset"
def get_key(self, idx):
return self.imgs[idx]
class AachenImages_DB(AachenImages):
"""Only database (db) images."""
def __init__(self, **kw):
AachenImages.__init__(self, select="db", **kw)
self.db_image_idxs = {self.get_tag(i): i for i, f in enumerate(self.imgs)}
def get_tag(self, idx):
# returns image tag == img number (name)
return os.path.split(self.imgs[idx][:-4])[1]
class AachenPairs_StyleTransferDayNight(AachenImages_DB, StillPairDataset):
"""synthetic day-night pairs of images
(night images obtained using autoamtic style transfer from web night images)
"""
def __init__(self, root="data/aachen/style_transfer", **kw):
StillPairDataset.__init__(self)
AachenImages_DB.__init__(self, **kw)
old_root = os.path.join(self.root, self.img_dir)
self.root = os.path.commonprefix((old_root, root))
self.img_dir = ""
newpath = lambda folder, f: os.path.join(folder, f)[len(self.root) :]
self.imgs = [newpath(old_root, f) for f in self.imgs]
self.image_pairs = []
for fname in os.listdir(root):
tag = fname.split(".jpg.st_")[0]
self.image_pairs.append((self.db_image_idxs[tag], len(self.imgs)))
self.imgs.append(newpath(root, fname))
self.nimg = len(self.imgs)
self.npairs = len(self.image_pairs)
assert self.nimg and self.npairs
class AachenPairs_OpticalFlow(AachenImages_DB, PairDataset):
"""Image pairs from Aachen db with optical flow."""
def __init__(self, root="data/aachen/optical_flow", **kw):
PairDataset.__init__(self)
AachenImages_DB.__init__(self, **kw)
self.root_flow = root
# find out the subsest of valid pairs from the list of flow files
flows = {
f for f in os.listdir(os.path.join(root, "flow")) if f.endswith(".png")
}
masks = {
f for f in os.listdir(os.path.join(root, "mask")) if f.endswith(".png")
}
assert flows == masks, "Missing flow or mask pairs"
make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split("_"))
self.image_pairs = [make_pair(f) for f in flows]
self.npairs = len(self.image_pairs)
assert self.nimg and self.npairs
def get_mask_filename(self, pair_idx):
tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx])
return os.path.join(self.root_flow, "mask", f"{tag_a}_{tag_b}.png")
def get_mask(self, pair_idx):
return np.asarray(Image.open(self.get_mask_filename(pair_idx)))
def get_flow_filename(self, pair_idx):
tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx])
return os.path.join(self.root_flow, "flow", f"{tag_a}_{tag_b}.png")
def get_flow(self, pair_idx):
fname = self.get_flow_filename(pair_idx)
try:
return self._png2flow(fname)
except IOError:
flow = open(fname[:-4], "rb")
help = np.fromfile(flow, np.float32, 1)
assert help == 202021.25
W, H = np.fromfile(flow, np.int32, 2)
flow = np.fromfile(flow, np.float32).reshape((H, W, 2))
return self._flow2png(flow, fname)
def get_pair(self, idx, output=()):
if isinstance(output, str):
output = output.split()
img1, img2 = map(self.get_image, self.image_pairs[idx])
meta = {}
if "flow" in output or "aflow" in output:
flow = self.get_flow(idx)
assert flow.shape[:2] == img1.size[::-1]
meta["flow"] = flow
H, W = flow.shape[:2]
meta["aflow"] = flow + np.mgrid[:H, :W][::-1].transpose(1, 2, 0)
if "mask" in output:
mask = self.get_mask(idx)
assert mask.shape[:2] == img1.size[::-1]
meta["mask"] = mask
return img1, img2, meta
if __name__ == "__main__":
print(aachen_db_images)
print(aachen_style_transfer_pairs)
print(aachen_flow_pairs)
pdb.set_trace()