sunana's picture
Update FFV1MT_MS.py
5c8ad42
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from MT import FeatureTransformer
from torch.cuda.amp import autocast as autocast
from flow_tools import viz_img_seq, save_img_seq, plt_show_img_flow
from copy import deepcopy
from V1 import V1
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):
if isReLU:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
dilation=dilation,
padding=((kernel_size - 1) * dilation) // 2, bias=True),
nn.GELU()
)
else:
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
dilation=dilation,
padding=((kernel_size - 1) * dilation) // 2, bias=True)
)
def plt_attention(attention, h, w):
col = len(attention) // 2
fig = plt.figure(figsize=(10, 8))
for i in range(len(attention)):
viz = attention[i][0, :, :, h, w].detach().cpu().numpy()
# viz = viz[7:-7, 7:-7]
if i == 0:
viz_all = viz
else:
viz_all = viz_all + viz
ax1 = fig.add_subplot(2, col, i + 1)
img = ax1.imshow(viz, cmap="rainbow", interpolation="bilinear")
ax1.scatter(w, h, color='grey', s=300, alpha=0.5)
ax1.scatter(w, h, color='red', s=150, alpha=0.5)
plt.title(" Iteration %d" % (i + 1))
if i == len(attention) - 1:
plt.title(" Final Iteration")
plt.xticks([])
plt.yticks([])
# tight layout
plt.tight_layout()
# save the figure
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# convert the figure to an array
img = Image.open(buf)
img = np.array(img)
return img
class FlowDecoder(nn.Module):
# can reduce 25% of training time.
def __init__(self, ch_in):
super(FlowDecoder, self).__init__()
self.conv1 = conv(ch_in, 256, kernel_size=1)
self.conv2 = conv(256, 128, kernel_size=1)
self.conv3 = conv(256 + 128, 96, kernel_size=1)
self.conv4 = conv(96 + 128, 64, kernel_size=1)
self.conv5 = conv(96 + 64, 32, kernel_size=1)
self.feat_dim = 32
self.predict_flow = conv(64 + 32, 2, isReLU=False)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(torch.cat([x1, x2], dim=1))
x4 = self.conv4(torch.cat([x2, x3], dim=1))
x5 = self.conv5(torch.cat([x3, x4], dim=1))
flow = self.predict_flow(torch.cat([x4, x5], dim=1))
return flow
class FFV1DNN(nn.Module):
def __init__(self,
num_scales=8,
num_cells=256,
upsample_factor=8,
feature_channels=256,
scale_factor=16,
num_layers=6,
):
super(FFV1DNN, self).__init__()
self.ffv1 = V1(spatial_num=num_cells // num_scales, scale_num=num_scales, scale_factor=scale_factor,
kernel_radius=7, num_ft=num_cells // num_scales,
kernel_size=6, average_time=True)
self.v1_kz = 7
self.scale_factor = scale_factor
scale_each_level = np.exp(1 / (num_scales - 1) * np.log(1 / scale_factor))
self.scale_num = num_scales
self.scale_each_level = scale_each_level
v1_channel = self.ffv1.num_after_st
self.num_scales = num_scales
self.MT_channel = feature_channels
assert self.MT_channel == v1_channel
self.feature_channels = feature_channels
self.upsample_factor = upsample_factor
self.num_layers = num_layers
# convex upsampling: concat feature0 and flow as input
self.upsampler_1 = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, upsample_factor ** 2 * 9, 3, 1, 1))
self.decoder = FlowDecoder(feature_channels)
self.conv_feat = nn.ModuleList([conv(v1_channel, feature_channels, 1) for i in range(num_scales)])
self.MT = FeatureTransformer(d_model=feature_channels, num_layers=self.num_layers)
# 2*2*8*scale`
def upsample_flow(self, flow, feature, upsampler=None, bilinear=False, upsample_factor=4):
if bilinear:
up_flow = F.interpolate(flow, scale_factor=upsample_factor,
mode='bilinear', align_corners=True) * upsample_factor
else:
# convex upsampling
concat = torch.cat((flow, feature), dim=1)
mask = upsampler(concat)
b, flow_channel, h, w = flow.shape
mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(upsample_factor * flow, [3, 3], padding=1)
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h,
upsample_factor * w) # [B, 2, K*H, K*W]
return up_flow
def forward(self, image_list, mix_enable=True, layer=6):
if layer is not None:
self.MT.num_layers = layer
self.num_layers = layer
results_dict = {}
padding = self.v1_kz * self.scale_factor
with torch.no_grad():
if image_list[0].max() > 10:
image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1
if image_list[0].shape[1] == 3:
# convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114
image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in
image_list]
image_list = [img.unsqueeze(1) for img in image_list]
B, _, H, W = image_list[0].shape
MT_size = (H // 8, W // 8)
with autocast(enabled=mix_enable):
# with torch.no_grad(): # TODO: only for test wheather a trainable V1 is needed.
st_component = self.ffv1(image_list)
# viz_img_seq(image_scale, if_debug=True)
if self.num_layers == 0:
motion_feature = [st_component]
flows = [self.decoder(feature) for feature in motion_feature]
flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
results_dict["flow_seq"] = flows_up
return results_dict
motion_feature, attn = self.MT.forward_save_mem(st_component)
flow_v1 = self.decoder(st_component)
flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature]
flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
flows_up = [flows_bi[0]] + \
[self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for
flows, attn in zip(flows[1:], attn)]
assert len(flows_bi) == len(flows_up)
results_dict["flow_seq"] = flows_up
results_dict["flow_seq_bi"] = flows_bi
return results_dict
def forward_test(self, image_list, mix_enable=True, layer=6):
if layer is not None:
self.MT.num_layers = layer
self.num_layers = layer
results_dict = {}
padding = self.v1_kz * self.scale_factor
with torch.no_grad():
if image_list[0].max() > 10:
image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1
B, _, H, W = image_list[0].shape
MT_size = (H // 8, W // 8)
with autocast(enabled=mix_enable):
st_component = self.ffv1(image_list)
# viz_img_seq(image_scale, if_debug=True)
if self.num_layers == 0:
motion_feature = [st_component]
flows = [self.decoder(feature) for feature in motion_feature]
flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
results_dict["flow_seq"] = flows_up
return results_dict
motion_feature, attn, _ = self.MT.forward_save_mem(st_component)
flow_v1 = self.decoder(st_component)
flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature]
flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
flows_up = [flows_bi[0]] + \
[self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for
flows, attn in zip(flows[1:], attn)]
assert len(flows_bi) == len(flows_up)
results_dict["flow_seq"] = flows_up
results_dict["flow_seq_bi"] = flows_bi
return results_dict
@torch.no_grad()
def forward_viz(self, image_list, layer=None, x=50, y=50):
x = x / 100
y = y / 100
if layer is not None:
self.MT.num_layers = layer
results_dict = {}
padding = self.v1_kz * self.scale_factor
with torch.no_grad():
if image_list[0].max() > 10:
image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1
if image_list[0].shape[1] == 3:
# convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114
image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in
image_list]
image_list = [img.unsqueeze(1) for img in image_list]
image_list_ori = deepcopy(image_list)
B, _, H, W = image_list[0].shape
MT_size = (H // 8, W // 8)
with autocast(enabled=True):
st_component = self.ffv1(image_list)
activation = self.ffv1.visualize_activation(st_component)
# viz_img_seq(image_scale, if_debug=True)
motion_feature, attn, attn_viz = self.MT(st_component)
flow_v1 = self.decoder(st_component)
flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature]
flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows]
flows_up = [flows_bi[0]] + \
[self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for
flows, attn in zip(flows[1:], attn)]
assert len(flows_bi) == len(flows_up)
results_dict["flow_seq"] = flows_up
flows_up = flows_up[:-1]
attn_viz = attn_viz
print(len(flows_up), len(attn_viz))
flow = plt_show_img_flow(image_list_ori, flows_up)
h = int(MT_size[0] * y)
w = int(MT_size[1] * x)
attention = plt_attention(attn_viz, h=h, w=w)
print("done")
results_dict["activation"] = activation
results_dict["attention"] = attention
results_dict["flow"] = flow
plt.clf()
plt.cla()
plt.close()
return results_dict
def num_parameters(self):
return sum(
[p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
def init_weights(self):
for layer in self.named_modules():
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
if isinstance(layer, nn.Conv1d):
nn.init.kaiming_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
elif isinstance(layer, nn.ConvTranspose2d):
nn.init.kaiming_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
@staticmethod
def demo(file=None):
import time
from utils import torch_utils as utils
frame_list = [torch.randn([4, 1, 512, 512], device="cuda")] * 11
model = FFV1DNN(num_scales=8, scale_factor=16, num_cells=256, upsample_factor=8, num_layers=6,
feature_channels=256).cuda()
if file is not None:
model = utils.restore_model(model, file)
print(model.num_parameters())
for i in range(100):
start = time.time()
output = model.forward_viz(frame_list, layer=7)
# print(output["flow_seq"][-1])
torch.mean(output["flow_seq"][-1]).backward()
print(torch.any(torch.isnan(output["flow_seq"][-1])))
end = time.time()
print(end - start)
print("#================================++#")
if __name__ == '__main__':
FFV1DNN.demo(None)