Vincentqyw
add: roma
62c7319
raw
history blame
No virus
920 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class NormedCorrelationKernel(nn.Module): # similar to softmax kernel
def __init__(self):
super().__init__()
def __call__(self, x, y, eps=1e-6):
c = torch.einsum("bnd,bmd->bnm", x, y) / (
x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
)
return c
class NormedCorr(nn.Module):
def __init__(
self,
):
super().__init__()
self.corr = NormedCorrelationKernel()
def reshape(self, x):
return rearrange(x, "b d h w -> b (h w) d")
def forward(self, x, y, **kwargs):
b, c, h, w = y.shape
assert x.shape == y.shape
x, y = self.reshape(x), self.reshape(y)
corr_xy = self.corr(x, y)
corr_xy_flat = rearrange(corr_xy, "b (h w) c -> b c h w", h=h, w=w)
return corr_xy_flat