Spaces:
Running
Running
File size: 7,231 Bytes
a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import torch
import torch.nn as nn
import torchvision.transforms as tvf
from .modules import InterestPointModule, CorrespondenceModule
def warp_homography_batch(sources, homographies):
"""
Batch warp keypoints given homographies. From https://github.com/TRI-ML/KP2D.
Parameters
----------
sources: torch.Tensor (B,H,W,C)
Keypoints vector.
homographies: torch.Tensor (B,3,3)
Homographies.
Returns
-------
warped_sources: torch.Tensor (B,H,W,C)
Warped keypoints vector.
"""
B, H, W, _ = sources.shape
warped_sources = []
for b in range(B):
source = sources[b].clone()
source = source.view(-1, 2)
"""
[X, [M11, M12, M13 [x, M11*x + M12*y + M13 [M11, M12 [M13,
Y, = M21, M22, M23 * y, = M21*x + M22*y + M23 = [x, y] * M21, M22 + M23,
Z] M31, M32, M33] 1] M31*x + M32*y + M33 M31, M32].T M33]
"""
source = torch.addmm(homographies[b, :, 2], source, homographies[b, :, :2].t())
source.mul_(1 / source[:, 2].unsqueeze(1))
source = source[:, :2].contiguous().view(H, W, 2)
warped_sources.append(source)
return torch.stack(warped_sources, dim=0)
class PointModel(nn.Module):
def __init__(self, is_test=True):
super(PointModel, self).__init__()
self.is_test = is_test
self.interestpoint_module = InterestPointModule(is_test=self.is_test)
self.correspondence_module = CorrespondenceModule()
self.norm_rgb = tvf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225])
def forward(self, *args):
if self.is_test:
img = args[0]
img = self.norm_rgb(img)
score, coord, desc = self.interestpoint_module(img)
return score, coord, desc
else:
source_score, source_coord, source_desc_block = self.interestpoint_module(
args[0]
)
target_score, target_coord, target_desc_block = self.interestpoint_module(
args[1]
)
B, _, H, W = args[0].shape
B, _, hc, wc = source_score.shape
device = source_score.device
# Normalize the coordinates from ([0, h], [0, w]) to ([0, 1], [0, 1]).
source_coord_norm = source_coord.clone()
source_coord_norm[:, 0] = (
source_coord_norm[:, 0] / (float(W - 1) / 2.0)
) - 1.0
source_coord_norm[:, 1] = (
source_coord_norm[:, 1] / (float(H - 1) / 2.0)
) - 1.0
source_coord_norm = source_coord_norm.permute(0, 2, 3, 1)
target_coord_norm = target_coord.clone()
target_coord_norm[:, 0] = (
target_coord_norm[:, 0] / (float(W - 1) / 2.0)
) - 1.0
target_coord_norm[:, 1] = (
target_coord_norm[:, 1] / (float(H - 1) / 2.0)
) - 1.0
target_coord_norm = target_coord_norm.permute(0, 2, 3, 1)
target_coord_warped_norm = warp_homography_batch(source_coord_norm, args[2])
target_coord_warped = target_coord_warped_norm.clone()
# de-normlize the coordinates
target_coord_warped[:, :, :, 0] = (target_coord_warped[:, :, :, 0] + 1) * (
float(W - 1) / 2.0
)
target_coord_warped[:, :, :, 1] = (target_coord_warped[:, :, :, 1] + 1) * (
float(H - 1) / 2.0
)
target_coord_warped = target_coord_warped.permute(0, 3, 1, 2)
# Border mask
border_mask_ori = torch.ones(B, hc, wc)
border_mask_ori[:, 0] = 0
border_mask_ori[:, hc - 1] = 0
border_mask_ori[:, :, 0] = 0
border_mask_ori[:, :, wc - 1] = 0
border_mask_ori = border_mask_ori.gt(1e-3).to(device)
oob_mask2 = (
target_coord_warped_norm[:, :, :, 0].lt(1)
& target_coord_warped_norm[:, :, :, 0].gt(-1)
& target_coord_warped_norm[:, :, :, 1].lt(1)
& target_coord_warped_norm[:, :, :, 1].gt(-1)
)
border_mask = border_mask_ori & oob_mask2
# score
target_score_warped = torch.nn.functional.grid_sample(
target_score, target_coord_warped_norm.detach(), align_corners=False
)
# descriptor
source_desc2 = torch.nn.functional.grid_sample(
source_desc_block[0], source_coord_norm.detach()
)
source_desc3 = torch.nn.functional.grid_sample(
source_desc_block[1], source_coord_norm.detach()
)
source_aware = source_desc_block[2]
source_desc = torch.mul(
source_desc2, source_aware[:, 0, :, :].unsqueeze(1).contiguous()
) + torch.mul(
source_desc3, source_aware[:, 1, :, :].unsqueeze(1).contiguous()
)
target_desc2 = torch.nn.functional.grid_sample(
target_desc_block[0], target_coord_norm.detach()
)
target_desc3 = torch.nn.functional.grid_sample(
target_desc_block[1], target_coord_norm.detach()
)
target_aware = target_desc_block[2]
target_desc = torch.mul(
target_desc2, target_aware[:, 0, :, :].unsqueeze(1).contiguous()
) + torch.mul(
target_desc3, target_aware[:, 1, :, :].unsqueeze(1).contiguous()
)
target_desc2_warped = torch.nn.functional.grid_sample(
target_desc_block[0], target_coord_warped_norm.detach()
)
target_desc3_warped = torch.nn.functional.grid_sample(
target_desc_block[1], target_coord_warped_norm.detach()
)
target_aware_warped = torch.nn.functional.grid_sample(
target_desc_block[2], target_coord_warped_norm.detach()
)
target_desc_warped = torch.mul(
target_desc2_warped,
target_aware_warped[:, 0, :, :].unsqueeze(1).contiguous(),
) + torch.mul(
target_desc3_warped,
target_aware_warped[:, 1, :, :].unsqueeze(1).contiguous(),
)
confidence_matrix = self.correspondence_module(source_desc, target_desc)
confidence_matrix = torch.clamp(confidence_matrix, 1e-12, 1 - 1e-12)
output = {
"source_score": source_score,
"source_coord": source_coord,
"source_desc": source_desc,
"source_aware": source_aware,
"target_score": target_score,
"target_coord": target_coord,
"target_score_warped": target_score_warped,
"target_coord_warped": target_coord_warped,
"target_desc_warped": target_desc_warped,
"target_aware_warped": target_aware_warped,
"border_mask": border_mask,
"confidence_matrix": confidence_matrix,
}
return output
|