CoNR / model /shader.py
p2oileen's picture
initial commit
c34ed4d
import torch
import torch.nn as nn
import torch.nn.functional as F
from .warplayer import warp_features
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DecoderBlock(nn.Module):
def __init__(self, in_planes, c=224, out_msgs=0, out_locals=0, block_nums=1, out_masks=1, out_local_flows=32, out_msgs_flows=32, out_feat_flows=0):
super(DecoderBlock, self).__init__()
self.conv0 = nn.Sequential(
nn.Conv2d(in_planes, c, 3, 2, 1),
nn.PReLU(c),
nn.Conv2d(c, c, 3, 2, 1),
nn.PReLU(c),
)
self.convblocks = nn.ModuleList()
for i in range(block_nums):
self.convblocks.append(nn.Sequential(
nn.Conv2d(c, c, 3, 1, 1),
nn.PReLU(c),
nn.Conv2d(c, c, 3, 1, 1),
nn.PReLU(c),
nn.Conv2d(c, c, 3, 1, 1),
nn.PReLU(c),
nn.Conv2d(c, c, 3, 1, 1),
nn.PReLU(c),
nn.Conv2d(c, c, 3, 1, 1),
nn.PReLU(c),
nn.Conv2d(c, c, 3, 1, 1),
nn.PReLU(c),
))
self.out_flows = 2
self.out_msgs = out_msgs
self.out_msgs_flows = out_msgs_flows if out_msgs > 0 else 0
self.out_locals = out_locals
self.out_local_flows = out_local_flows if out_locals > 0 else 0
self.out_masks = out_masks
self.out_feat_flows = out_feat_flows
self.conv_last = nn.Sequential(
nn.ConvTranspose2d(c, c, 4, 2, 1),
nn.PReLU(c),
nn.ConvTranspose2d(c, self.out_flows+self.out_msgs+self.out_msgs_flows +
self.out_locals+self.out_local_flows+self.out_masks+self.out_feat_flows, 4, 2, 1),
)
def forward(self, accumulated_flow, *other):
x = [accumulated_flow]
for each in other:
if each is not None:
assert(accumulated_flow.shape[-1] == each.shape[-1]), "decoder want {}, but get {}".format(
accumulated_flow.shape, each.shape)
x.append(each)
feat = self.conv0(torch.cat(x, dim=1))
for convblock1 in self.convblocks:
feat = convblock1(feat) + feat
feat = self.conv_last(feat)
prev = 0
flow = feat[:, prev:prev+self.out_flows, :, :]
prev += self.out_flows
message = feat[:, prev:prev+self.out_msgs,
:, :] if self.out_msgs > 0 else None
prev += self.out_msgs
message_flow = feat[:, prev:prev + self.out_msgs_flows,
:, :] if self.out_msgs_flows > 0 else None
prev += self.out_msgs_flows
local_message = feat[:, prev:prev + self.out_locals,
:, :] if self.out_locals > 0 else None
prev += self.out_locals
local_message_flow = feat[:, prev:prev+self.out_local_flows,
:, :] if self.out_local_flows > 0 else None
prev += self.out_local_flows
mask = torch.sigmoid(
feat[:, prev:prev+self.out_masks, :, :]) if self.out_masks > 0 else None
prev += self.out_masks
feat_flow = feat[:, prev:prev+self.out_feat_flows,
:, :] if self.out_feat_flows > 0 else None
prev += self.out_feat_flows
return flow, mask, message, message_flow, local_message, local_message_flow, feat_flow
class CINN(nn.Module):
def __init__(self, DIM_SHADER_REFERENCE, target_feature_chns=[512, 256, 128, 64, 64], feature_chns=[2048, 1024, 512, 256, 64], out_msgs_chn=[2048, 1024, 512, 256, 64, 64], out_locals_chn=[2048, 1024, 512, 256, 64, 0], block_num=[1, 1, 1, 1, 1, 2], block_chn_num=[224, 224, 224, 224, 224, 224]):
super(CINN, self).__init__()
self.in_msgs_chn = [0, *out_msgs_chn[:-1]]
self.in_locals_chn = [0, *out_locals_chn[:-1]]
self.decoder_blocks = nn.ModuleList()
self.feed_weighted = True
if self.feed_weighted:
in_planes = 2+2+DIM_SHADER_REFERENCE*2
else:
in_planes = 2+DIM_SHADER_REFERENCE
for each_target_feature_chns, each_feature_chns, each_out_msgs_chn, each_out_locals_chn, each_in_msgs_chn, each_in_locals_chn, each_block_num, each_block_chn_num in zip(target_feature_chns, feature_chns, out_msgs_chn, out_locals_chn, self.in_msgs_chn, self.in_locals_chn, block_num, block_chn_num):
self.decoder_blocks.append(
DecoderBlock(in_planes+each_target_feature_chns+each_feature_chns+each_in_locals_chn+each_in_msgs_chn, c=each_block_chn_num, block_nums=each_block_num, out_msgs=each_out_msgs_chn, out_locals=each_out_locals_chn, out_masks=2+each_out_locals_chn))
for i in range(len(feature_chns), len(out_locals_chn)):
#print("append extra block", i, "msg",
# out_msgs_chn[i], "local", out_locals_chn[i], "block", block_num[i])
self.decoder_blocks.append(
DecoderBlock(in_planes+self.in_msgs_chn[i]+self.in_locals_chn[i], c=block_chn_num[i], block_nums=block_num[i], out_msgs=out_msgs_chn[i], out_locals=out_locals_chn[i], out_masks=2+out_msgs_chn[i], out_feat_flows=0))
def apply_flow(self, mask, message, message_flow, local_message, local_message_flow, x_reference, accumulated_flow, each_x_reference_features=None, each_x_reference_features_flow=None):
if each_x_reference_features is not None:
size_from = each_x_reference_features
else:
size_from = x_reference
f_size = (size_from.shape[2], size_from.shape[3])
accumulated_flow = self.flow_rescale(
accumulated_flow, size_from)
# mask = warp_features(F.interpolate(
# mask, size=f_size, mode="bilinear"), accumulated_flow) if mask is not None else None
mask = F.interpolate(
mask, size=f_size, mode="bilinear") if mask is not None else None
message = F.interpolate(
message, size=f_size, mode="bilinear") if message is not None else None
message_flow = self.flow_rescale(
message_flow, size_from) if message_flow is not None else None
message = warp_features(
message, message_flow) if message_flow is not None else message
local_message = F.interpolate(
local_message, size=f_size, mode="bilinear") if local_message is not None else None
local_message_flow = self.flow_rescale(
local_message_flow, size_from) if local_message_flow is not None else None
local_message = warp_features(
local_message, local_message_flow) if local_message_flow is not None else local_message
warp_x_reference = warp_features(F.interpolate(
x_reference, size=f_size, mode="bilinear"), accumulated_flow)
each_x_reference_features_flow = self.flow_rescale(
each_x_reference_features_flow, size_from) if (each_x_reference_features is not None and each_x_reference_features_flow is not None) else None
warp_each_x_reference_features = warp_features(
each_x_reference_features, each_x_reference_features_flow) if each_x_reference_features_flow is not None else each_x_reference_features
return mask, message, local_message, warp_x_reference, accumulated_flow, warp_each_x_reference_features, each_x_reference_features_flow
def forward(self, x_target_features=[], x_reference=None, x_reference_features=[]):
y_flow = []
y_feat_flow = []
y_local_message = []
y_warp_x_reference = []
y_warp_x_reference_features = []
y_weighted_flow = []
y_weighted_mask = []
y_weighted_message = []
y_weighted_x_reference = []
y_weighted_x_reference_features = []
for pyrlevel, ifblock in enumerate(self.decoder_blocks):
stacked_wref = []
stacked_feat = []
stacked_anci = []
stacked_flow = []
stacked_mask = []
stacked_mesg = []
stacked_locm = []
stacked_feat_flow = []
for view_id in range(x_reference.shape[1]): # NMCHW
if pyrlevel == 0:
# create from zero flow
feat_ev = x_reference_features[pyrlevel][:,
view_id, :, :, :] if pyrlevel < len(x_reference_features) else None
accumulated_flow = torch.zeros_like(
feat_ev[:, :2, :, :]).to(device)
accumulated_feat_flow = torch.zeros_like(
feat_ev[:, :32, :, :]).to(device)
# domestic inputs
warp_x_reference = F.interpolate(x_reference[:, view_id, :, :, :], size=(
feat_ev.shape[-2], feat_ev.shape[-1]), mode="bilinear")
warp_x_reference_features = feat_ev
local_message = None
# federated inputs
weighted_flow = accumulated_flow if self.feed_weighted else None
weighted_wref = warp_x_reference if self.feed_weighted else None
weighted_message = None
else:
# resume from last layer
accumulated_flow = y_flow[-1][:, view_id, :, :, :]
accumulated_feat_flow = y_feat_flow[-1][:,
view_id, :, :, :] if y_feat_flow[-1] is not None else None
# domestic inputs
warp_x_reference = y_warp_x_reference[-1][:,
view_id, :, :, :]
warp_x_reference_features = y_warp_x_reference_features[-1][:,
view_id, :, :, :] if y_warp_x_reference_features[-1] is not None else None
local_message = y_local_message[-1][:, view_id, :,
:, :] if len(y_local_message) > 0 else None
# federated inputs
weighted_flow = y_weighted_flow[-1] if self.feed_weighted else None
weighted_wref = y_weighted_x_reference[-1] if self.feed_weighted else None
weighted_message = y_weighted_message[-1] if len(
y_weighted_message) > 0 else None
scaled_x_target = x_target_features[pyrlevel][:, :, :, :].detach() if pyrlevel < len(
x_target_features) else None
# compute flow
residual_flow, mask, message, message_flow, local_message, local_message_flow, residual_feat_flow = ifblock(
accumulated_flow, scaled_x_target, warp_x_reference, warp_x_reference_features, weighted_flow, weighted_wref, weighted_message, local_message)
accumulated_flow = residual_flow + accumulated_flow
accumulated_feat_flow = accumulated_flow
feat_ev = x_reference_features[pyrlevel+1][:,
view_id, :, :, :] if pyrlevel+1 < len(x_reference_features) else None
mask, message, local_message, warp_x_reference, accumulated_flow, warp_x_reference_features, accumulated_feat_flow = self.apply_flow(
mask, message, message_flow, local_message, local_message_flow, x_reference[:, view_id, :, :, :], accumulated_flow, feat_ev, accumulated_feat_flow)
stacked_flow.append(accumulated_flow)
if accumulated_feat_flow is not None:
stacked_feat_flow.append(accumulated_feat_flow)
stacked_mask.append(mask)
if message is not None:
stacked_mesg.append(message)
if local_message is not None:
stacked_locm.append(local_message)
stacked_wref.append(warp_x_reference)
if warp_x_reference_features is not None:
stacked_feat.append(warp_x_reference_features)
stacked_flow = torch.stack(stacked_flow, dim=1) # M*NCHW -> NMCHW
stacked_feat_flow = torch.stack(stacked_feat_flow, dim=1) if len(
stacked_feat_flow) > 0 else None
stacked_mask = torch.stack(
stacked_mask, dim=1)
stacked_mesg = torch.stack(stacked_mesg, dim=1) if len(
stacked_mesg) > 0 else None
stacked_locm = torch.stack(stacked_locm, dim=1) if len(
stacked_locm) > 0 else None
stacked_wref = torch.stack(stacked_wref, dim=1)
stacked_feat = torch.stack(stacked_feat, dim=1) if len(
stacked_feat) > 0 else None
stacked_anci = torch.stack(stacked_anci, dim=1) if len(
stacked_anci) > 0 else None
y_flow.append(stacked_flow)
y_feat_flow.append(stacked_feat_flow)
y_warp_x_reference.append(stacked_wref)
y_warp_x_reference_features.append(stacked_feat)
# compute normalized confidence
stacked_contrib = torch.nn.functional.softmax(stacked_mask, dim=1)
# torch.sum to remove temp dimension M from NMCHW --> NCHW
weighted_flow = torch.sum(
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_flow, dim=1)
weighted_mask = torch.sum(
stacked_contrib[:, :, 0:1, :, :] * stacked_mask[:, :, 0:1, :, :], dim=1)
weighted_wref = torch.sum(
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_wref, dim=1) if stacked_wref is not None else None
weighted_feat = torch.sum(
stacked_mask[:, :, 1:2, :, :] * stacked_contrib[:, :, 1:2, :, :] * stacked_feat, dim=1) if stacked_feat is not None else None
weighted_mesg = torch.sum(
stacked_mask[:, :, 2:, :, :] * stacked_contrib[:, :, 2:, :, :] * stacked_mesg, dim=1) if stacked_mesg is not None else None
y_weighted_flow.append(weighted_flow)
y_weighted_mask.append(weighted_mask)
if weighted_mesg is not None:
y_weighted_message.append(weighted_mesg)
if stacked_locm is not None:
y_local_message.append(stacked_locm)
y_weighted_message.append(weighted_mesg)
y_weighted_x_reference.append(weighted_wref)
y_weighted_x_reference_features.append(weighted_feat)
if weighted_feat is not None:
y_weighted_x_reference_features.append(weighted_feat)
return {
"y_last_remote_features": [weighted_mesg],
}
def flow_rescale(self, prev_flow, each_x_reference_features):
if prev_flow is None:
prev_flow = torch.zeros_like(
each_x_reference_features[:, :2]).to(device)
else:
up_scale_factor = each_x_reference_features.shape[-1] / \
prev_flow.shape[-1]
if up_scale_factor != 1:
prev_flow = F.interpolate(prev_flow, scale_factor=up_scale_factor, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * up_scale_factor
return prev_flow