Vincentqyw
fix: roma
c74a070
raw
history blame
No virus
1.6 kB
import torch
import torch.nn.functional as F
def local_correlation(
feature0, feature1, local_radius, padding_mode="zeros", flow=None
):
device = feature0.device
b, c, h, w = feature0.size()
if flow is None:
# If flow is None, assume feature0 and feature1 are aligned
coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
)
)
coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(b, h, w, 2)
else:
coords = flow.permute(0, 2, 3, 1) # If using flow, sample around flow target.
r = local_radius
local_window = torch.meshgrid(
(
torch.linspace(
-2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device
),
torch.linspace(
-2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device
),
)
)
local_window = (
torch.stack((local_window[1], local_window[0]), dim=-1)[None]
.expand(b, 2 * r + 1, 2 * r + 1, 2)
.reshape(b, (2 * r + 1) ** 2, 2)
)
coords = (coords[:, :, :, None] + local_window[:, None, None]).reshape(
b, h, w * (2 * r + 1) ** 2, 2
)
window_feature = F.grid_sample(
feature1, coords, padding_mode=padding_mode, align_corners=False
)[..., None].reshape(b, c, h, w, (2 * r + 1) ** 2)
corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature) / (c**0.5)
return corr