|
import time |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def rnd_sample(inputs, n_sample): |
|
cur_size = inputs[0].shape[0] |
|
rnd_idx = torch.randperm(cur_size)[0:n_sample] |
|
outputs = [i[rnd_idx] for i in inputs] |
|
return outputs |
|
|
|
|
|
def _grid_positions(h, w, bs): |
|
x_rng = torch.arange(0, w.int()) |
|
y_rng = torch.arange(0, h.int()) |
|
xv, yv = torch.meshgrid(x_rng, y_rng, indexing='xy') |
|
return torch.reshape( |
|
torch.stack((yv, xv), axis=-1), |
|
(1, -1, 2) |
|
).repeat(bs, 1, 1).float() |
|
|
|
|
|
def getK(ori_img_size, cur_feat_size, K): |
|
|
|
r = ori_img_size / cur_feat_size[[1, 0]] |
|
r_K0 = torch.stack([K[:, 0] / r[:, 0][..., None], K[:, 1] / |
|
r[:, 1][..., None], K[:, 2]], axis=1) |
|
return r_K0 |
|
|
|
|
|
def gather_nd(params, indices): |
|
""" The same as tf.gather_nd but batched gather is not supported yet. |
|
indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params: |
|
|
|
output[\\(i_0, ..., i_{k-2}\\)] = params[indices[\\(i_0, ..., i_{k-2}\\)]] |
|
|
|
Args: |
|
params (Tensor): "n" dimensions. shape: [x_0, x_1, x_2, ..., x_{n-1}] |
|
indices (Tensor): "k" dimensions. shape: [y_0,y_2,...,y_{k-2}, m]. m <= n. |
|
|
|
Returns: gathered Tensor. |
|
shape [y_0,y_2,...y_{k-2}] + params.shape[m:] |
|
|
|
""" |
|
orig_shape = list(indices.shape) |
|
num_samples = np.prod(orig_shape[:-1]) |
|
m = orig_shape[-1] |
|
n = len(params.shape) |
|
|
|
if m <= n: |
|
out_shape = orig_shape[:-1] + list(params.shape)[m:] |
|
else: |
|
raise ValueError( |
|
f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}' |
|
) |
|
|
|
indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist() |
|
output = params[indices] |
|
return output.reshape(out_shape).contiguous() |
|
|
|
|
|
|
|
def interpolate(pos, inputs, nd=True): |
|
h = inputs.shape[0] |
|
w = inputs.shape[1] |
|
|
|
i = pos[:, 0] |
|
j = pos[:, 1] |
|
|
|
i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1) |
|
j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1) |
|
|
|
i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1) |
|
j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) |
|
|
|
i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1) |
|
j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1) |
|
|
|
i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1) |
|
j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) |
|
|
|
dist_i_top_left = i - i_top_left.float() |
|
dist_j_top_left = j - j_top_left.float() |
|
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) |
|
w_top_right = (1 - dist_i_top_left) * dist_j_top_left |
|
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) |
|
w_bottom_right = dist_i_top_left * dist_j_top_left |
|
|
|
if nd: |
|
w_top_left = w_top_left[..., None] |
|
w_top_right = w_top_right[..., None] |
|
w_bottom_left = w_bottom_left[..., None] |
|
w_bottom_right = w_bottom_right[..., None] |
|
|
|
interpolated_val = ( |
|
w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + |
|
w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + |
|
w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + |
|
w_bottom_right * |
|
gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) |
|
) |
|
|
|
return interpolated_val |
|
|
|
|
|
def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=None, nd=False): |
|
if nd: |
|
h, w, c = inputs.shape |
|
else: |
|
h, w = inputs.shape |
|
ids = torch.arange(0, pos.shape[0]) |
|
|
|
i = pos[:, 0] |
|
j = pos[:, 1] |
|
|
|
i_top_left = torch.floor(i).int() |
|
j_top_left = torch.floor(j).int() |
|
|
|
i_top_right = torch.floor(i).int() |
|
j_top_right = torch.ceil(j).int() |
|
|
|
i_bottom_left = torch.ceil(i).int() |
|
j_bottom_left = torch.floor(j).int() |
|
|
|
i_bottom_right = torch.ceil(i).int() |
|
j_bottom_right = torch.ceil(j).int() |
|
|
|
if validate_corner: |
|
|
|
valid_top_left = torch.logical_and(i_top_left >= 0, j_top_left >= 0) |
|
valid_top_right = torch.logical_and(i_top_right >= 0, j_top_right < w) |
|
valid_bottom_left = torch.logical_and(i_bottom_left < h, j_bottom_left >= 0) |
|
valid_bottom_right = torch.logical_and(i_bottom_right < h, j_bottom_right < w) |
|
|
|
valid_corner = torch.logical_and( |
|
torch.logical_and(valid_top_left, valid_top_right), |
|
torch.logical_and(valid_bottom_left, valid_bottom_right) |
|
) |
|
|
|
i_top_left = i_top_left[valid_corner] |
|
j_top_left = j_top_left[valid_corner] |
|
|
|
i_top_right = i_top_right[valid_corner] |
|
j_top_right = j_top_right[valid_corner] |
|
|
|
i_bottom_left = i_bottom_left[valid_corner] |
|
j_bottom_left = j_bottom_left[valid_corner] |
|
|
|
i_bottom_right = i_bottom_right[valid_corner] |
|
j_bottom_right = j_bottom_right[valid_corner] |
|
|
|
ids = ids[valid_corner] |
|
|
|
if validate_val is not None: |
|
|
|
valid_depth = torch.logical_and( |
|
torch.logical_and( |
|
gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) > 0, |
|
gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0 |
|
), |
|
torch.logical_and( |
|
gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) > 0, |
|
gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) > 0 |
|
) |
|
) |
|
|
|
i_top_left = i_top_left[valid_depth] |
|
j_top_left = j_top_left[valid_depth] |
|
|
|
i_top_right = i_top_right[valid_depth] |
|
j_top_right = j_top_right[valid_depth] |
|
|
|
i_bottom_left = i_bottom_left[valid_depth] |
|
j_bottom_left = j_bottom_left[valid_depth] |
|
|
|
i_bottom_right = i_bottom_right[valid_depth] |
|
j_bottom_right = j_bottom_right[valid_depth] |
|
|
|
ids = ids[valid_depth] |
|
|
|
|
|
i = i[ids] |
|
j = j[ids] |
|
dist_i_top_left = i - i_top_left.float() |
|
dist_j_top_left = j - j_top_left.float() |
|
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) |
|
w_top_right = (1 - dist_i_top_left) * dist_j_top_left |
|
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) |
|
w_bottom_right = dist_i_top_left * dist_j_top_left |
|
|
|
if nd: |
|
w_top_left = w_top_left[..., None] |
|
w_top_right = w_top_right[..., None] |
|
w_bottom_left = w_bottom_left[..., None] |
|
w_bottom_right = w_bottom_right[..., None] |
|
|
|
interpolated_val = ( |
|
w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + |
|
w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + |
|
w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + |
|
w_bottom_right * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) |
|
) |
|
|
|
pos = torch.stack([i, j], axis=1) |
|
return [interpolated_val, pos, ids] |
|
|
|
|
|
|
|
|
|
def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs): |
|
def swap_axis(data): |
|
return torch.stack([data[:, 1], data[:, 0]], axis=-1) |
|
|
|
all_pos0 = [] |
|
all_pos1 = [] |
|
all_ids = [] |
|
for i in range(bs): |
|
z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0) |
|
|
|
uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1) |
|
xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t()) |
|
xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, |
|
torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0) |
|
|
|
xyz1 = torch.matmul(rel_pose[i], xyz0_homo) |
|
xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) |
|
uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2] |
|
|
|
new_pos1 = swap_axis(uv1) |
|
annotated_depth, new_pos1, new_ids = validate_and_interpolate( |
|
new_pos1, depth1[i], validate_val=0) |
|
|
|
ids = ids[new_ids] |
|
new_pos0 = new_pos0[new_ids] |
|
estimated_depth = xyz1.t()[new_ids][:, -1] |
|
|
|
inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05 |
|
|
|
all_ids.append(ids[inlier_mask]) |
|
all_pos0.append(new_pos0[inlier_mask]) |
|
all_pos1.append(new_pos1[inlier_mask]) |
|
|
|
return all_pos0, all_pos1, all_ids |
|
|
|
|
|
|
|
|
|
def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs): |
|
def swap_axis(data): |
|
return torch.stack([data[:, 1], data[:, 0]], axis=-1) |
|
|
|
all_pos0 = [] |
|
all_pos1 = [] |
|
all_ids = [] |
|
for i in range(bs): |
|
z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0) |
|
|
|
uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1) |
|
xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t()) |
|
xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, |
|
torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0) |
|
|
|
xyz1 = torch.matmul(rel_pose[i], xyz0_homo) |
|
xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) |
|
uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2] |
|
|
|
new_pos1 = swap_axis(uv1) |
|
_, new_pos1, new_ids = validate_and_interpolate( |
|
new_pos1, depth1[i], validate_val=0) |
|
|
|
ids = ids[new_ids] |
|
new_pos0 = new_pos0[new_ids] |
|
|
|
all_ids.append(ids) |
|
all_pos0.append(new_pos0) |
|
all_pos1.append(new_pos1) |
|
|
|
return all_pos0, all_pos1, all_ids |
|
|
|
|
|
|
|
|
|
def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1): |
|
def swap_axis(data): |
|
return torch.stack([data[:, 1], data[:, 0]], axis=-1) |
|
|
|
z0 = interpolate(pos0, depth0, nd=False) |
|
|
|
uv0_homo = torch.cat([swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1) |
|
xy0_homo = torch.matmul(torch.linalg.inv(K0), uv0_homo.t()) |
|
xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, |
|
torch.ones((1, pos0.shape[0])).to(z0.device)], axis=0) |
|
|
|
xyz1 = torch.matmul(rel_pose, xyz0_homo) |
|
xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) |
|
uv1 = torch.matmul(K1, xy1_homo).t()[:, 0:2] |
|
|
|
new_pos1 = swap_axis(uv1) |
|
|
|
return new_pos1 |
|
|
|
|
|
|
|
def get_dist_mat(feat1, feat2, dist_type): |
|
eps = 1e-6 |
|
cos_dist_mat = torch.matmul(feat1, feat2.t()) |
|
if dist_type == 'cosine_dist': |
|
dist_mat = torch.clamp(cos_dist_mat, -1, 1) |
|
elif dist_type == 'euclidean_dist': |
|
dist_mat = torch.sqrt(torch.clamp(2 - 2 * cos_dist_mat, min=eps)) |
|
elif dist_type == 'euclidean_dist_no_norm': |
|
norm1 = torch.sum(feat1 * feat1, axis=-1, keepdims=True) |
|
norm2 = torch.sum(feat2 * feat2, axis=-1, keepdims=True) |
|
dist_mat = torch.sqrt( |
|
torch.clamp( |
|
norm1 - 2 * cos_dist_mat + norm2.t(), |
|
min=0. |
|
) + eps |
|
) |
|
else: |
|
raise NotImplementedError() |
|
return dist_mat |
|
|