Vincentqyw
fix: roma
8b973ee
raw
history blame
No virus
4.21 kB
import torch
import cv2
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from pathlib import Path
class PatchesDataset(Dataset):
"""
HPatches dataset class.
# Note: output_shape = (output_width, output_height)
# Note: this returns Pytorch tensors, resized to output_shape (if specified)
# Note: the homography will be adjusted according to output_shape.
Parameters
----------
root_dir : str
Path to the dataset
use_color : bool
Return color images or convert to grayscale.
data_transform : Function
Transformations applied to the sample
output_shape: tuple
If specified, the images and homographies will be resized to the desired shape.
type: str
Dataset subset to return from ['i', 'v', 'all']:
i - illumination sequences
v - viewpoint sequences
all - all sequences
"""
def __init__(
self,
root_dir,
use_color=True,
data_transform=None,
output_shape=None,
type="all",
):
super().__init__()
self.type = type
self.root_dir = root_dir
self.data_transform = data_transform
self.output_shape = output_shape
self.use_color = use_color
base_path = Path(root_dir)
folder_paths = [x for x in base_path.iterdir() if x.is_dir()]
image_paths = []
warped_image_paths = []
homographies = []
for path in folder_paths:
if self.type == "i" and path.stem[0] != "i":
continue
if self.type == "v" and path.stem[0] != "v":
continue
num_images = 5
file_ext = ".ppm"
for i in range(2, 2 + num_images):
image_paths.append(str(Path(path, "1" + file_ext)))
warped_image_paths.append(str(Path(path, str(i) + file_ext)))
homographies.append(np.loadtxt(str(Path(path, "H_1_" + str(i)))))
self.files = {
"image_paths": image_paths,
"warped_image_paths": warped_image_paths,
"homography": homographies,
}
def scale_homography(self, homography, original_scale, new_scale, pre):
scales = np.divide(new_scale, original_scale)
if pre:
s = np.diag(np.append(scales, 1.0))
homography = np.matmul(s, homography)
else:
sinv = np.diag(np.append(1.0 / scales, 1.0))
homography = np.matmul(homography, sinv)
return homography
def __len__(self):
return len(self.files["image_paths"])
def __getitem__(self, idx):
def _read_image(path):
img = cv2.imread(path, cv2.IMREAD_COLOR)
if self.use_color:
return img
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return gray
image = _read_image(self.files["image_paths"][idx])
warped_image = _read_image(self.files["warped_image_paths"][idx])
homography = np.array(self.files["homography"][idx])
sample = {
"image": image,
"warped_image": warped_image,
"homography": homography,
"index": idx,
}
# Apply transformations
if self.output_shape is not None:
sample["homography"] = self.scale_homography(
sample["homography"],
sample["image"].shape[:2][::-1],
self.output_shape,
pre=False,
)
sample["homography"] = self.scale_homography(
sample["homography"],
sample["warped_image"].shape[:2][::-1],
self.output_shape,
pre=True,
)
for key in ["image", "warped_image"]:
sample[key] = cv2.resize(sample[key], self.output_shape)
if self.use_color is False:
sample[key] = np.expand_dims(sample[key], axis=2)
transform = transforms.ToTensor()
for key in ["image", "warped_image"]:
sample[key] = transform(sample[key]).type("torch.FloatTensor")
return sample