import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import os os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory import numpy as np import math import pdb import time from .submodule import pspnet, bfmodule, conv from .conv4d import sepConv4d, sepConv4dBlock, butterfly4D class flow_reg(nn.Module): """ Soft winner-take-all that selects the most likely diplacement. Set ent=True to enable entropy output. Set maxdisp to adjust maximum allowed displacement towards one side. maxdisp=4 searches for a 9x9 region. Set fac to squeeze search window. maxdisp=4 and fac=2 gives search window of 9x5 """ def __init__(self, size, ent=False, maxdisp = int(4), fac=1): B,W,H = size super(flow_reg, self).__init__() self.ent = ent self.md = maxdisp self.fac = fac self.truncated = True self.wsize = 3 # by default using truncation 7x7 flowrangey = range(-maxdisp,maxdisp+1) flowrangex = range(-int(maxdisp//self.fac),int(maxdisp//self.fac)+1) meshgrid = np.meshgrid(flowrangex,flowrangey) flowy = np.tile( np.reshape(meshgrid[0],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) ) flowx = np.tile( np.reshape(meshgrid[1],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) ) self.register_buffer('flowx',torch.Tensor(flowx)) self.register_buffer('flowy',torch.Tensor(flowy)) self.pool3d = nn.MaxPool3d((self.wsize*2+1,self.wsize*2+1,1),stride=1,padding=(self.wsize,self.wsize,0)) def forward(self, x): b,u,v,h,w = x.shape oldx = x if self.truncated: # truncated softmax x = x.view(b,u*v,h,w) idx = x.argmax(1)[:,np.newaxis] if x.is_cuda: mask = Variable(torch.cuda.HalfTensor(b,u*v,h,w)).fill_(0) else: mask = Variable(torch.FloatTensor(b,u*v,h,w)).fill_(0) mask.scatter_(1,idx,1) mask = mask.view(b,1,u,v,-1) mask = self.pool3d(mask)[:,0].view(b,u,v,h,w) ninf = x.clone().fill_(-np.inf).view(b,u,v,h,w) x = torch.where(mask.byte(),oldx,ninf) else: self.wsize = (np.sqrt(u*v)-1)/2 b,u,v,h,w = x.shape x = F.softmax(x.view(b,-1,h,w),1).view(b,u,v,h,w) outx = torch.sum(torch.sum(x*self.flowx,1),1,keepdim=True) outy = torch.sum(torch.sum(x*self.flowy,1),1,keepdim=True) if self.ent: # local local_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis] if self.wsize == 0: local_entropy[:] = 1. else: local_entropy /= np.log((self.wsize*2+1)**2) # global x = F.softmax(oldx.view(b,-1,h,w),1).view(b,u,v,h,w) global_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis] global_entropy /= np.log(x.shape[1]*x.shape[2]) return torch.cat([outx,outy],1),torch.cat([local_entropy, global_entropy],1) else: return torch.cat([outx,outy],1),None class WarpModule(nn.Module): """ taken from https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py """ def __init__(self, size): super(WarpModule, self).__init__() B,W,H = size # mesh grid xx = torch.arange(0, W).view(1,-1).repeat(H,1) yy = torch.arange(0, H).view(-1,1).repeat(1,W) xx = xx.view(1,1,H,W).repeat(B,1,1,1) yy = yy.view(1,1,H,W).repeat(B,1,1,1) self.register_buffer('grid',torch.cat((xx,yy),1).float()) def forward(self, x, flo): """ warp an image/tensor (im2) back to im1, according to the optical flow x: [B, C, H, W] (im2) flo: [B, 2, H, W] flow """ B, C, H, W = x.size() vgrid = self.grid + flo # scale grid to [-1,1] vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0 vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0 vgrid = vgrid.permute(0,2,3,1) output = nn.functional.grid_sample(x, vgrid,align_corners=True) mask = ((vgrid[:,:,:,0].abs()<1) * (vgrid[:,:,:,1].abs()<1)) >0 return output*mask.unsqueeze(1).float(), mask def get_grid(B,H,W): meshgrid_base = np.meshgrid(range(0,W), range(0,H))[::-1] basey = np.reshape(meshgrid_base[0],[1,1,1,H,W]) basex = np.reshape(meshgrid_base[1],[1,1,1,H,W]) grid = torch.tensor(np.concatenate((basex.reshape((-1,H,W,1)),basey.reshape((-1,H,W,1))),-1)).cuda().float() return grid.view(1,1,H,W,2) class VCN(nn.Module): """ VCN. md defines maximum displacement for each level, following a coarse-to-fine-warping scheme fac defines squeeze parameter for the coarsest level """ def __init__(self, size, md=[4,4,4,4,4], fac=1.,exp_unc=False): # exp_uncertainty super(VCN,self).__init__() self.md = md self.fac = fac use_entropy = True withbn = True ## pspnet self.pspnet = pspnet(is_proj=False) ### Volumetric-UNet fdima1 = 128 # 6/5/4 fdima2 = 64 # 3/2 fdimb1 = 16 # 6/5/4/3 fdimb2 = 12 # 2 full=False self.f6 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full) self.p6 = sepConv4d(fdimb1,fdimb1, with_bn=False, full=full) self.f5 = butterfly4D(fdima1, fdimb1,withbn=withbn, full=full) self.p5 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) self.f4 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full) self.p4 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) self.f3 = butterfly4D(fdima2, fdimb1,withbn=withbn,full=full) self.p3 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) full=True self.f2 = butterfly4D(fdima2, fdimb2,withbn=withbn,full=full) self.p2 = sepConv4d(fdimb2,fdimb2, with_bn=False,full=full) self.flow_reg64 = flow_reg([fdimb1*size[0],size[1]//64,size[2]//64], ent=use_entropy, maxdisp=self.md[0], fac=self.fac) self.flow_reg32 = flow_reg([fdimb1*size[0],size[1]//32,size[2]//32], ent=use_entropy, maxdisp=self.md[1]) self.flow_reg16 = flow_reg([fdimb1*size[0],size[1]//16,size[2]//16], ent=use_entropy, maxdisp=self.md[2]) self.flow_reg8 = flow_reg([fdimb1*size[0],size[1]//8,size[2]//8] , ent=use_entropy, maxdisp=self.md[3]) self.flow_reg4 = flow_reg([fdimb2*size[0],size[1]//4,size[2]//4] , ent=use_entropy, maxdisp=self.md[4]) self.warp5 = WarpModule([size[0],size[1]//32,size[2]//32]) self.warp4 = WarpModule([size[0],size[1]//16,size[2]//16]) self.warp3 = WarpModule([size[0],size[1]//8,size[2]//8]) self.warp2 = WarpModule([size[0],size[1]//4,size[2]//4]) ## hypotheses fusion modules, adopted from the refinement module of PWCNet # https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py # c6 self.dc6_conv1 = conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1) self.dc6_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) self.dc6_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) self.dc6_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) self.dc6_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) self.dc6_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) self.dc6_conv7 = nn.Conv2d(32,2*fdimb1,kernel_size=3,stride=1,padding=1,bias=True) # c5 self.dc5_conv1 = conv(128+4*fdimb1*2, 128, kernel_size=3, stride=1, padding=1, dilation=1) self.dc5_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) self.dc5_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) self.dc5_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) self.dc5_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) self.dc5_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) self.dc5_conv7 = nn.Conv2d(32,2*fdimb1*2,kernel_size=3,stride=1,padding=1,bias=True) # c4 self.dc4_conv1 = conv(128+4*fdimb1*3, 128, kernel_size=3, stride=1, padding=1, dilation=1) self.dc4_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) self.dc4_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) self.dc4_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) self.dc4_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) self.dc4_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) self.dc4_conv7 = nn.Conv2d(32,2*fdimb1*3,kernel_size=3,stride=1,padding=1,bias=True) # c3 self.dc3_conv1 = conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1) self.dc3_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) self.dc3_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) self.dc3_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) self.dc3_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) self.dc3_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) self.dc3_conv7 = nn.Conv2d(32,8*fdimb1,kernel_size=3,stride=1,padding=1,bias=True) # c2 self.dc2_conv1 = conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1) self.dc2_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) self.dc2_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) self.dc2_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) self.dc2_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) self.dc2_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) self.dc2_conv7 = nn.Conv2d(32,4*2*fdimb1 + 2*fdimb2,kernel_size=3,stride=1,padding=1,bias=True) self.dc6_conv = nn.Sequential( self.dc6_conv1, self.dc6_conv2, self.dc6_conv3, self.dc6_conv4, self.dc6_conv5, self.dc6_conv6, self.dc6_conv7) self.dc5_conv = nn.Sequential( self.dc5_conv1, self.dc5_conv2, self.dc5_conv3, self.dc5_conv4, self.dc5_conv5, self.dc5_conv6, self.dc5_conv7) self.dc4_conv = nn.Sequential( self.dc4_conv1, self.dc4_conv2, self.dc4_conv3, self.dc4_conv4, self.dc4_conv5, self.dc4_conv6, self.dc4_conv7) self.dc3_conv = nn.Sequential( self.dc3_conv1, self.dc3_conv2, self.dc3_conv3, self.dc3_conv4, self.dc3_conv5, self.dc3_conv6, self.dc3_conv7) self.dc2_conv = nn.Sequential( self.dc2_conv1, self.dc2_conv2, self.dc2_conv3, self.dc2_conv4, self.dc2_conv5, self.dc2_conv6, self.dc2_conv7) ## Out-of-range detection self.dc6_convo = nn.Sequential(conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) self.dc5_convo = nn.Sequential(conv(128+2*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) self.dc4_convo = nn.Sequential(conv(128+3*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) self.dc3_convo = nn.Sequential(conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) self.dc2_convo = nn.Sequential(conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1), conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) # affine-exp self.f3d2v1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.f3d2v2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.f3d2v3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.f3d2v4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.f3d2v5 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.f3d2v6 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.f3d2 = bfmodule(128-64,1) # depth change net self.dcnetv1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.dcnetv2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.dcnetv3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.dcnetv4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.dcnetv5 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) # self.dcnetv6 = conv(4, 32, kernel_size=3, stride=1, padding=1,dilation=1) # if exp_unc: self.dcnet = bfmodule(128,2) else: self.dcnet = bfmodule(128,1) for m in self.modules(): if isinstance(m, nn.Conv3d): n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if hasattr(m.bias,'data'): m.bias.data.zero_() self.facs = [self.fac,1,1,1,1] self.warp_modules = nn.ModuleList([None, self.warp5, self.warp4, self.warp3, self.warp2]) self.f_modules = nn.ModuleList([self.f6, self.f5, self.f4, self.f3, self.f2]) self.p_modules = nn.ModuleList([self.p6, self.p5, self.p4, self.p3, self.p2]) self.reg_modules = nn.ModuleList([self.flow_reg64, self.flow_reg32, self.flow_reg16, self.flow_reg8, self.flow_reg4]) self.oor_modules = nn.ModuleList([self.dc6_convo, self.dc5_convo, self.dc4_convo, self.dc3_convo, self.dc2_convo]) self.fuse_modules = nn.ModuleList([self.dc6_conv, self.dc5_conv, self.dc4_conv, self.dc3_conv, self.dc2_conv]) def corrf(self, refimg_fea, targetimg_fea,maxdisp, fac=1): """ slow correlation function """ b,c,height,width = refimg_fea.shape if refimg_fea.is_cuda: cost = Variable(torch.cuda.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w else: cost = Variable(torch.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w for i in range(2*maxdisp+1): ind = i-maxdisp for j in range(2*int(maxdisp//fac)+1): indd = j-int(maxdisp//fac) feata = refimg_fea[:,:,max(0,-indd):height-indd,max(0,-ind):width-ind] featb = targetimg_fea[:,:,max(0,+indd):height+indd,max(0,ind):width+ind] diff = (feata*featb) cost[:, :, i,j,max(0,-indd):height-indd,max(0,-ind):width-ind] = diff # standard cost = F.leaky_relu(cost, 0.1,inplace=True) return cost def cost_matching(self,up_flow, c1, c2, flowh, enth, level): """ up_flow: upsample coarse flow c1: normalized feature of image 1 c2: normalized feature of image 2 flowh: flow hypotheses enth: entropy """ # normalize c1n = c1 / (c1.norm(dim=1, keepdim=True)+1e-9) c2n = c2 / (c2.norm(dim=1, keepdim=True)+1e-9) # cost volume if level == 0: warp = c2n else: warp,_ = self.warp_modules[level](c2n, up_flow) feat = self.corrf(c1n,warp,self.md[level],fac=self.facs[level]) feat = self.f_modules[level](feat) cost = self.p_modules[level](feat) # b, 16, u,v,h,w # soft WTA b,c,u,v,h,w = cost.shape cost = cost.view(-1,u,v,h,w) # bx16, 9,9,h,w, also predict uncertainty from here flowhh,enthh = self.reg_modules[level](cost) # bx16, 2, h, w flowhh = flowhh.view(b,c,2,h,w) if level > 0: flowhh = flowhh + up_flow[:,np.newaxis] flowhh = flowhh.view(b,-1,h,w) # b, 16*2, h, w enthh = enthh.view(b,-1,h,w) # b, 16*1, h, w # append coarse hypotheses if level == 0: flowh = flowhh enth = enthh else: flowh = torch.cat((flowhh, F.upsample(flowh.detach()*2, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) # b, k2--k2, h, w enth = torch.cat((enthh, F.upsample(enth, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) if self.training or level==4: x = torch.cat((enth.detach(), flowh.detach(), c1),1) oor = self.oor_modules[level](x)[:,0] else: oor = None # hypotheses fusion x = torch.cat((enth.detach(), flowh.detach(), c1),1) va = self.fuse_modules[level](x) va = va.view(b,-1,2,h,w) flow = ( flowh.view(b,-1,2,h,w) * F.softmax(va,1) ).sum(1) # b, 2k, 2, h, w return flow, flowh, enth, oor def affine(self,pref,flow, pw=1): b,_,lh,lw=flow.shape ptar = pref + flow pw = 1 pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis] ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) prefprefT = pref.matmul(pref.permute(0,2,1)) ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1] ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis] Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv) Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw) Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf) exp = Avol.sqrt() mask = (exp>0.5) & (exp<2) & (Error<0.1) mask = mask[:,0] exp = exp.clamp(0.5,2) exp[Error>0.1]=1 return exp, Error, mask def affine_mask(self,pref,flow, pw=3): """ pref: reference coordinates pw: patch width """ flmask = flow[:,2:] flow = flow[:,:2] b,_,lh,lw=flow.shape ptar = pref + flow pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis] ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w conf_flow = flmask conf_flow = F.unfold(conf_flow,(pw*2+1,pw*2+1), padding=(pw)).view(b,1,(pw*2+1)**2,lh,lw) count = conf_flow.sum(2,keepdims=True) conf_flow = ((pw*2+1)**2)*conf_flow / count pref = pref * conf_flow ptar = ptar * conf_flow pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) prefprefT = pref.matmul(pref.permute(0,2,1)) ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1] ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis] Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv) Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw) Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf) exp = Avol.sqrt() mask = (exp>0.5) & (exp<2) & (Error<0.2) & (flmask.bool()) & (count[:,0]>4) mask = mask[:,0] exp = exp.clamp(0.5,2) exp[Error>0.2]=1 return exp, Error, mask def weight_parameters(self): return [param for name, param in self.named_parameters() if 'weight' in name] def bias_parameters(self): return [param for name, param in self.named_parameters() if 'bias' in name] def forward(self,im,disc_aux=None): bs = im.shape[0]//2 if self.training and disc_aux[-1]: # if only fine-tuning expansion reset=True self.eval() torch.set_grad_enabled(False) else: reset=False c06,c05,c04,c03,c02 = self.pspnet(im) c16 = c06[:bs]; c26 = c06[bs:] c15 = c05[:bs]; c25 = c05[bs:] c14 = c04[:bs]; c24 = c04[bs:] c13 = c03[:bs]; c23 = c03[bs:] c12 = c02[:bs]; c22 = c02[bs:] ## matching 6 flow6, flow6h, ent6h, oor6 = self.cost_matching(None, c16, c26, None, None,level=0) ## matching 5 up_flow6 = F.upsample(flow6, [im.size()[2]//32,im.size()[3]//32], mode='bilinear')*2 flow5, flow5h, ent5h, oor5 = self.cost_matching(up_flow6, c15, c25, flow6h, ent6h,level=1) ## matching 4 up_flow5 = F.upsample(flow5, [im.size()[2]//16,im.size()[3]//16], mode='bilinear')*2 flow4, flow4h, ent4h, oor4 = self.cost_matching(up_flow5, c14, c24, flow5h, ent5h,level=2) ## matching 3 up_flow4 = F.upsample(flow4, [im.size()[2]//8,im.size()[3]//8], mode='bilinear')*2 flow3, flow3h, ent3h, oor3 = self.cost_matching(up_flow4, c13, c23, flow4h, ent4h,level=3) ## matching 2 up_flow3 = F.upsample(flow3, [im.size()[2]//4,im.size()[3]//4], mode='bilinear')*2 flow2, flow2h, ent2h, oor2 = self.cost_matching(up_flow3, c12, c22, flow3h, ent3h,level=4) if reset: torch.set_grad_enabled(True) self.train() # expansion b,_,h,w = flow2.shape exp2,err2,_ = self.affine(get_grid(b,h,w)[:,0].permute(0,3,1,2).repeat(b,1,1,1).clone(), flow2.detach(),pw=1) x = torch.cat(( self.f3d2v2(-exp2.log()), self.f3d2v3(err2), ),1) dchange2 = -exp2.log()+1./200*self.f3d2(x)[0] # depth change net iexp2 = F.upsample(dchange2.clone(), [im.size()[2],im.size()[3]], mode='bilinear') x = torch.cat((self.dcnetv1(c12.detach()), self.dcnetv2(dchange2.detach()), self.dcnetv3(-exp2.log()), self.dcnetv4(err2), ),1) dcneto = 1./200*self.dcnet(x)[0] dchange2 = dchange2.detach() + dcneto[:,:1] flow2 = F.upsample(flow2.detach(), [im.size()[2],im.size()[3]], mode='bilinear')*4 dchange2 = F.upsample(dchange2, [im.size()[2],im.size()[3]], mode='bilinear') if self.training: flowl0 = disc_aux[0].permute(0,3,1,2).clone() gt_depth = disc_aux[2][:,:,:,0] gt_f3d = disc_aux[2][:,:,:,4:7].permute(0,3,1,2).clone() gt_dchange = (1+gt_f3d[:,2]/gt_depth) maskdc = (gt_dchange < 2) & (gt_dchange > 0.5) & disc_aux[1] gt_expi,gt_expi_err,maskoe = self.affine_mask(get_grid(b,4*h,4*w)[:,0].permute(0,3,1,2).repeat(b,1,1,1), flowl0,pw=3) gt_exp = 1./gt_expi[:,0] loss = 0.1* (dchange2[:,0]-gt_dchange.log()).abs()[maskdc].mean() loss += 0.1* (iexp2[:,0]-gt_exp.log()).abs()[maskoe].mean() return flow2*4, flow3*8,flow4*16,flow5*32,flow6*64,loss, dchange2[:,0], iexp2[:,0] else: return flow2, oor2, dchange2, iexp2