Spaces:
Running
Running
import io | |
import cv2 | |
import numpy as np | |
import h5py | |
import torch | |
from numpy.linalg import inv | |
import re | |
try: | |
# for internel use only | |
from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT | |
except Exception: | |
MEGADEPTH_CLIENT = SCANNET_CLIENT = None | |
# --- DATA IO --- | |
def load_array_from_s3( | |
path, | |
client, | |
cv_type, | |
use_h5py=False, | |
): | |
byte_str = client.Get(path) | |
try: | |
if not use_h5py: | |
raw_array = np.fromstring(byte_str, np.uint8) | |
data = cv2.imdecode(raw_array, cv_type) | |
else: | |
f = io.BytesIO(byte_str) | |
data = np.array(h5py.File(f, "r")["/depth"]) | |
except Exception as ex: | |
print(f"==> Data loading failure: {path}") | |
raise ex | |
assert data is not None | |
return data | |
def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): | |
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR | |
if str(path).startswith("s3://"): | |
image = load_array_from_s3(str(path), client, cv_type) | |
else: | |
image = cv2.imread(str(path), cv_type) | |
if augment_fn is not None: | |
image = cv2.imread(str(path), cv2.IMREAD_COLOR) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = augment_fn(image) | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
return image # (h, w) | |
def get_resized_wh(w, h, resize=None): | |
if resize is not None: # resize the longer edge | |
scale = resize / max(h, w) | |
w_new, h_new = int(round(w * scale)), int(round(h * scale)) | |
else: | |
w_new, h_new = w, h | |
return w_new, h_new | |
def get_divisible_wh(w, h, df=None): | |
if df is not None: | |
w_new, h_new = map(lambda x: int(x // df * df), [w, h]) | |
else: | |
w_new, h_new = w, h | |
return w_new, h_new | |
def pad_bottom_right(inp, pad_size, ret_mask=False): | |
assert isinstance(pad_size, int) and pad_size >= max( | |
inp.shape[-2:] | |
), f"{pad_size} < {max(inp.shape[-2:])}" | |
mask = None | |
if inp.ndim == 2: | |
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) | |
padded[: inp.shape[0], : inp.shape[1]] = inp | |
if ret_mask: | |
mask = np.zeros((pad_size, pad_size), dtype=bool) | |
mask[: inp.shape[0], : inp.shape[1]] = True | |
elif inp.ndim == 3: | |
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) | |
padded[:, : inp.shape[1], : inp.shape[2]] = inp | |
if ret_mask: | |
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) | |
mask[:, : inp.shape[1], : inp.shape[2]] = True | |
else: | |
raise NotImplementedError() | |
return padded, mask | |
# --- MEGADEPTH --- | |
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): | |
""" | |
Args: | |
resize (int, optional): the longer edge of resized images. None for no resize. | |
padding (bool): If set to 'True', zero-pad resized images to squared size. | |
augment_fn (callable, optional): augments images with pre-defined visual effects | |
Returns: | |
image (torch.tensor): (1, h, w) | |
mask (torch.tensor): (h, w) | |
scale (torch.tensor): [w/w_new, h/h_new] | |
""" | |
# read image | |
image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) | |
# resize image | |
w, h = image.shape[1], image.shape[0] | |
w_new, h_new = get_resized_wh(w, h, resize) | |
w_new, h_new = get_divisible_wh(w_new, h_new, df) | |
image = cv2.resize(image, (w_new, h_new)) | |
scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float) | |
if padding: # padding | |
pad_to = max(h_new, w_new) | |
image, mask = pad_bottom_right(image, pad_to, ret_mask=True) | |
else: | |
mask = None | |
image = ( | |
torch.from_numpy(image).float()[None] / 255 | |
) # (h, w) -> (1, h, w) and normalized | |
if mask is not None: | |
mask = torch.from_numpy(mask) | |
return image, mask, scale | |
def read_megadepth_depth(path, pad_to=None): | |
if str(path).startswith("s3://"): | |
depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) | |
else: | |
depth = np.array(h5py.File(path, "r")["depth"]) | |
if pad_to is not None: | |
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) | |
depth = torch.from_numpy(depth).float() # (h, w) | |
return depth | |
# --- ScanNet --- | |
def read_scannet_gray(path, resize=(640, 480), augment_fn=None): | |
""" | |
Args: | |
resize (tuple): align image to depthmap, in (w, h). | |
augment_fn (callable, optional): augments images with pre-defined visual effects | |
Returns: | |
image (torch.tensor): (1, h, w) | |
mask (torch.tensor): (h, w) | |
scale (torch.tensor): [w/w_new, h/h_new] | |
""" | |
# read and resize image | |
image = imread_gray(path, augment_fn) | |
image = cv2.resize(image, resize) | |
# (h, w) -> (1, h, w) and normalized | |
image = torch.from_numpy(image).float()[None] / 255 | |
return image | |
def read_scannet_depth(path): | |
if str(path).startswith("s3://"): | |
depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) | |
else: | |
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) | |
depth = depth / 1000 | |
depth = torch.from_numpy(depth).float() # (h, w) | |
return depth | |
def read_scannet_pose(path): | |
"""Read ScanNet's Camera2World pose and transform it to World2Camera. | |
Returns: | |
pose_w2c (np.ndarray): (4, 4) | |
""" | |
cam2world = np.loadtxt(path, delimiter=" ") | |
world2cam = inv(cam2world) | |
return world2cam | |
def read_scannet_intrinsic(path): | |
"""Read ScanNet's intrinsic matrix and return the 3x3 matrix.""" | |
intrinsic = np.loadtxt(path, delimiter=" ") | |
return intrinsic[:-1, :-1] | |
def read_gl3d_gray(path, resize): | |
img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (int(resize), int(resize))) | |
img = ( | |
torch.from_numpy(img).float()[None] / 255 | |
) # (h, w) -> (1, h, w) and normalized | |
return img | |
def read_gl3d_depth(file_path): | |
with open(file_path, "rb") as fin: | |
color = None | |
width = None | |
height = None | |
scale = None | |
data_type = None | |
header = str(fin.readline().decode("UTF-8")).rstrip() | |
if header == "PF": | |
color = True | |
elif header == "Pf": | |
color = False | |
else: | |
raise Exception("Not a PFM file.") | |
dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8")) | |
if dim_match: | |
width, height = map(int, dim_match.groups()) | |
else: | |
raise Exception("Malformed PFM header.") | |
scale = float((fin.readline().decode("UTF-8")).rstrip()) | |
if scale < 0: # little-endian | |
data_type = "<f" | |
else: | |
data_type = ">f" # big-endian | |
data_string = fin.read() | |
data = np.fromstring(data_string, data_type) | |
shape = (height, width, 3) if color else (height, width) | |
data = np.reshape(data, shape) | |
data = np.flip(data, 0) | |
return torch.from_numpy(data.copy()).float() | |