Spaces:
Sleeping
Sleeping
import torch | |
import cuda_corr | |
class CorrLayer(torch.autograd.Function): | |
def forward(ctx, fmap1, fmap2, coords, ii, jj, radius, dropout): | |
""" forward correlation """ | |
ctx.save_for_backward(fmap1, fmap2, coords, ii, jj) | |
ctx.radius = radius | |
ctx.dropout = dropout | |
corr, = cuda_corr.forward(fmap1, fmap2, coords, ii, jj, radius) | |
return corr | |
def backward(ctx, grad): | |
""" backward correlation """ | |
fmap1, fmap2, coords, ii, jj = ctx.saved_tensors | |
if ctx.dropout < 1: | |
perm = torch.rand(len(ii), device="cuda") < ctx.dropout | |
coords = coords[:,perm] | |
grad = grad[:,perm] | |
ii = ii[perm] | |
jj = jj[perm] | |
fmap1_grad, fmap2_grad = \ | |
cuda_corr.backward(fmap1, fmap2, coords, ii, jj, grad, ctx.radius) | |
return fmap1_grad, fmap2_grad, None, None, None, None, None | |
class PatchLayer(torch.autograd.Function): | |
def forward(ctx, net, coords, radius): | |
""" forward patchify """ | |
ctx.radius = radius | |
ctx.save_for_backward(net, coords) | |
patches, = cuda_corr.patchify_forward(net, coords, radius) | |
return patches | |
def backward(ctx, grad): | |
""" backward patchify """ | |
net, coords = ctx.saved_tensors | |
grad, = cuda_corr.patchify_backward(net, coords, grad, ctx.radius) | |
return grad, None, None | |
def patchify(net, coords, radius, mode='bilinear'): | |
""" extract patches """ | |
patches = PatchLayer.apply(net, coords, radius) | |
if mode == 'bilinear': | |
offset = (coords - coords.floor()).to(net.device) | |
dx, dy = offset[:,:,None,None,None].unbind(dim=-1) | |
d = 2 * radius + 1 | |
x00 = (1-dy) * (1-dx) * patches[...,:d,:d] | |
x01 = (1-dy) * ( dx) * patches[...,:d,1:] | |
x10 = ( dy) * (1-dx) * patches[...,1:,:d] | |
x11 = ( dy) * ( dx) * patches[...,1:,1:] | |
return x00 + x01 + x10 + x11 | |
return patches | |
def corr(fmap1, fmap2, coords, ii, jj, radius=1, dropout=1): | |
return CorrLayer.apply(fmap1, fmap2, coords, ii, jj, radius, dropout) | |