Spaces:
Sleeping
Sleeping
import time | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
def rnd_sample(inputs, n_sample): | |
cur_size = inputs[0].shape[0] | |
rnd_idx = torch.randperm(cur_size)[0:n_sample] | |
outputs = [i[rnd_idx] for i in inputs] | |
return outputs | |
def _grid_positions(h, w, bs): | |
x_rng = torch.arange(0, w.int()) | |
y_rng = torch.arange(0, h.int()) | |
xv, yv = torch.meshgrid(x_rng, y_rng, indexing="xy") | |
return ( | |
torch.reshape(torch.stack((yv, xv), axis=-1), (1, -1, 2)) | |
.repeat(bs, 1, 1) | |
.float() | |
) | |
def getK(ori_img_size, cur_feat_size, K): | |
# WARNING: cur_feat_size's order is [h, w] | |
r = ori_img_size / cur_feat_size[[1, 0]] | |
r_K0 = torch.stack( | |
[K[:, 0] / r[:, 0][..., None], K[:, 1] / r[:, 1][..., None], K[:, 2]], axis=1 | |
) | |
return r_K0 | |
def gather_nd(params, indices): | |
"""The same as tf.gather_nd but batched gather is not supported yet. | |
indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params: | |
output[\\(i_0, ..., i_{k-2}\\)] = params[indices[\\(i_0, ..., i_{k-2}\\)]] | |
Args: | |
params (Tensor): "n" dimensions. shape: [x_0, x_1, x_2, ..., x_{n-1}] | |
indices (Tensor): "k" dimensions. shape: [y_0,y_2,...,y_{k-2}, m]. m <= n. | |
Returns: gathered Tensor. | |
shape [y_0,y_2,...y_{k-2}] + params.shape[m:] | |
""" | |
orig_shape = list(indices.shape) | |
num_samples = np.prod(orig_shape[:-1]) | |
m = orig_shape[-1] | |
n = len(params.shape) | |
if m <= n: | |
out_shape = orig_shape[:-1] + list(params.shape)[m:] | |
else: | |
raise ValueError( | |
f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}" | |
) | |
indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist() | |
output = params[indices] # (num_samples, ...) | |
return output.reshape(out_shape).contiguous() | |
# input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W] | |
# output: [kpt_n, 128] / [kpt_n] | |
def interpolate(pos, inputs, nd=True): | |
h = inputs.shape[0] | |
w = inputs.shape[1] | |
i = pos[:, 0] | |
j = pos[:, 1] | |
i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1) | |
j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1) | |
i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1) | |
j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) | |
i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1) | |
j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1) | |
i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1) | |
j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) | |
dist_i_top_left = i - i_top_left.float() | |
dist_j_top_left = j - j_top_left.float() | |
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) | |
w_top_right = (1 - dist_i_top_left) * dist_j_top_left | |
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) | |
w_bottom_right = dist_i_top_left * dist_j_top_left | |
if nd: | |
w_top_left = w_top_left[..., None] | |
w_top_right = w_top_right[..., None] | |
w_bottom_left = w_bottom_left[..., None] | |
w_bottom_right = w_bottom_right[..., None] | |
interpolated_val = ( | |
w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) | |
+ w_top_right | |
* gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) | |
+ w_bottom_left | |
* gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) | |
+ w_bottom_right | |
* gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) | |
) | |
return interpolated_val | |
def validate_and_interpolate( | |
pos, inputs, validate_corner=True, validate_val=None, nd=False | |
): | |
if nd: | |
h, w, c = inputs.shape | |
else: | |
h, w = inputs.shape | |
ids = torch.arange(0, pos.shape[0]) | |
i = pos[:, 0] | |
j = pos[:, 1] | |
i_top_left = torch.floor(i).int() | |
j_top_left = torch.floor(j).int() | |
i_top_right = torch.floor(i).int() | |
j_top_right = torch.ceil(j).int() | |
i_bottom_left = torch.ceil(i).int() | |
j_bottom_left = torch.floor(j).int() | |
i_bottom_right = torch.ceil(i).int() | |
j_bottom_right = torch.ceil(j).int() | |
if validate_corner: | |
# Valid corner | |
valid_top_left = torch.logical_and(i_top_left >= 0, j_top_left >= 0) | |
valid_top_right = torch.logical_and(i_top_right >= 0, j_top_right < w) | |
valid_bottom_left = torch.logical_and(i_bottom_left < h, j_bottom_left >= 0) | |
valid_bottom_right = torch.logical_and(i_bottom_right < h, j_bottom_right < w) | |
valid_corner = torch.logical_and( | |
torch.logical_and(valid_top_left, valid_top_right), | |
torch.logical_and(valid_bottom_left, valid_bottom_right), | |
) | |
i_top_left = i_top_left[valid_corner] | |
j_top_left = j_top_left[valid_corner] | |
i_top_right = i_top_right[valid_corner] | |
j_top_right = j_top_right[valid_corner] | |
i_bottom_left = i_bottom_left[valid_corner] | |
j_bottom_left = j_bottom_left[valid_corner] | |
i_bottom_right = i_bottom_right[valid_corner] | |
j_bottom_right = j_bottom_right[valid_corner] | |
ids = ids[valid_corner] | |
if validate_val is not None: | |
# Valid depth | |
valid_depth = torch.logical_and( | |
torch.logical_and( | |
gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) > 0, | |
gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0, | |
), | |
torch.logical_and( | |
gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) | |
> 0, | |
gather_nd( | |
inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1) | |
) | |
> 0, | |
), | |
) | |
i_top_left = i_top_left[valid_depth] | |
j_top_left = j_top_left[valid_depth] | |
i_top_right = i_top_right[valid_depth] | |
j_top_right = j_top_right[valid_depth] | |
i_bottom_left = i_bottom_left[valid_depth] | |
j_bottom_left = j_bottom_left[valid_depth] | |
i_bottom_right = i_bottom_right[valid_depth] | |
j_bottom_right = j_bottom_right[valid_depth] | |
ids = ids[valid_depth] | |
# Interpolation | |
i = i[ids] | |
j = j[ids] | |
dist_i_top_left = i - i_top_left.float() | |
dist_j_top_left = j - j_top_left.float() | |
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) | |
w_top_right = (1 - dist_i_top_left) * dist_j_top_left | |
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) | |
w_bottom_right = dist_i_top_left * dist_j_top_left | |
if nd: | |
w_top_left = w_top_left[..., None] | |
w_top_right = w_top_right[..., None] | |
w_bottom_left = w_bottom_left[..., None] | |
w_bottom_right = w_bottom_right[..., None] | |
interpolated_val = ( | |
w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) | |
+ w_top_right | |
* gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) | |
+ w_bottom_left | |
* gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) | |
+ w_bottom_right | |
* gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) | |
) | |
pos = torch.stack([i, j], axis=1) | |
return [interpolated_val, pos, ids] | |
# pos0: [2, 230400, 2] | |
# depth0: [2, 480, 480] | |
def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs): | |
def swap_axis(data): | |
return torch.stack([data[:, 1], data[:, 0]], axis=-1) | |
all_pos0 = [] | |
all_pos1 = [] | |
all_ids = [] | |
for i in range(bs): | |
z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0) | |
uv0_homo = torch.cat( | |
[ | |
swap_axis(new_pos0), | |
torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device), | |
], | |
axis=-1, | |
) | |
xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t()) | |
xyz0_homo = torch.cat( | |
[ | |
torch.unsqueeze(z0, 0) * xy0_homo, | |
torch.ones((1, new_pos0.shape[0])).to(z0.device), | |
], | |
axis=0, | |
) | |
xyz1 = torch.matmul(rel_pose[i], xyz0_homo) | |
xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) | |
uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2] | |
new_pos1 = swap_axis(uv1) | |
annotated_depth, new_pos1, new_ids = validate_and_interpolate( | |
new_pos1, depth1[i], validate_val=0 | |
) | |
ids = ids[new_ids] | |
new_pos0 = new_pos0[new_ids] | |
estimated_depth = xyz1.t()[new_ids][:, -1] | |
inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05 | |
all_ids.append(ids[inlier_mask]) | |
all_pos0.append(new_pos0[inlier_mask]) | |
all_pos1.append(new_pos1[inlier_mask]) | |
# all_pos0 & all_pose1: [inlier_num, 2] * batch_size | |
return all_pos0, all_pos1, all_ids | |
# pos0: [2, 230400, 2] | |
# depth0: [2, 480, 480] | |
def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs): | |
def swap_axis(data): | |
return torch.stack([data[:, 1], data[:, 0]], axis=-1) | |
all_pos0 = [] | |
all_pos1 = [] | |
all_ids = [] | |
for i in range(bs): | |
z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0) | |
uv0_homo = torch.cat( | |
[ | |
swap_axis(new_pos0), | |
torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device), | |
], | |
axis=-1, | |
) | |
xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t()) | |
xyz0_homo = torch.cat( | |
[ | |
torch.unsqueeze(z0, 0) * xy0_homo, | |
torch.ones((1, new_pos0.shape[0])).to(z0.device), | |
], | |
axis=0, | |
) | |
xyz1 = torch.matmul(rel_pose[i], xyz0_homo) | |
xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) | |
uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2] | |
new_pos1 = swap_axis(uv1) | |
_, new_pos1, new_ids = validate_and_interpolate( | |
new_pos1, depth1[i], validate_val=0 | |
) | |
ids = ids[new_ids] | |
new_pos0 = new_pos0[new_ids] | |
all_ids.append(ids) | |
all_pos0.append(new_pos0) | |
all_pos1.append(new_pos1) | |
# all_pos0 & all_pose1: [inlier_num, 2] * batch_size | |
return all_pos0, all_pos1, all_ids | |
# pos0: [2, 230400, 2] | |
# depth0: [2, 480, 480] | |
def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1): | |
def swap_axis(data): | |
return torch.stack([data[:, 1], data[:, 0]], axis=-1) | |
z0 = interpolate(pos0, depth0, nd=False) | |
uv0_homo = torch.cat( | |
[swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1 | |
) | |
xy0_homo = torch.matmul(torch.linalg.inv(K0), uv0_homo.t()) | |
xyz0_homo = torch.cat( | |
[ | |
torch.unsqueeze(z0, 0) * xy0_homo, | |
torch.ones((1, pos0.shape[0])).to(z0.device), | |
], | |
axis=0, | |
) | |
xyz1 = torch.matmul(rel_pose, xyz0_homo) | |
xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) | |
uv1 = torch.matmul(K1, xy1_homo).t()[:, 0:2] | |
new_pos1 = swap_axis(uv1) | |
return new_pos1 | |
def get_dist_mat(feat1, feat2, dist_type): | |
eps = 1e-6 | |
cos_dist_mat = torch.matmul(feat1, feat2.t()) | |
if dist_type == "cosine_dist": | |
dist_mat = torch.clamp(cos_dist_mat, -1, 1) | |
elif dist_type == "euclidean_dist": | |
dist_mat = torch.sqrt(torch.clamp(2 - 2 * cos_dist_mat, min=eps)) | |
elif dist_type == "euclidean_dist_no_norm": | |
norm1 = torch.sum(feat1 * feat1, axis=-1, keepdims=True) | |
norm2 = torch.sum(feat2 * feat2, axis=-1, keepdims=True) | |
dist_mat = torch.sqrt( | |
torch.clamp(norm1 - 2 * cos_dist_mat + norm2.t(), min=0.0) + eps | |
) | |
else: | |
raise NotImplementedError() | |
return dist_mat | |