File size: 2,186 Bytes
c608946
 
 
c74a070
c608946
 
 
 
 
c74a070
 
c608946
 
c74a070
c608946
 
 
c74a070
c608946
 
 
c74a070
 
 
 
 
 
c608946
c74a070
c608946
c74a070
 
 
 
 
 
 
 
 
 
 
 
 
 
c608946
 
c74a070
 
 
 
 
c608946
c74a070
 
 
 
 
c608946
c74a070
 
 
 
 
 
c608946
c74a070
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn.functional as F


def local_correlation(
    feature0,
    feature1,
    local_radius,
    padding_mode="zeros",
    flow=None,
    sample_mode="bilinear",
):
    r = local_radius
    K = (2 * r + 1) ** 2
    B, c, h, w = feature0.size()
    feature0 = feature0.half()
    feature1 = feature1.half()
    corr = torch.empty((B, K, h, w), device=feature0.device, dtype=feature0.dtype)
    if flow is None:
        # If flow is None, assume feature0 and feature1 are aligned
        coords = torch.meshgrid(
            (
                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
            )
        )
        coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
    else:
        coords = flow.permute(0, 2, 3, 1)  # If using flow, sample around flow target.
    local_window = torch.meshgrid(
        (
            torch.linspace(
                -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device="cuda"
            ),
            torch.linspace(
                -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device="cuda"
            ),
        )
    )
    local_window = (
        torch.stack((local_window[1], local_window[0]), dim=-1)[None]
        .expand(1, 2 * r + 1, 2 * r + 1, 2)
        .reshape(1, (2 * r + 1) ** 2, 2)
    )
    for _ in range(B):
        with torch.no_grad():
            local_window_coords = (
                (coords[_, :, :, None] + local_window[:, None, None])
                .reshape(1, h, w * (2 * r + 1) ** 2, 2)
                .float()
            )
            window_feature = F.grid_sample(
                feature1[_ : _ + 1].float(),
                local_window_coords,
                padding_mode=padding_mode,
                align_corners=False,
                mode=sample_mode,  #
            )
            window_feature = window_feature.reshape(c, h, w, (2 * r + 1) ** 2)
        corr[_] = (
            (feature0[_, ..., None] / (c**0.5) * window_feature)
            .sum(dim=0)
            .permute(2, 0, 1)
        )
    torch.cuda.empty_cache()
    return corr