Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import os | |
os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory | |
import numpy as np | |
import math | |
import pdb | |
import time | |
from .submodule import pspnet, bfmodule, conv | |
from .conv4d import sepConv4d, sepConv4dBlock, butterfly4D | |
class flow_reg(nn.Module): | |
""" | |
Soft winner-take-all that selects the most likely diplacement. | |
Set ent=True to enable entropy output. | |
Set maxdisp to adjust maximum allowed displacement towards one side. | |
maxdisp=4 searches for a 9x9 region. | |
Set fac to squeeze search window. | |
maxdisp=4 and fac=2 gives search window of 9x5 | |
""" | |
def __init__(self, size, ent=False, maxdisp = int(4), fac=1): | |
B,W,H = size | |
super(flow_reg, self).__init__() | |
self.ent = ent | |
self.md = maxdisp | |
self.fac = fac | |
self.truncated = True | |
self.wsize = 3 # by default using truncation 7x7 | |
flowrangey = range(-maxdisp,maxdisp+1) | |
flowrangex = range(-int(maxdisp//self.fac),int(maxdisp//self.fac)+1) | |
meshgrid = np.meshgrid(flowrangex,flowrangey) | |
flowy = np.tile( np.reshape(meshgrid[0],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) ) | |
flowx = np.tile( np.reshape(meshgrid[1],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) ) | |
self.register_buffer('flowx',torch.Tensor(flowx)) | |
self.register_buffer('flowy',torch.Tensor(flowy)) | |
self.pool3d = nn.MaxPool3d((self.wsize*2+1,self.wsize*2+1,1),stride=1,padding=(self.wsize,self.wsize,0)) | |
def forward(self, x): | |
b,u,v,h,w = x.shape | |
oldx = x | |
if self.truncated: | |
# truncated softmax | |
x = x.view(b,u*v,h,w) | |
idx = x.argmax(1)[:,np.newaxis] | |
if x.is_cuda: | |
mask = Variable(torch.cuda.HalfTensor(b,u*v,h,w)).fill_(0) | |
else: | |
mask = Variable(torch.FloatTensor(b,u*v,h,w)).fill_(0) | |
mask.scatter_(1,idx,1) | |
mask = mask.view(b,1,u,v,-1) | |
mask = self.pool3d(mask)[:,0].view(b,u,v,h,w) | |
ninf = x.clone().fill_(-np.inf).view(b,u,v,h,w) | |
x = torch.where(mask.byte(),oldx,ninf) | |
else: | |
self.wsize = (np.sqrt(u*v)-1)/2 | |
b,u,v,h,w = x.shape | |
x = F.softmax(x.view(b,-1,h,w),1).view(b,u,v,h,w) | |
outx = torch.sum(torch.sum(x*self.flowx,1),1,keepdim=True) | |
outy = torch.sum(torch.sum(x*self.flowy,1),1,keepdim=True) | |
if self.ent: | |
# local | |
local_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis] | |
if self.wsize == 0: | |
local_entropy[:] = 1. | |
else: | |
local_entropy /= np.log((self.wsize*2+1)**2) | |
# global | |
x = F.softmax(oldx.view(b,-1,h,w),1).view(b,u,v,h,w) | |
global_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis] | |
global_entropy /= np.log(x.shape[1]*x.shape[2]) | |
return torch.cat([outx,outy],1),torch.cat([local_entropy, global_entropy],1) | |
else: | |
return torch.cat([outx,outy],1),None | |
class WarpModule(nn.Module): | |
""" | |
taken from https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py | |
""" | |
def __init__(self, size): | |
super(WarpModule, self).__init__() | |
B,W,H = size | |
# mesh grid | |
xx = torch.arange(0, W).view(1,-1).repeat(H,1) | |
yy = torch.arange(0, H).view(-1,1).repeat(1,W) | |
xx = xx.view(1,1,H,W).repeat(B,1,1,1) | |
yy = yy.view(1,1,H,W).repeat(B,1,1,1) | |
self.register_buffer('grid',torch.cat((xx,yy),1).float()) | |
def forward(self, x, flo): | |
""" | |
warp an image/tensor (im2) back to im1, according to the optical flow | |
x: [B, C, H, W] (im2) | |
flo: [B, 2, H, W] flow | |
""" | |
B, C, H, W = x.size() | |
vgrid = self.grid + flo | |
# scale grid to [-1,1] | |
vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0 | |
vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0 | |
vgrid = vgrid.permute(0,2,3,1) | |
output = nn.functional.grid_sample(x, vgrid,align_corners=True) | |
mask = ((vgrid[:,:,:,0].abs()<1) * (vgrid[:,:,:,1].abs()<1)) >0 | |
return output*mask.unsqueeze(1).float(), mask | |
def get_grid(B,H,W): | |
meshgrid_base = np.meshgrid(range(0,W), range(0,H))[::-1] | |
basey = np.reshape(meshgrid_base[0],[1,1,1,H,W]) | |
basex = np.reshape(meshgrid_base[1],[1,1,1,H,W]) | |
grid = torch.tensor(np.concatenate((basex.reshape((-1,H,W,1)),basey.reshape((-1,H,W,1))),-1)).cuda().float() | |
return grid.view(1,1,H,W,2) | |
class VCN(nn.Module): | |
""" | |
VCN. | |
md defines maximum displacement for each level, following a coarse-to-fine-warping scheme | |
fac defines squeeze parameter for the coarsest level | |
""" | |
def __init__(self, size, md=[4,4,4,4,4], fac=1.,exp_unc=False): # exp_uncertainty | |
super(VCN,self).__init__() | |
self.md = md | |
self.fac = fac | |
use_entropy = True | |
withbn = True | |
## pspnet | |
self.pspnet = pspnet(is_proj=False) | |
### Volumetric-UNet | |
fdima1 = 128 # 6/5/4 | |
fdima2 = 64 # 3/2 | |
fdimb1 = 16 # 6/5/4/3 | |
fdimb2 = 12 # 2 | |
full=False | |
self.f6 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full) | |
self.p6 = sepConv4d(fdimb1,fdimb1, with_bn=False, full=full) | |
self.f5 = butterfly4D(fdima1, fdimb1,withbn=withbn, full=full) | |
self.p5 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) | |
self.f4 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full) | |
self.p4 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) | |
self.f3 = butterfly4D(fdima2, fdimb1,withbn=withbn,full=full) | |
self.p3 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) | |
full=True | |
self.f2 = butterfly4D(fdima2, fdimb2,withbn=withbn,full=full) | |
self.p2 = sepConv4d(fdimb2,fdimb2, with_bn=False,full=full) | |
self.flow_reg64 = flow_reg([fdimb1*size[0],size[1]//64,size[2]//64], ent=use_entropy, maxdisp=self.md[0], fac=self.fac) | |
self.flow_reg32 = flow_reg([fdimb1*size[0],size[1]//32,size[2]//32], ent=use_entropy, maxdisp=self.md[1]) | |
self.flow_reg16 = flow_reg([fdimb1*size[0],size[1]//16,size[2]//16], ent=use_entropy, maxdisp=self.md[2]) | |
self.flow_reg8 = flow_reg([fdimb1*size[0],size[1]//8,size[2]//8] , ent=use_entropy, maxdisp=self.md[3]) | |
self.flow_reg4 = flow_reg([fdimb2*size[0],size[1]//4,size[2]//4] , ent=use_entropy, maxdisp=self.md[4]) | |
self.warp5 = WarpModule([size[0],size[1]//32,size[2]//32]) | |
self.warp4 = WarpModule([size[0],size[1]//16,size[2]//16]) | |
self.warp3 = WarpModule([size[0],size[1]//8,size[2]//8]) | |
self.warp2 = WarpModule([size[0],size[1]//4,size[2]//4]) | |
## hypotheses fusion modules, adopted from the refinement module of PWCNet | |
# https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py | |
# c6 | |
self.dc6_conv1 = conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc6_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) | |
self.dc6_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) | |
self.dc6_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) | |
self.dc6_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) | |
self.dc6_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc6_conv7 = nn.Conv2d(32,2*fdimb1,kernel_size=3,stride=1,padding=1,bias=True) | |
# c5 | |
self.dc5_conv1 = conv(128+4*fdimb1*2, 128, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc5_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) | |
self.dc5_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) | |
self.dc5_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) | |
self.dc5_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) | |
self.dc5_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc5_conv7 = nn.Conv2d(32,2*fdimb1*2,kernel_size=3,stride=1,padding=1,bias=True) | |
# c4 | |
self.dc4_conv1 = conv(128+4*fdimb1*3, 128, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc4_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) | |
self.dc4_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) | |
self.dc4_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) | |
self.dc4_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) | |
self.dc4_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc4_conv7 = nn.Conv2d(32,2*fdimb1*3,kernel_size=3,stride=1,padding=1,bias=True) | |
# c3 | |
self.dc3_conv1 = conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc3_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) | |
self.dc3_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) | |
self.dc3_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) | |
self.dc3_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) | |
self.dc3_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc3_conv7 = nn.Conv2d(32,8*fdimb1,kernel_size=3,stride=1,padding=1,bias=True) | |
# c2 | |
self.dc2_conv1 = conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc2_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) | |
self.dc2_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) | |
self.dc2_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) | |
self.dc2_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) | |
self.dc2_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) | |
self.dc2_conv7 = nn.Conv2d(32,4*2*fdimb1 + 2*fdimb2,kernel_size=3,stride=1,padding=1,bias=True) | |
self.dc6_conv = nn.Sequential( self.dc6_conv1, | |
self.dc6_conv2, | |
self.dc6_conv3, | |
self.dc6_conv4, | |
self.dc6_conv5, | |
self.dc6_conv6, | |
self.dc6_conv7) | |
self.dc5_conv = nn.Sequential( self.dc5_conv1, | |
self.dc5_conv2, | |
self.dc5_conv3, | |
self.dc5_conv4, | |
self.dc5_conv5, | |
self.dc5_conv6, | |
self.dc5_conv7) | |
self.dc4_conv = nn.Sequential( self.dc4_conv1, | |
self.dc4_conv2, | |
self.dc4_conv3, | |
self.dc4_conv4, | |
self.dc4_conv5, | |
self.dc4_conv6, | |
self.dc4_conv7) | |
self.dc3_conv = nn.Sequential( self.dc3_conv1, | |
self.dc3_conv2, | |
self.dc3_conv3, | |
self.dc3_conv4, | |
self.dc3_conv5, | |
self.dc3_conv6, | |
self.dc3_conv7) | |
self.dc2_conv = nn.Sequential( self.dc2_conv1, | |
self.dc2_conv2, | |
self.dc2_conv3, | |
self.dc2_conv4, | |
self.dc2_conv5, | |
self.dc2_conv6, | |
self.dc2_conv7) | |
## Out-of-range detection | |
self.dc6_convo = nn.Sequential(conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), | |
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), | |
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), | |
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), | |
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), | |
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) | |
self.dc5_convo = nn.Sequential(conv(128+2*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), | |
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), | |
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), | |
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), | |
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), | |
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) | |
self.dc4_convo = nn.Sequential(conv(128+3*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), | |
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), | |
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), | |
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), | |
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), | |
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) | |
self.dc3_convo = nn.Sequential(conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), | |
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), | |
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), | |
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), | |
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), | |
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) | |
self.dc2_convo = nn.Sequential(conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1), | |
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), | |
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), | |
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), | |
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), | |
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), | |
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) | |
# affine-exp | |
self.f3d2v1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.f3d2v2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.f3d2v3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.f3d2v4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.f3d2v5 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.f3d2v6 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.f3d2 = bfmodule(128-64,1) | |
# depth change net | |
self.dcnetv1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.dcnetv2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.dcnetv3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.dcnetv4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.dcnetv5 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
self.dcnetv6 = conv(4, 32, kernel_size=3, stride=1, padding=1,dilation=1) # | |
if exp_unc: | |
self.dcnet = bfmodule(128,2) | |
else: | |
self.dcnet = bfmodule(128,1) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv3d): | |
n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
if hasattr(m.bias,'data'): | |
m.bias.data.zero_() | |
self.facs = [self.fac,1,1,1,1] | |
self.warp_modules = nn.ModuleList([None, self.warp5, self.warp4, self.warp3, self.warp2]) | |
self.f_modules = nn.ModuleList([self.f6, self.f5, self.f4, self.f3, self.f2]) | |
self.p_modules = nn.ModuleList([self.p6, self.p5, self.p4, self.p3, self.p2]) | |
self.reg_modules = nn.ModuleList([self.flow_reg64, self.flow_reg32, self.flow_reg16, self.flow_reg8, self.flow_reg4]) | |
self.oor_modules = nn.ModuleList([self.dc6_convo, self.dc5_convo, self.dc4_convo, self.dc3_convo, self.dc2_convo]) | |
self.fuse_modules = nn.ModuleList([self.dc6_conv, self.dc5_conv, self.dc4_conv, self.dc3_conv, self.dc2_conv]) | |
def corrf(self, refimg_fea, targetimg_fea,maxdisp, fac=1): | |
""" | |
slow correlation function | |
""" | |
b,c,height,width = refimg_fea.shape | |
if refimg_fea.is_cuda: | |
cost = Variable(torch.cuda.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w | |
else: | |
cost = Variable(torch.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w | |
for i in range(2*maxdisp+1): | |
ind = i-maxdisp | |
for j in range(2*int(maxdisp//fac)+1): | |
indd = j-int(maxdisp//fac) | |
feata = refimg_fea[:,:,max(0,-indd):height-indd,max(0,-ind):width-ind] | |
featb = targetimg_fea[:,:,max(0,+indd):height+indd,max(0,ind):width+ind] | |
diff = (feata*featb) | |
cost[:, :, i,j,max(0,-indd):height-indd,max(0,-ind):width-ind] = diff # standard | |
cost = F.leaky_relu(cost, 0.1,inplace=True) | |
return cost | |
def cost_matching(self,up_flow, c1, c2, flowh, enth, level): | |
""" | |
up_flow: upsample coarse flow | |
c1: normalized feature of image 1 | |
c2: normalized feature of image 2 | |
flowh: flow hypotheses | |
enth: entropy | |
""" | |
# normalize | |
c1n = c1 / (c1.norm(dim=1, keepdim=True)+1e-9) | |
c2n = c2 / (c2.norm(dim=1, keepdim=True)+1e-9) | |
# cost volume | |
if level == 0: | |
warp = c2n | |
else: | |
warp,_ = self.warp_modules[level](c2n, up_flow) | |
feat = self.corrf(c1n,warp,self.md[level],fac=self.facs[level]) | |
feat = self.f_modules[level](feat) | |
cost = self.p_modules[level](feat) # b, 16, u,v,h,w | |
# soft WTA | |
b,c,u,v,h,w = cost.shape | |
cost = cost.view(-1,u,v,h,w) # bx16, 9,9,h,w, also predict uncertainty from here | |
flowhh,enthh = self.reg_modules[level](cost) # bx16, 2, h, w | |
flowhh = flowhh.view(b,c,2,h,w) | |
if level > 0: | |
flowhh = flowhh + up_flow[:,np.newaxis] | |
flowhh = flowhh.view(b,-1,h,w) # b, 16*2, h, w | |
enthh = enthh.view(b,-1,h,w) # b, 16*1, h, w | |
# append coarse hypotheses | |
if level == 0: | |
flowh = flowhh | |
enth = enthh | |
else: | |
flowh = torch.cat((flowhh, F.upsample(flowh.detach()*2, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) # b, k2--k2, h, w | |
enth = torch.cat((enthh, F.upsample(enth, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) | |
if self.training or level==4: | |
x = torch.cat((enth.detach(), flowh.detach(), c1),1) | |
oor = self.oor_modules[level](x)[:,0] | |
else: oor = None | |
# hypotheses fusion | |
x = torch.cat((enth.detach(), flowh.detach(), c1),1) | |
va = self.fuse_modules[level](x) | |
va = va.view(b,-1,2,h,w) | |
flow = ( flowh.view(b,-1,2,h,w) * F.softmax(va,1) ).sum(1) # b, 2k, 2, h, w | |
return flow, flowh, enth, oor | |
def affine(self,pref,flow, pw=1): | |
b,_,lh,lw=flow.shape | |
ptar = pref + flow | |
pw = 1 | |
pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis] | |
ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w | |
pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) | |
ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) | |
prefprefT = pref.matmul(pref.permute(0,2,1)) | |
ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1] | |
ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis] | |
Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv) | |
Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw) | |
Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf) | |
exp = Avol.sqrt() | |
mask = (exp>0.5) & (exp<2) & (Error<0.1) | |
mask = mask[:,0] | |
exp = exp.clamp(0.5,2) | |
exp[Error>0.1]=1 | |
return exp, Error, mask | |
def affine_mask(self,pref,flow, pw=3): | |
""" | |
pref: reference coordinates | |
pw: patch width | |
""" | |
flmask = flow[:,2:] | |
flow = flow[:,:2] | |
b,_,lh,lw=flow.shape | |
ptar = pref + flow | |
pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis] | |
ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w | |
conf_flow = flmask | |
conf_flow = F.unfold(conf_flow,(pw*2+1,pw*2+1), padding=(pw)).view(b,1,(pw*2+1)**2,lh,lw) | |
count = conf_flow.sum(2,keepdims=True) | |
conf_flow = ((pw*2+1)**2)*conf_flow / count | |
pref = pref * conf_flow | |
ptar = ptar * conf_flow | |
pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) | |
ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) | |
prefprefT = pref.matmul(pref.permute(0,2,1)) | |
ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1] | |
ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis] | |
Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv) | |
Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw) | |
Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf) | |
exp = Avol.sqrt() | |
mask = (exp>0.5) & (exp<2) & (Error<0.2) & (flmask.bool()) & (count[:,0]>4) | |
mask = mask[:,0] | |
exp = exp.clamp(0.5,2) | |
exp[Error>0.2]=1 | |
return exp, Error, mask | |
def weight_parameters(self): | |
return [param for name, param in self.named_parameters() if 'weight' in name] | |
def bias_parameters(self): | |
return [param for name, param in self.named_parameters() if 'bias' in name] | |
def forward(self,im,disc_aux=None): | |
bs = im.shape[0]//2 | |
if self.training and disc_aux[-1]: # if only fine-tuning expansion | |
reset=True | |
self.eval() | |
torch.set_grad_enabled(False) | |
else: reset=False | |
c06,c05,c04,c03,c02 = self.pspnet(im) | |
c16 = c06[:bs]; c26 = c06[bs:] | |
c15 = c05[:bs]; c25 = c05[bs:] | |
c14 = c04[:bs]; c24 = c04[bs:] | |
c13 = c03[:bs]; c23 = c03[bs:] | |
c12 = c02[:bs]; c22 = c02[bs:] | |
## matching 6 | |
flow6, flow6h, ent6h, oor6 = self.cost_matching(None, c16, c26, None, None,level=0) | |
## matching 5 | |
up_flow6 = F.upsample(flow6, [im.size()[2]//32,im.size()[3]//32], mode='bilinear')*2 | |
flow5, flow5h, ent5h, oor5 = self.cost_matching(up_flow6, c15, c25, flow6h, ent6h,level=1) | |
## matching 4 | |
up_flow5 = F.upsample(flow5, [im.size()[2]//16,im.size()[3]//16], mode='bilinear')*2 | |
flow4, flow4h, ent4h, oor4 = self.cost_matching(up_flow5, c14, c24, flow5h, ent5h,level=2) | |
## matching 3 | |
up_flow4 = F.upsample(flow4, [im.size()[2]//8,im.size()[3]//8], mode='bilinear')*2 | |
flow3, flow3h, ent3h, oor3 = self.cost_matching(up_flow4, c13, c23, flow4h, ent4h,level=3) | |
## matching 2 | |
up_flow3 = F.upsample(flow3, [im.size()[2]//4,im.size()[3]//4], mode='bilinear')*2 | |
flow2, flow2h, ent2h, oor2 = self.cost_matching(up_flow3, c12, c22, flow3h, ent3h,level=4) | |
if reset: | |
torch.set_grad_enabled(True) | |
self.train() | |
# expansion | |
b,_,h,w = flow2.shape | |
exp2,err2,_ = self.affine(get_grid(b,h,w)[:,0].permute(0,3,1,2).repeat(b,1,1,1).clone(), flow2.detach(),pw=1) | |
x = torch.cat(( | |
self.f3d2v2(-exp2.log()), | |
self.f3d2v3(err2), | |
),1) | |
dchange2 = -exp2.log()+1./200*self.f3d2(x)[0] | |
# depth change net | |
iexp2 = F.upsample(dchange2.clone(), [im.size()[2],im.size()[3]], mode='bilinear') | |
x = torch.cat((self.dcnetv1(c12.detach()), | |
self.dcnetv2(dchange2.detach()), | |
self.dcnetv3(-exp2.log()), | |
self.dcnetv4(err2), | |
),1) | |
dcneto = 1./200*self.dcnet(x)[0] | |
dchange2 = dchange2.detach() + dcneto[:,:1] | |
flow2 = F.upsample(flow2.detach(), [im.size()[2],im.size()[3]], mode='bilinear')*4 | |
dchange2 = F.upsample(dchange2, [im.size()[2],im.size()[3]], mode='bilinear') | |
if self.training: | |
flowl0 = disc_aux[0].permute(0,3,1,2).clone() | |
gt_depth = disc_aux[2][:,:,:,0] | |
gt_f3d = disc_aux[2][:,:,:,4:7].permute(0,3,1,2).clone() | |
gt_dchange = (1+gt_f3d[:,2]/gt_depth) | |
maskdc = (gt_dchange < 2) & (gt_dchange > 0.5) & disc_aux[1] | |
gt_expi,gt_expi_err,maskoe = self.affine_mask(get_grid(b,4*h,4*w)[:,0].permute(0,3,1,2).repeat(b,1,1,1), flowl0,pw=3) | |
gt_exp = 1./gt_expi[:,0] | |
loss = 0.1* (dchange2[:,0]-gt_dchange.log()).abs()[maskdc].mean() | |
loss += 0.1* (iexp2[:,0]-gt_exp.log()).abs()[maskoe].mean() | |
return flow2*4, flow3*8,flow4*16,flow5*32,flow6*64,loss, dchange2[:,0], iexp2[:,0] | |
else: | |
return flow2, oor2, dchange2, iexp2 | |