Spaces:
Running
Running
File size: 4,184 Bytes
2c8b554 1ccca05 2c8b554 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .exceptions import EmptyTensorError
from .utils import interpolate_dense_features, upscale_positions
def process_multiscale(image, model, scales=[.5, 1, 2]):
b, _, h_init, w_init = image.size()
device = image.device
assert(b == 1)
all_keypoints = torch.zeros([3, 0])
all_descriptors = torch.zeros([
model.dense_feature_extraction.num_channels, 0
])
all_scores = torch.zeros(0)
previous_dense_features = None
banned = None
for idx, scale in enumerate(scales):
current_image = F.interpolate(
image, scale_factor=scale,
mode='bilinear', align_corners=True
)
_, _, h_level, w_level = current_image.size()
dense_features = model.dense_feature_extraction(current_image)
del current_image
_, _, h, w = dense_features.size()
# Sum the feature maps.
if previous_dense_features is not None:
dense_features += F.interpolate(
previous_dense_features, size=[h, w],
mode='bilinear', align_corners=True
)
del previous_dense_features
# Recover detections.
detections = model.detection(dense_features)
if banned is not None:
banned = F.interpolate(banned.float(), size=[h, w]).bool()
detections = torch.min(detections, ~banned)
banned = torch.max(
torch.max(detections, dim=1)[0].unsqueeze(1), banned
)
else:
banned = torch.max(detections, dim=1)[0].unsqueeze(1)
fmap_pos = torch.nonzero(detections[0].cpu()).t()
del detections
# Recover displacements.
displacements = model.localization(dense_features)[0].cpu()
displacements_i = displacements[
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
]
displacements_j = displacements[
1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
]
del displacements
mask = torch.min(
torch.abs(displacements_i) < 0.5,
torch.abs(displacements_j) < 0.5
)
fmap_pos = fmap_pos[:, mask]
valid_displacements = torch.stack([
displacements_i[mask],
displacements_j[mask]
], dim=0)
del mask, displacements_i, displacements_j
fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements
del valid_displacements
try:
raw_descriptors, _, ids = interpolate_dense_features(
fmap_keypoints.to(device),
dense_features[0]
)
except EmptyTensorError:
continue
fmap_pos = fmap_pos.to(device)
fmap_keypoints = fmap_keypoints.to(device)
fmap_pos = fmap_pos[:, ids]
fmap_keypoints = fmap_keypoints[:, ids]
del ids
keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
del fmap_keypoints
descriptors = F.normalize(raw_descriptors, dim=0).cpu()
del raw_descriptors
keypoints[0, :] *= h_init / h_level
keypoints[1, :] *= w_init / w_level
fmap_pos = fmap_pos.cpu()
keypoints = keypoints.cpu()
keypoints = torch.cat([
keypoints,
torch.ones([1, keypoints.size(1)]) * 1 / scale,
], dim=0)
scores = dense_features[
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
].cpu() / (idx + 1)
del fmap_pos
all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
all_scores = torch.cat([all_scores, scores], dim=0)
del keypoints, descriptors
previous_dense_features = dense_features
del dense_features
del previous_dense_features, banned
keypoints = all_keypoints.t().detach().numpy()
del all_keypoints
scores = all_scores.detach().numpy()
del all_scores
descriptors = all_descriptors.t().detach().numpy()
del all_descriptors
return keypoints, scores, descriptors
|