Spaces:
Running
Running
File size: 7,019 Bytes
a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import io
import cv2
import numpy as np
import h5py
import torch
from numpy.linalg import inv
import re
try:
# for internel use only
from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
except Exception:
MEGADEPTH_CLIENT = SCANNET_CLIENT = None
# --- DATA IO ---
def load_array_from_s3(
path, client, cv_type,
use_h5py=False,
):
byte_str = client.Get(path)
try:
if not use_h5py:
raw_array = np.fromstring(byte_str, np.uint8)
data = cv2.imdecode(raw_array, cv_type)
else:
f = io.BytesIO(byte_str)
data = np.array(h5py.File(f, 'r')['/depth'])
except Exception as ex:
print(f"==> Data loading failure: {path}")
raise ex
assert data is not None
return data
def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
else cv2.IMREAD_COLOR
if str(path).startswith('s3://'):
image = load_array_from_s3(str(path), client, cv_type)
else:
image = cv2.imread(str(path), cv_type)
if augment_fn is not None:
image = cv2.imread(str(path), cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = augment_fn(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
return image # (h, w)
def get_resized_wh(w, h, resize=None):
if resize is not None: # resize the longer edge
scale = resize / max(h, w)
w_new, h_new = int(round(w*scale)), int(round(h*scale))
else:
w_new, h_new = w, h
return w_new, h_new
def get_divisible_wh(w, h, df=None):
if df is not None:
w_new, h_new = map(lambda x: int(x // df * df), [w, h])
else:
w_new, h_new = w, h
return w_new, h_new
def pad_bottom_right(inp, pad_size, ret_mask=False):
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
mask = None
if inp.ndim == 2:
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
padded[:inp.shape[0], :inp.shape[1]] = inp
if ret_mask:
mask = np.zeros((pad_size, pad_size), dtype=bool)
mask[:inp.shape[0], :inp.shape[1]] = True
elif inp.ndim == 3:
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
padded[:, :inp.shape[1], :inp.shape[2]] = inp
if ret_mask:
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
mask[:, :inp.shape[1], :inp.shape[2]] = True
else:
raise NotImplementedError()
return padded, mask
# --- MEGADEPTH ---
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
"""
Args:
resize (int, optional): the longer edge of resized images. None for no resize.
padding (bool): If set to 'True', zero-pad resized images to squared size.
augment_fn (callable, optional): augments images with pre-defined visual effects
Returns:
image (torch.tensor): (1, h, w)
mask (torch.tensor): (h, w)
scale (torch.tensor): [w/w_new, h/h_new]
"""
# read image
image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
# resize image
w, h = image.shape[1], image.shape[0]
w_new, h_new = get_resized_wh(w, h, resize)
w_new, h_new = get_divisible_wh(w_new, h_new, df)
image = cv2.resize(image, (w_new, h_new))
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
if padding: # padding
pad_to = max(h_new, w_new)
image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
else:
mask = None
image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
if mask is not None:
mask = torch.from_numpy(mask)
return image, mask, scale
def read_megadepth_depth(path, pad_to=None):
if str(path).startswith('s3://'):
depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
else:
depth = np.array(h5py.File(path, 'r')['depth'])
if pad_to is not None:
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
depth = torch.from_numpy(depth).float() # (h, w)
return depth
# --- ScanNet ---
def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
"""
Args:
resize (tuple): align image to depthmap, in (w, h).
augment_fn (callable, optional): augments images with pre-defined visual effects
Returns:
image (torch.tensor): (1, h, w)
mask (torch.tensor): (h, w)
scale (torch.tensor): [w/w_new, h/h_new]
"""
# read and resize image
image = imread_gray(path, augment_fn)
image = cv2.resize(image, resize)
# (h, w) -> (1, h, w) and normalized
image = torch.from_numpy(image).float()[None] / 255
return image
def read_scannet_depth(path):
if str(path).startswith('s3://'):
depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
else:
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
depth = depth / 1000
depth = torch.from_numpy(depth).float() # (h, w)
return depth
def read_scannet_pose(path):
""" Read ScanNet's Camera2World pose and transform it to World2Camera.
Returns:
pose_w2c (np.ndarray): (4, 4)
"""
cam2world = np.loadtxt(path, delimiter=' ')
world2cam = inv(cam2world)
return world2cam
def read_scannet_intrinsic(path):
""" Read ScanNet's intrinsic matrix and return the 3x3 matrix.
"""
intrinsic = np.loadtxt(path, delimiter=' ')
return intrinsic[:-1, :-1]
def read_gl3d_gray(path,resize):
img=cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),(int(resize),int(resize)))
img = torch.from_numpy(img).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
return img
def read_gl3d_depth(file_path):
with open(file_path, 'rb') as fin:
color = None
width = None
height = None
scale = None
data_type = None
header = str(fin.readline().decode('UTF-8')).rstrip()
if header == 'PF':
color = True
elif header == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8'))
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float((fin.readline().decode('UTF-8')).rstrip())
if scale < 0: # little-endian
data_type = '<f'
else:
data_type = '>f' # big-endian
data_string = fin.read()
data = np.fromstring(data_string, data_type)
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flip(data, 0)
return torch.from_numpy(data.copy()).float() |