|
import torch |
|
import math |
|
import numpy |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
import modules.components.upr_net.correlation as correlation |
|
import modules.components.upr_net.softsplat as softsplat |
|
from modules.components.upr_net.m2m import * |
|
from modules.components.upr_net.backwarp import backwarp |
|
|
|
from ..components import register |
|
|
|
from utils.padder import InputPadder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def photometric_consistency(img0, img1, flow01): |
|
return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) |
|
|
|
|
|
def flow_consistency(flow01, flow10): |
|
return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) |
|
|
|
|
|
gaussian_kernel = torch.tensor([[1, 2, 1], |
|
[2, 4, 2], |
|
[1, 2, 1]]) / 16 |
|
gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) |
|
gaussian_kernel = gaussian_kernel.to("cpu") |
|
|
|
|
|
def gaussian(x): |
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') |
|
out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) |
|
|
|
return out |
|
|
|
|
|
def variance_flow(flow): |
|
flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, |
|
device=flow.device).view(1, 2, 1, 1) |
|
return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) |
|
|
|
|
|
class FeatPyramid(nn.Module): |
|
"""A 3-level feature pyramid, which by default is shared by the motion |
|
estimator and synthesis network. |
|
""" |
|
|
|
def __init__(self): |
|
super(FeatPyramid, self).__init__() |
|
self.conv_stage0 = nn.Sequential( |
|
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
self.conv_stage1 = nn.Sequential( |
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, |
|
stride=2, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
self.conv_stage2 = nn.Sequential( |
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, |
|
stride=2, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
|
|
def forward(self, img): |
|
C0 = self.conv_stage0(img) |
|
C1 = self.conv_stage1(C0) |
|
C2 = self.conv_stage2(C1) |
|
return [C0, C1, C2] |
|
|
|
|
|
|
|
|
|
|
|
class MotionEstimator(nn.Module): |
|
"""Bi-directional optical flow estimator |
|
1) construct partial cost volume with the CNN features from the stage 2 of |
|
the feature pyramid; |
|
2) estimate bi-directional flows, by feeding cost volume, CNN features for |
|
both warped images, CNN feature and estimated flow from previous iteration. |
|
""" |
|
|
|
def __init__(self): |
|
super(MotionEstimator, self).__init__() |
|
|
|
self.conv_layer1 = nn.Sequential( |
|
nn.Conv2d(in_channels=469, out_channels=320, |
|
kernel_size=1, stride=1, padding=0), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
self.conv_layer2 = nn.Sequential( |
|
nn.Conv2d(in_channels=320, out_channels=256, |
|
kernel_size=3, stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
self.conv_layer3 = nn.Sequential( |
|
nn.Conv2d(in_channels=256, out_channels=224, |
|
kernel_size=3, stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
self.conv_layer4 = nn.Sequential( |
|
nn.Conv2d(in_channels=224, out_channels=192, |
|
kernel_size=3, stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
self.conv_layer5 = nn.Sequential( |
|
nn.Conv2d(in_channels=192, out_channels=128, |
|
kernel_size=3, stride=1, padding=1), |
|
nn.LeakyReLU(inplace=False, negative_slope=0.1)) |
|
self.conv_layer6 = nn.Sequential( |
|
nn.Conv2d(in_channels=128, out_channels=4, |
|
kernel_size=3, stride=1, padding=1)) |
|
|
|
def forward(self, feat0, feat1, last_feat, last_flow): |
|
corr_fn = correlation.FunctionCorrelation |
|
feat0 = softsplat.FunctionSoftsplat( |
|
tenInput=feat0, tenFlow=last_flow[:, :2] * 0.5 * 0.24, |
|
tenMetric=None, strType='average') |
|
feat1 = softsplat.FunctionSoftsplat( |
|
tenInput=feat1, tenFlow=last_flow[:, 2:] * 0.5 * 0.24, |
|
tenMetric=None, strType='average') |
|
|
|
volume = F.leaky_relu( |
|
input=corr_fn(tenFirst=feat0, tenSecond=feat1), |
|
negative_slope=0.1, inplace=False) |
|
input_feat = torch.cat([volume, feat0, feat1, last_feat, last_flow], 1) |
|
feat = self.conv_layer1(input_feat) |
|
feat = self.conv_layer2(feat) |
|
feat = self.conv_layer3(feat) |
|
feat = self.conv_layer4(feat) |
|
feat = self.conv_layer5(feat) |
|
flow = self.conv_layer6(feat) |
|
|
|
return flow, feat |
|
|
|
|
|
|
|
|
|
|
|
class SynthesisNetwork(nn.Module): |
|
def __init__(self, splat_mode='average', branch=1): |
|
super(SynthesisNetwork, self).__init__() |
|
input_channels = 9 + 4 + 6 |
|
self.encoder_conv = nn.Sequential( |
|
nn.Conv2d(in_channels=input_channels, out_channels=64, |
|
kernel_size=3, stride=1, padding=1), |
|
nn.PReLU(num_parameters=64), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=64)) |
|
self.encoder_down1 = nn.Sequential( |
|
nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, |
|
kernel_size=3, stride=2, padding=1), |
|
nn.PReLU(num_parameters=128), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=128), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=128)) |
|
self.encoder_down2 = nn.Sequential( |
|
nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, |
|
kernel_size=3, stride=2, padding=1), |
|
nn.PReLU(num_parameters=256), |
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=256), |
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=256)) |
|
self.decoder_up1 = nn.Sequential( |
|
torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, |
|
out_channels=128, kernel_size=4, stride=2, |
|
padding=1, bias=True), |
|
nn.PReLU(num_parameters=128), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=128)) |
|
self.decoder_up2 = nn.Sequential( |
|
torch.nn.ConvTranspose2d(in_channels=128 + 128, |
|
out_channels=64, kernel_size=4, stride=2, |
|
padding=1, bias=True), |
|
nn.PReLU(num_parameters=64), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=64)) |
|
self.decoder_conv = nn.Sequential( |
|
nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=64), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, |
|
stride=1, padding=1), |
|
nn.PReLU(num_parameters=64)) |
|
self.pred = nn.Conv2d(in_channels=64, out_channels=5, kernel_size=3, |
|
stride=1, padding=1) |
|
self.splat_mode = splat_mode |
|
self.branch = branch |
|
|
|
class MotionRefineNet(torch.nn.Module): |
|
def __init__(self, branch): |
|
super(MotionRefineNet, self).__init__() |
|
self.branch = branch |
|
self.img_pyramid = ImgPyramid() |
|
self.motion_encdec = EncDec(branch) |
|
|
|
def forward(self, flow0, flow1, im0, im1): |
|
c0 = self.img_pyramid(im0) |
|
c1 = self.img_pyramid(im1) |
|
|
|
flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) |
|
|
|
flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] |
|
flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] |
|
|
|
return flow0, flow1, flow_res[2], flow_res[3] |
|
if self.branch > 1: |
|
|
|
self.convblock = nn.Sequential( |
|
nn.Conv2d(32 * 2 + 2, 32, 3, 1, 1), |
|
nn.ReLU(), |
|
ResBlock(32, 16), |
|
nn.Conv2d(32, 3 * self.branch, 3, 1, 1, bias=False) |
|
) |
|
if self.splat_mode == 'softmax' or branch > 1: |
|
|
|
self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) |
|
self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) |
|
self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) |
|
self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) |
|
|
|
def get_splat_weight(self, img0, img1, flow01, flow10): |
|
if self.splat_mode == 'softmax' or self.branch > 1: |
|
M_splat = 1 / (1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ |
|
1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ |
|
1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) |
|
return M_splat * self.alpha |
|
else: |
|
return None |
|
|
|
def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): |
|
flow_0t = bi_flow[:, :2] * time_period |
|
flow_1t = bi_flow[:, 2:4] * (1 - time_period) |
|
warped_c0 = softsplat.FunctionSoftsplat( |
|
tenInput=c0, tenFlow=flow_0t, |
|
tenMetric=None, strType='average') |
|
warped_c1 = softsplat.FunctionSoftsplat( |
|
tenInput=c1, tenFlow=flow_1t, |
|
tenMetric=None, strType='average') |
|
if (i0 is None) and (i1 is None): |
|
return warped_c0, warped_c1 |
|
else: |
|
warped_img0 = softsplat.FunctionSoftsplat( |
|
tenInput=i0, tenFlow=flow_0t, |
|
tenMetric=m_splat_0, strType=self.splat_mode) |
|
warped_img1 = softsplat.FunctionSoftsplat( |
|
tenInput=i1, tenFlow=flow_1t, |
|
tenMetric=m_splat_1, strType=self.splat_mode) |
|
flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) |
|
return warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t |
|
|
|
def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): |
|
m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) |
|
m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) |
|
if multi_flow: |
|
tenFwd = bi_flow_pyr[0][:, :2] |
|
tenBwd = bi_flow_pyr[0][:, 2:4] |
|
|
|
c0_warp = backwarp(c0_pyr[0], tenBwd) |
|
c1_warp = backwarp(c1_pyr[0], tenFwd) |
|
out0 = self.convblock(torch.cat([c0_pyr[0], c1_warp, tenFwd], 1)) |
|
out1 = self.convblock(torch.cat([c1_pyr[0], c0_warp, tenBwd], 1)) |
|
delta_flow_fwd, WeiMF = torch.split(out0, [2 * self.branch, self.branch], 1) |
|
delta_flow_bwd, WeiMB = torch.split(out1, [2 * self.branch, self.branch], 1) |
|
|
|
tenFwd = delta_flow_fwd + tenFwd.repeat(1, self.branch, 1, 1) |
|
tenBwd = delta_flow_bwd + tenBwd.repeat(1, self.branch, 1, 1) |
|
N_, _, H_, W_ = i0.shape |
|
|
|
i0_ = i0.repeat(1, self.branch, 1, 1) |
|
i1_ = i1.repeat(1, self.branch, 1, 1) |
|
|
|
fltTime = time_period.repeat(1, self.branch, 1, 1) |
|
|
|
tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) |
|
tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) |
|
|
|
WeiMF = WeiMF.view(N_, self.branch, 1, H_, W_).reshape(N_ * self.branch, 1, H_, W_) |
|
WeiMB = WeiMB.view(N_, self.branch, 1, H_, W_).reshape(N_ * self.branch, 1, H_, W_) |
|
|
|
i0_ = i0_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) |
|
i1_ = i1_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) |
|
|
|
fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) |
|
|
|
tenPhotoone = self.get_splat_weight(i0_, i1_, tenFwd, tenBwd) * WeiMF |
|
tenPhototwo = self.get_splat_weight(i1_, i0_, tenBwd, tenFwd) * WeiMB |
|
|
|
t0 = fltTime |
|
flow0 = tenFwd * t0 |
|
metric0 = tenPhotoone |
|
|
|
t1 = 1.0 - fltTime |
|
flow1 = tenBwd * t1 |
|
metric1 = tenPhototwo |
|
|
|
flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) |
|
flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) |
|
|
|
metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) |
|
metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) |
|
|
|
i0_ = i0_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) |
|
i1_ = i1_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) |
|
|
|
t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) |
|
t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) |
|
flow0, flow1 = flow0.contiguous(), flow1.contiguous() |
|
|
|
tenOutputF, maskF, tenOutputB, maskB = forwarp_mframe_mask(i0_, flow0, t0, i1_, flow1, t1, metric0, metric1) |
|
|
|
warped_img0 = tenOutputF + maskF * i0 |
|
warped_img1 = tenOutputB + maskB * i1 |
|
warped_c0, warped_c1 = \ |
|
self.get_warped_representations( |
|
bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, |
|
time_period=time_period) |
|
flow_0t = bi_flow_pyr[0][:, :2] * time_period |
|
flow_1t = bi_flow_pyr[0][:, 2:4] * (1 - time_period) |
|
flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) |
|
else: |
|
warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ |
|
self.get_warped_representations( |
|
bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, |
|
time_period=time_period) |
|
input_feat = torch.cat( |
|
(last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) |
|
s0 = self.encoder_conv(input_feat) |
|
s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) |
|
warped_c0, warped_c1 = self.get_warped_representations( |
|
bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, |
|
time_period=time_period) |
|
s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) |
|
warped_c0, warped_c1 = self.get_warped_representations( |
|
bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, |
|
time_period=time_period) |
|
|
|
x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) |
|
x = self.decoder_up2(torch.cat((x, s1), 1)) |
|
x = self.decoder_conv(torch.cat((x, s0), 1)) |
|
|
|
|
|
refine = self.pred(x) |
|
refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 |
|
refine_mask0 = torch.sigmoid(refine[:, 3:4]) |
|
refine_mask1 = torch.sigmoid(refine[:, 4:5]) |
|
merged_img = (warped_img0 * refine_mask0 * (1 - time_period) + \ |
|
warped_img1 * refine_mask1 * time_period) |
|
merged_img = merged_img / (refine_mask0 * (1 - time_period) + \ |
|
refine_mask1 * time_period) |
|
interp_img = merged_img + refine_res |
|
interp_img = torch.clamp(interp_img, 0, 1) |
|
|
|
extra_dict = {} |
|
extra_dict["refine_res"] = refine_res |
|
extra_dict["warped_img0"] = warped_img0 |
|
extra_dict["warped_img1"] = warped_img1 |
|
extra_dict["merged_img"] = merged_img |
|
if multi_flow: |
|
extra_dict['tenFwd'] = tenFwd.view(N_, self.branch, 2, H_, W_) |
|
extra_dict['tenBwd'] = tenBwd.view(N_, self.branch, 2, H_, W_) |
|
|
|
return interp_img, extra_dict |
|
|
|
|
|
|
|
|
|
|
|
@register('upr_net') |
|
class Model(nn.Module): |
|
def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average', branch=1): |
|
super(Model, self).__init__() |
|
self.pyr_level = pyr_level |
|
self.feat_pyramid = FeatPyramid() |
|
self.nr_lvl_skipped = nr_lvl_skipped |
|
self.motion_estimator = MotionEstimator() |
|
self.synthesis_network = SynthesisNetwork(splat_mode, branch) |
|
self.splat_mode = splat_mode |
|
self.branch = branch |
|
|
|
def forward_one_lvl(self, |
|
img0, img1, last_feat, last_flow, last_interp=None, |
|
time_period=0.5, skip_me=False, multi_flow=False): |
|
|
|
|
|
feat0_pyr = self.feat_pyramid(img0) |
|
feat1_pyr = self.feat_pyramid(img1) |
|
|
|
|
|
if not skip_me: |
|
flow, feat = self.motion_estimator( |
|
feat0_pyr[-1], feat1_pyr[-1], |
|
last_feat, last_flow) |
|
else: |
|
flow = last_flow |
|
feat = last_feat |
|
|
|
|
|
|
|
ori_resolution_flow = F.interpolate( |
|
input=flow, scale_factor=4.0, |
|
mode="bilinear", align_corners=False) |
|
|
|
|
|
bi_flow_pyr = [] |
|
tmp_flow = ori_resolution_flow |
|
bi_flow_pyr.append(tmp_flow) |
|
for i in range(2): |
|
tmp_flow = F.interpolate( |
|
input=tmp_flow, scale_factor=0.5, |
|
mode="bilinear", align_corners=False) * 0.5 |
|
bi_flow_pyr.append(tmp_flow) |
|
|
|
|
|
if last_interp is None: |
|
flow_0t = ori_resolution_flow[:, :2] * time_period |
|
flow_1t = ori_resolution_flow[:, 2:4] * (1 - time_period) |
|
warped_img0 = softsplat.FunctionSoftsplat( |
|
tenInput=img0, tenFlow=flow_0t, |
|
tenMetric=None, strType='average') |
|
warped_img1 = softsplat.FunctionSoftsplat( |
|
tenInput=img1, tenFlow=flow_1t, |
|
tenMetric=None, strType='average') |
|
last_interp = warped_img0 * (1 - time_period) \ |
|
+ warped_img1 * time_period |
|
|
|
|
|
interp_img, extra_dict = self.synthesis_network( |
|
last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, |
|
time_period=time_period, multi_flow=multi_flow) |
|
return flow, feat, interp_img, extra_dict |
|
|
|
def forward(self, img0, img1, time_step, |
|
pyr_level=None, nr_lvl_skipped=None, **kwargs): |
|
|
|
if pyr_level is None: pyr_level = self.pyr_level |
|
if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped |
|
N, _, H, W = img0.shape |
|
flow0_pred = [] |
|
flow1_pred = [] |
|
interp_imgs = [] |
|
skipped_levels = [] if nr_lvl_skipped == 0 else \ |
|
list(range(pyr_level))[::-1][-nr_lvl_skipped:] |
|
|
|
padder = InputPadder(img0.shape, divisor=int(4 * 2**pyr_level)) |
|
img0, img1 = padder.pad(img0, img1) |
|
N, _, H, W = img0.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for level in list(range(pyr_level))[::-1]: |
|
if level != 0: |
|
scale_factor = 1 / 2 ** level |
|
img0_this_lvl = F.interpolate( |
|
input=img0, scale_factor=scale_factor, |
|
mode="bilinear", align_corners=False) |
|
img1_this_lvl = F.interpolate( |
|
input=img1, scale_factor=scale_factor, |
|
mode="bilinear", align_corners=False) |
|
else: |
|
img0_this_lvl = img0 |
|
img1_this_lvl = img1 |
|
|
|
|
|
skip_me = False |
|
|
|
|
|
if level == pyr_level - 1: |
|
last_flow = torch.zeros( |
|
(N, 4, H // (2 ** (level + 2)), W // (2 ** (level + 2))) |
|
).to(img0.device) |
|
last_feat = torch.zeros( |
|
(N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) |
|
).to(img0.device) |
|
last_interp = None |
|
|
|
elif level in skipped_levels[:-1]: |
|
continue |
|
|
|
elif (level == 0) and len(skipped_levels) > 0: |
|
if len(skipped_levels) == pyr_level: |
|
last_flow = torch.zeros( |
|
(N, 4, H // 4, W // 4)).to(img0.device) |
|
last_interp = None |
|
else: |
|
resize_factor = 2 ** len(skipped_levels) |
|
last_flow = F.interpolate( |
|
input=flow, scale_factor=resize_factor, |
|
mode="bilinear", align_corners=False) * resize_factor |
|
last_interp = F.interpolate( |
|
input=interp_img, scale_factor=resize_factor, |
|
mode="bilinear", align_corners=False) |
|
skip_me = True |
|
|
|
|
|
else: |
|
last_flow = F.interpolate(input=flow, scale_factor=2.0, |
|
mode="bilinear", align_corners=False) * 2 |
|
last_feat = F.interpolate(input=feat, scale_factor=2.0, |
|
mode="bilinear", align_corners=False) * 2 |
|
last_interp = F.interpolate( |
|
input=interp_img, scale_factor=2.0, |
|
mode="bilinear", align_corners=False) |
|
|
|
flow, feat, interp_img, extra_dict = self.forward_one_lvl( |
|
img0_this_lvl, img1_this_lvl, |
|
last_feat, last_flow, last_interp, |
|
time_step, skip_me=skip_me, multi_flow=(self.branch > 1 and level == 0)) |
|
if level == 0 and self.branch > 1: |
|
flow0_pred.append(extra_dict['tenFwd']) |
|
flow1_pred.append(extra_dict['tenBwd']) |
|
elif level == 0 and self.branch == 1: |
|
flow0_pred.append( |
|
F.interpolate(input=flow[:, :2], scale_factor=4.0, |
|
mode="bilinear", align_corners=False).unsqueeze(1) * 4 * 0.5) |
|
flow1_pred.append( |
|
F.interpolate(input=flow[:, 2:], scale_factor=4.0, |
|
mode="bilinear", align_corners=False).unsqueeze(1) * 4 * 0.5) |
|
else: |
|
flow0_pred.append( |
|
padder.unpad(F.interpolate(input=flow[:, :2], scale_factor=4.0, |
|
mode="bilinear", align_corners=False) * 4 * 0.5)) |
|
flow1_pred.append( |
|
padder.unpad(F.interpolate(input=flow[:, 2:], scale_factor=4.0, |
|
mode="bilinear", align_corners=False) * 4 * 0.5)) |
|
interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level))) |
|
|
|
|
|
|
|
|
|
interp_img = padder.unpad(interp_img) |
|
|
|
return {"imgt_preds": interp_imgs[-2:], "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], |
|
'imgt_pred': interp_img, "flowfwd": flow0_pred[-1][:, 0], "flowbwd": flow1_pred[-1][:, 0]} |
|
|
|
|
|
if __name__ == "__main__": |
|
pass |
|
|