File size: 4,926 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
def index(feat, uv):
'''
:param feat: [B, C, H, W] image features
:param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
:return: [B, C, N] image features at the uv coordinates
'''
uv = uv.transpose(1, 2) # [B, N, 2]
(B, N, _) = uv.shape
C = feat.shape[1]
if uv.shape[-1] == 3:
# uv = uv[:,:,[2,1,0]]
# uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
else:
uv = uv.unsqueeze(2) # [B, N, 1, 2]
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
# for old versions, simply remove the aligned_corners argument.
samples = torch.nn.functional.grid_sample(
feat, uv, align_corners=True) # [B, C, N, 1]
#samples = grid_sample(feat, uv) # [B, C, N, 1]
return samples.view(B, C, N) # [B, C, N]
def grid_sample(image, optical):
N, C, IH, IW = image.shape
_, H, W, _ = optical.shape
ix = optical[..., 0]
iy = optical[..., 1]
ix = ((ix + 1) / 2) * (IW-1);
iy = ((iy + 1) / 2) * (IH-1);
with torch.no_grad():
ix_nw = torch.floor(ix);
iy_nw = torch.floor(iy);
ix_ne = ix_nw + 1;
iy_ne = iy_nw;
ix_sw = ix_nw;
iy_sw = iy_nw + 1;
ix_se = ix_nw + 1;
iy_se = iy_nw + 1;
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
with torch.no_grad():
torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
torch.clamp(iy_nw, 0, IH-1, out=iy_nw)
torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
torch.clamp(iy_ne, 0, IH-1, out=iy_ne)
torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
torch.clamp(iy_sw, 0, IH-1, out=iy_sw)
torch.clamp(ix_se, 0, IW-1, out=ix_se)
torch.clamp(iy_se, 0, IH-1, out=iy_se)
image = image.view(N, C, IH * IW)
nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))
out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) +
ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
se_val.view(N, C, H, W) * se.view(N, 1, H, W))
return out_val
def orthogonal(points, calibrations, transforms=None):
'''
Compute the orthogonal projections of 3D points into the image plane by given projection matrix
:param points: [B, 3, N] Tensor of 3D points
:param calibrations: [B, 3, 4] Tensor of projection matrix
:param transforms: [B, 2, 3] Tensor of image transform matrix
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
'''
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
return pts
def perspective(points, calibrations, transforms=None):
'''
Compute the perspective projections of 3D points into the image plane by given projection matrix
:param points: [Bx3xN] Tensor of 3D points
:param calibrations: [Bx3x4] Tensor of projection matrix
:param transforms: [Bx2x3] Tensor of image transform matrix
:return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
'''
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
xy = homo[:, :2, :] / homo[:, 2:3, :]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
xy = torch.baddbmm(shift, scale, xy)
xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
return xyz
|