taesiri's picture
Initial Commit
8390f90
r""" Provides functions that manipulate boxes and points """
import math
import torch.nn.functional as F
import torch
class Geometry(object):
@classmethod
def initialize(cls, img_size):
cls.img_size = img_size
cls.spatial_side = int(img_size / 8)
norm_grid1d = torch.linspace(-1, 1, cls.spatial_side)
cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1)
cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1)
cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0)
cls.feat_idx = torch.arange(0, cls.spatial_side).float()
@classmethod
def normalize_kps(cls, kps):
kps = kps.clone().detach()
kps[kps != -2] -= (cls.img_size // 2)
kps[kps != -2] /= (cls.img_size // 2)
return kps
@classmethod
def unnormalize_kps(cls, kps):
kps = kps.clone().detach()
kps[kps != -2] *= (cls.img_size // 2)
kps[kps != -2] += (cls.img_size // 2)
return kps
@classmethod
def attentive_indexing(cls, kps, thres=0.1):
r"""kps: normalized keypoints x, y (N, 2)
returns attentive index map(N, spatial_side, spatial_side)
"""
nkps = kps.size(0)
kps = kps.view(nkps, 1, 1, 2)
eps = 1e-5
attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3)
attmap = (attmap + eps).pow(0.5)
attmap = (thres - attmap).clamp(min=0).view(nkps, -1)
attmap = attmap / attmap.sum(dim=1, keepdim=True)
attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side)
return attmap
@classmethod
def apply_gaussian_kernel(cls, corr, sigma=17):
bsz, side, side = corr.size()
center = corr.max(dim=2)[1]
center_y = center // cls.spatial_side
center_x = center % cls.spatial_side
y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2)
x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2)
y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side)
x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1)
gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side)
filtered_corr = filtered_corr.view(bsz, side, side)
return filtered_corr
@classmethod
def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized):
r""" Transfer keypoints by weighted average """
if not normalized:
src_kps = Geometry.normalize_kps(src_kps)
confidence_ts = cls.apply_gaussian_kernel(confidence_ts)
pdf = F.softmax(confidence_ts, dim=2)
prd_x = (pdf * cls.norm_grid_x).sum(dim=2)
prd_y = (pdf * cls.norm_grid_y).sum(dim=2)
prd_kps = []
for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)):
max_pts = src_kp.size()[1]
prd_xy = torch.stack([x, y]).t()
src_kp = src_kp[:, :np].t()
attmap = cls.attentive_indexing(src_kp).view(np, -1)
prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t()
pads = (torch.zeros((2, max_pts - np)) - 2)
prd_kp = torch.cat([prd_kp, pads], dim=1)
prd_kps.append(prd_kp)
return torch.stack(prd_kps)
@staticmethod
def get_coord1d(coord4d, ksz):
i, j, k, l = coord4d
coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l
return coord1d
@staticmethod
def get_distance(coord1, coord2):
delta_y = int(math.pow(coord1[0] - coord2[0], 2))
delta_x = int(math.pow(coord1[1] - coord2[1], 2))
dist = delta_y + delta_x
return dist
@staticmethod
def interpolate4d(tensor4d, size):
bsz, h1, w1, h2, w2 = tensor4d.size()
tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2)
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2)
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0])
return tensor4d
@staticmethod
def init_idx4d(ksz):
i0 = torch.arange(0, ksz).repeat(ksz ** 3)
i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2)
i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz)
i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1)
idx4d = torch.stack([i3, i2, i1, i0]).t().numpy()
return idx4d