""" Simply load images from a folder or nested folders (does not have any split). """ import argparse import logging import tarfile import matplotlib.pyplot as plt import numpy as np import torch from omegaconf import OmegaConf from ..settings import DATA_PATH from ..utils.image import ImagePreprocessor, load_image from ..utils.tools import fork_rng from ..visualization.viz2d import plot_image_grid from .base_dataset import BaseDataset logger = logging.getLogger(__name__) def read_homography(path): with open(path) as f: result = [] for line in f.readlines(): while " " in line: # Remove double spaces line = line.replace(" ", " ") line = line.replace(" \n", "").replace("\n", "") # Split and discard empty strings elements = list(filter(lambda s: s, line.split(" "))) if elements: result.append(elements) return np.array(result).astype(float) class HPatches(BaseDataset, torch.utils.data.Dataset): default_conf = { "preprocessing": ImagePreprocessor.default_conf, "data_dir": "hpatches-sequences-release", "subset": None, "ignore_large_images": True, "grayscale": False, } # Large images that were ignored in previous papers ignored_scenes = ( "i_contruction", "i_crownnight", "i_dc", "i_pencils", "i_whitebuilding", "v_artisans", "v_astronautis", "v_talent", ) url = "http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz" def _init(self, conf): assert conf.batch_size == 1 self.preprocessor = ImagePreprocessor(conf.preprocessing) self.root = DATA_PATH / conf.data_dir if not self.root.exists(): logger.info("Downloading the HPatches dataset.") self.download() self.sequences = sorted([x.name for x in self.root.iterdir()]) if not self.sequences: raise ValueError("No image found!") self.items = [] # (seq, q_idx, is_illu) for seq in self.sequences: if conf.ignore_large_images and seq in self.ignored_scenes: continue if conf.subset is not None and conf.subset != seq[0]: continue for i in range(2, 7): self.items.append((seq, i, seq[0] == "i")) def download(self): data_dir = self.root.parent data_dir.mkdir(exist_ok=True, parents=True) tar_path = data_dir / self.url.rsplit("/", 1)[-1] torch.hub.download_url_to_file(self.url, tar_path) with tarfile.open(tar_path) as tar: tar.extractall(data_dir) tar_path.unlink() def get_dataset(self, split): assert split in ["val", "test"] return self def _read_image(self, seq: str, idx: int) -> dict: img = load_image(self.root / seq / f"{idx}.ppm", self.conf.grayscale) return self.preprocessor(img) def __getitem__(self, idx): seq, q_idx, is_illu = self.items[idx] data0 = self._read_image(seq, 1) data1 = self._read_image(seq, q_idx) H = read_homography(self.root / seq / f"H_1_{q_idx}") H = data1["transform"] @ H @ np.linalg.inv(data0["transform"]) return { "H_0to1": H.astype(np.float32), "scene": seq, "idx": idx, "is_illu": is_illu, "name": f"{seq}/{idx}.ppm", "view0": data0, "view1": data1, } def __len__(self): return len(self.items) def visualize(args): conf = { "batch_size": 1, "num_workers": 8, "prefetch_factor": 1, } conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist)) dataset = HPatches(conf) loader = dataset.get_data_loader("test") logger.info("The dataset has %d elements.", len(loader)) with fork_rng(seed=dataset.conf.seed): images = [] for _, data in zip(range(args.num_items), loader): images.append( (data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)) ) plot_image_grid(images, dpi=args.dpi) plt.tight_layout() plt.show() if __name__ == "__main__": from .. import logger # overwrite the logger parser = argparse.ArgumentParser() parser.add_argument("--num_items", type=int, default=8) parser.add_argument("--dpi", type=int, default=100) parser.add_argument("dotlist", nargs="*") args = parser.parse_intermixed_args() visualize(args)