Align3R / dust3r /datasets /base /base_stereo_view_dataset.py
cyun9286's picture
Add application file
f53b39e
raw
history blame
11.7 kB
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# base class for implementing datasets
# --------------------------------------------------------
import PIL
import numpy as np
import torch
from dust3r.datasets.base.easy_dataset import EasyDataset
from dust3r.datasets.utils.transforms import ImgNorm
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
import dust3r.datasets.utils.cropping as cropping
class BaseStereoViewDataset(EasyDataset):
""" Define all basic options.
Usage:
class MyDataset (BaseStereoViewDataset):
def _get_views(self, idx, rng):
# overload here
views = []
views.append(dict(img=, ...))
return views
"""
def __init__(self, *, # only keyword arguments
split=None,
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
transform=ImgNorm,
aug_crop=False,
aug_f=False,
seed=None,
depth_prior_name='depthpro'):
self.num_views = 2
self.split = split
self.depth_prior_name = depth_prior_name
self._set_resolutions(resolution)
self.aug_f = aug_f
self.transform = transform
if isinstance(transform, str):
transform = eval(transform)
self.aug_crop = aug_crop
self.seed = seed
def __len__(self):
return len(self.scenes)
def get_stats(self):
return f"{len(self)} pairs"
def __repr__(self):
resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
return f"""{type(self).__name__}({self.get_stats()},
{self.split=},
{self.seed=},
resolutions={resolutions_str},
{self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
def _get_views(self, idx, resolution, rng):
raise NotImplementedError()
def pixel_to_pointcloud(self, depth_map, focal_length_px):
"""
Convert a depth map to a 3D point cloud.
Args:
depth_map (numpy.ndarray): The input depth map with shape (H, W), where each value represents the depth at that pixel.
focal_length_px (float): The focal length of the camera in pixels.
Returns:
numpy.ndarray: The resulting point cloud with shape (H, W, 3), where each point is represented by (X, Y, Z).
"""
height, width = depth_map.shape
cx = width / 2
cy = height / 2
# Create meshgrid for pixel coordinates
u = np.arange(width)
v = np.arange(height)
u, v = np.meshgrid(u, v)
#depth_map[depth_map>100]=0
# Convert pixel coordinates to camera coordinates
Z = depth_map
X = (u - cx) * Z / focal_length_px
Y = (v - cy) * Z / focal_length_px
# Stack the coordinates into a point cloud (H, W, 3)
point_cloud = np.dstack((X, Y, Z)).astype(np.float32)
point_cloud = self.normalize_pointcloud(point_cloud)
# Optional: Filter out invalid depth values (if necessary)
# point_cloud = point_cloud[depth_map > 0]
#print(point_cloud)
return point_cloud
def normalize_pointcloud(self, point_cloud):
min_vals = np.min(point_cloud, axis=(0, 1))
max_vals = np.max(point_cloud, axis=(0, 1))
#print(min_vals, max_vals)
normalized_point_cloud = (point_cloud - min_vals) / (max_vals - min_vals)
return normalized_point_cloud
def __getitem__(self, idx):
if isinstance(idx, tuple):
# the idx is specifying the aspect-ratio
idx, ar_idx = idx
else:
assert len(self._resolutions) == 1
ar_idx = 0
# set-up the rng
if self.seed: # reseed for each __getitem__
self._rng = np.random.default_rng(seed=self.seed + idx)
elif not hasattr(self, '_rng'):
seed = torch.initial_seed() # this is different for each dataloader process
self._rng = np.random.default_rng(seed=seed)
# over-loaded code
resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
#print(ar_idx, self.dataset_label,resolution)
views = self._get_views(idx, resolution, self._rng)
assert len(views) == self.num_views
# check data-types
for v, view in enumerate(views):
assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
view['idx'] = (idx, ar_idx, v)
# encode the image
width, height = view['img'].size
view['true_shape'] = np.int32((height, width))
view['img'] = self.transform(view['img'])
assert 'camera_intrinsics' in view
if 'camera_pose' not in view:
view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
else:
assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
assert 'pts3d' not in view
assert 'valid_mask' not in view
assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
view['pts3d'] = pts3d
view['valid_mask'] = valid_mask & (np.isfinite(pts3d).all(axis=-1))[..., None]
# check all datatypes
for key, val in view.items():
res, err_msg = is_good_type(key, val)
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
K = view['camera_intrinsics']
# last thing done!
for view in views:
# transpose to make sure all views are the same size
transpose_to_landscape(view)
# this allows to check whether the RNG is is the same state each time
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
return views
def _set_resolutions(self, resolutions):
assert resolutions is not None, 'undefined resolution'
if not isinstance(resolutions, list):
resolutions = [resolutions]
self._resolutions = []
for resolution in resolutions:
if isinstance(resolution, int):
width = height = resolution
else:
width, height = resolution
assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
assert width >= height
self._resolutions.append((width, height))
def _crop_resize_if_necessary(self, image, depthmap, pred_depth, intrinsics, resolution, rng=None, info=None):
""" This function:
- first downsizes the image with LANCZOS inteprolation,
which is better than bilinear interpolation in
"""
if not isinstance(image, PIL.Image.Image):
image = PIL.Image.fromarray(image)
# downscale with lanczos interpolation so that image.size == resolution
# cropping centered on the principal point
W, H = image.size
cx, cy = intrinsics[:2, 2].round().astype(int)
#print(cx, W-cx,cy, H-cy)
min_margin_x = min(cx, W-cx)
min_margin_y = min(cy, H-cy)
# scale = rng.choice([0.5, 0.75, 1, 1.25], size=1, replace=False)[0]
# #print(scale)
# crop_resolution = (resolution[0]*scale, resolution[1]*scale)
# #print(crop_resolution)
# assert min_margin_x > W/5, f'Bad principal point in view={info}'
# assert min_margin_y > H/5, f'Bad principal point in view={info}'
# if rng.choice([0, 1], size=1, replace=False)[0]==0:
# min_margin_x = min(min_margin_x, int(crop_resolution[0]/2))
# min_margin_y = min(min_margin_y, int(crop_resolution[1]/2))
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
l, t = cx - min_margin_x, cy - min_margin_y
r, b = cx + min_margin_x, cy + min_margin_y
crop_bbox = (l, t, r, b)
#print(resolution, crop_resolution,crop_bbox)
# print(crop_bbox)
image, depthmap, pred_depth, intrinsics = cropping.crop_image_depthmap(image, depthmap, pred_depth, intrinsics, crop_bbox)
#print(image.size)
# transpose the resolution if necessary
W, H = image.size # new size
assert resolution[0] >= resolution[1]
if H > 1.1*W:
# image is portrait mode
resolution = resolution[::-1]
elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]:
# image is square, so we chose (portrait, landscape) randomly
if rng.integers(2):
resolution = resolution[::-1]
# center-crop
target_resolution = np.array(resolution)
if self.aug_f:
crop_scale = rng.choice([0.8, 0.9, 1.0], size=1, replace=False)[0]
image, depthmap, pred_depth, intrinsics = cropping.center_crop_image_depthmap(image, depthmap, pred_depth, intrinsics, crop_scale)
if self.aug_crop > 1:
target_resolution += rng.integers(0, self.aug_crop)
image, depthmap, pred_depth, intrinsics = cropping.rescale_image_depthmap(image, depthmap, pred_depth, intrinsics, target_resolution)
#print(image.size)
# actual cropping (if necessary) with bilinear interpolation
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
image, depthmap, pred_depth, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, pred_depth, intrinsics, crop_bbox)
#print(image.size)
return image, depthmap, pred_depth, intrinsics2
def is_good_type(key, v):
""" returns (is_good, err_msg)
"""
if isinstance(v, (str, int, tuple)):
return True, None
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
return False, f"bad {v.dtype=}"
return True, None
def view_name(view, batch_index=None):
def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x
db = sel(view['dataset'])
label = sel(view['label'])
instance = sel(view['instance'])
return f"{db}/{label}/{instance}"
def transpose_to_landscape(view):
height, width = view['true_shape']
if width < height:
# rectify portrait to landscape
assert view['img'].shape == (3, height, width)
view['img'] = view['img'].swapaxes(1, 2)
assert view['valid_mask'].shape == (height, width)
view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
assert view['depthmap'].shape == (height, width)
view['depthmap'] = view['depthmap'].swapaxes(0, 1)
assert view['pts3d'].shape == (height, width, 3)
view['pts3d'] = view['pts3d'].swapaxes(0, 1)
assert view['pred_depth'].shape == (height, width)
view['pred_depth'] = view['pred_depth'].swapaxes(0, 1)
# transpose x and y pixels
view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]