cyun9286's picture
Add application file
f53b39e
raw
history blame
5.27 kB
import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock2:
def __init__(self, fmap1, fmap2, args):
self.num_levels = args.corr_levels
self.radius = args.corr_radius
self.args = args
self.corr_pyramid = []
# all pairs correlation
for i in range(self.num_levels):
corr = CorrBlock2.corr(fmap1, fmap2, 1)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='bilinear', align_corners=False)
self.corr_pyramid.append(corr)
def __call__(self, coords, dilation=None):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
if dilation is None:
dilation = torch.ones(batch, 1, h1, w1, device=coords.device)
# print(dilation.max(), dilation.mean(), dilation.min())
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
device = coords.device
dx = torch.linspace(-r, r, 2*r+1, device=device)
dy = torch.linspace(-r, r, 2*r+1, device=device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
out = out.permute(0, 3, 1, 2).contiguous().float()
return out
@staticmethod
def corr(fmap1, fmap2, num_head):
batch, dim, h1, w1 = fmap1.shape
h2, w2 = fmap2.shape[2:]
fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1)
fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2)
corr = fmap1.transpose(2, 3) @ fmap2
corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5)
return corr / torch.sqrt(torch.tensor(dim).float())
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())