Vincentqyw
fix: roma
c74a070
raw
history blame
No virus
2.19 kB
import torch
import torch.nn.functional as F
def local_correlation(
feature0,
feature1,
local_radius,
padding_mode="zeros",
flow=None,
sample_mode="bilinear",
):
r = local_radius
K = (2 * r + 1) ** 2
B, c, h, w = feature0.size()
feature0 = feature0.half()
feature1 = feature1.half()
corr = torch.empty((B, K, h, w), device=feature0.device, dtype=feature0.dtype)
if flow is None:
# If flow is None, assume feature0 and feature1 are aligned
coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
)
)
coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
else:
coords = flow.permute(0, 2, 3, 1) # If using flow, sample around flow target.
local_window = torch.meshgrid(
(
torch.linspace(
-2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device="cuda"
),
torch.linspace(
-2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device="cuda"
),
)
)
local_window = (
torch.stack((local_window[1], local_window[0]), dim=-1)[None]
.expand(1, 2 * r + 1, 2 * r + 1, 2)
.reshape(1, (2 * r + 1) ** 2, 2)
)
for _ in range(B):
with torch.no_grad():
local_window_coords = (
(coords[_, :, :, None] + local_window[:, None, None])
.reshape(1, h, w * (2 * r + 1) ** 2, 2)
.float()
)
window_feature = F.grid_sample(
feature1[_ : _ + 1].float(),
local_window_coords,
padding_mode=padding_mode,
align_corners=False,
mode=sample_mode, #
)
window_feature = window_feature.reshape(c, h, w, (2 * r + 1) ** 2)
corr[_] = (
(feature0[_, ..., None] / (c**0.5) * window_feature)
.sum(dim=0)
.permute(2, 0, 1)
)
torch.cuda.empty_cache()
return corr