sunana's picture
Update MT.py
4fbf139
raw
history blame
13.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class GRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(GRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
def forward(self, hidden, x, shape):
# horizontal
b, l, c = hidden.shape
h, w = shape
hidden = hidden.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
x = x.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
hx = torch.cat([hidden, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * hidden, x], dim=1)))
hidden = (1 - z) * hidden + z * q
# vertical
hx = torch.cat([hidden, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * hidden, x], dim=1)))
hidden = (1 - z) * hidden + z * q
return hidden.flatten(-2).permute(0, 2, 1)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x):
# x = tensor_list.tensors # [B, C, H, W]
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
b, c, h, w = x.size()
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
y_embed = mask.cumsum(1, dtype=torch.float32)
x_embed = mask.cumsum(2, dtype=torch.float32)
#
# y_embed = (y_embed / 2) ** 2
# x_embed = (x_embed / 2) ** 2
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
# using an exponential
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def feature_add_position(feature0, feature_channels, scale=0.5):
temp = torch.mean(abs(feature0))
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
# position = PositionalEncodingPermute2D(feature_channels)(feature0)
position = pos_enc(feature0)
feature0 = feature0 + (temp * position / position.mean()) * scale * torch.pi
feature0 = feature0 * temp / torch.mean(abs(feature0), dim=(1, 2, 3), keepdim=True)
return feature0
def feature_add_image_content(feature0, add_fea, scale=0.4):
temp = torch.mean(abs(feature0))
position = add_fea
feature0 = feature0 + (temp * position / position.mean()) * scale * torch.pi
feature0 = feature0 * temp / torch.mean(abs(feature0), dim=(1, 2, 3), keepdim=True)
return feature0
class AttUp(nn.Module):
def __init__(self,
c=512
):
super(AttUp, self).__init__()
self.proj = nn.Linear(c, c, bias=False)
self.norm = nn.LayerNorm(c)
self.conv = nn.Sequential(nn.Conv2d(2 * c, c, kernel_size=1, stride=1, padding=0),
nn.GELU(),
nn.Conv2d(c, c, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Conv2d(c, c, kernel_size=3, stride=1, padding=1),
nn.GELU()
)
self.gru = SepConvGRU(c, c)
def forward(self, att, message, shape):
# q, k, v: [B, L, C]
b, l, c = att.shape
h, w = shape
message = self.norm(self.proj(message)).view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
att = att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
message = self.conv(torch.cat([att, message], dim=1))
att = self.gru(att, message).flatten(-2).permute(0, 2, 1)
# [B, H*W, C]
return att
class TransformerLayer(nn.Module):
def __init__(self,
d_model=256,
nhead=1,
no_ffn=False,
ffn_dim_expansion=4
):
super(TransformerLayer, self).__init__()
self.dim = d_model
self.nhead = nhead
self.no_ffn = no_ffn
# multi-head attention
self.att_proj = nn.Sequential(nn.Linear(d_model, d_model, bias=False), nn.ReLU(inplace=True),
nn.Linear(d_model, d_model, bias=False))
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.merge = nn.Linear(d_model, d_model, bias=False)
self.gru = GRU(d_model, d_model)
self.attn_updater = AttUp(d_model)
self.drop = nn.Dropout(p=0.8)
self.norm1 = nn.LayerNorm(d_model)
# no ffn after self-attn, with ffn after cross-attn
if not self.no_ffn:
in_channels = d_model * 2
self.mlp = nn.Sequential(
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
nn.GELU(),
nn.Linear(in_channels * ffn_dim_expansion, in_channels * ffn_dim_expansion, bias=False),
nn.GELU(),
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, att, value,
shape, iteration=0):
# source, target: [B, L, C]
max_exp_scale = 3 * torch.pi
# single-head attention
B, L, C = value.shape
if iteration == 0:
att = feature_add_position(att.transpose(-1, -2).view(
B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)
val_proj = self.v_proj(value)
att_proj = self.att_proj(att) # [B, L, C]
norm_fac = torch.sum(att_proj ** 2, dim=-1, keepdim=True) ** 0.5
scale = max_exp_scale * torch.sigmoid(torch.mean(att_proj, dim=[-1, -2], keepdim=True)) + 1
A = torch.exp(scale * torch.matmul(att_proj / norm_fac, att_proj.permute(0, 2, 1) / norm_fac.permute(0, 2, 1)))
A = A / A.max()
# I = torch.eye(A.shape[-1], device=A.device).unsqueeze(0)
# # A[I.repeat(B, 1, 1) == 1] = 1e-6 # remove self-prop
D = torch.sum(A, dim=-1, keepdim=True)
D = 1 / (torch.sqrt(D) + 1e-6) # normalized node degrees
A = D * A * D.transpose(-1, -2)
message = torch.matmul(A, val_proj) # [B, L, C]
message = self.merge(message) # [B, L, C]
message = self.norm1(message)
if not self.no_ffn:
message = self.mlp(torch.cat([value, message], dim=-1))
message = self.norm2(message)
att = self.attn_updater(att, message, shape)
value = self.gru(value, message, shape)
return value, att, A
class FeatureTransformer(nn.Module):
def __init__(self,
num_layers=6,
d_model=128
):
super(FeatureTransformer, self).__init__()
self.d_model = d_model
# self.layers = nn.ModuleList([TransformerLayer(self.d_model, no_ffn=False, ffn_dim_expansion=2)
# for i in range(num_layers)])
self.layers = TransformerLayer(self.d_model, no_ffn=False, ffn_dim_expansion=2)
self.re_proj = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model))
self.num_layers = num_layers
self.norm_sigma = nn.Parameter(torch.tensor(1.0, requires_grad=True), requires_grad=True)
self.norm_k = nn.Parameter(torch.tensor(1.8, requires_grad=True), requires_grad=True)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def normalize(self, x): # TODO
sum_activation = torch.mean(x, dim=[1, 2], keepdim=True) + torch.square(self.norm_sigma)
x = self.norm_k.abs() * x / sum_activation
return x
def forward(self, feature0):
feature_list = []
attn_list = []
attn_viz_list = []
b, c, h, w = feature0.shape
assert self.d_model == c
value = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
att = feature0
att = att.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
for i in range(self.num_layers):
value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
value_decode = self.normalize(torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
if i % 2 == 0:
attn_viz_list.append(attn_viz.reshape(b, h, w, h, w))
attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
return feature_list, attn_list, attn_viz_list
def forward_save_mem(self, feature0, add_position_embedding=True):
feature_list = []
attn_list = []
attn_viz_list = []
b, c, h, w = feature0.shape
assert self.d_model == c
value = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
att = feature0
att = att.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
for i in range(self.num_layers):
value, att, _ = self.layers(att=att, value=value, shape=[h, w], iteration=i)
value_decode = self.normalize(
torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
# print("value_decode",value_decode.abs().mean())
attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
# reshape back
return feature_list, attn_list
@staticmethod
def demo():
import time
frame_list = torch.randn([4, 256, 64, 64], device="cuda")
model = FeatureTransformer(6, 256).cuda()
for i in range(100):
start = time.time()
output = model(frame_list)
torch.mean(output[-1][-1]).backward()
end = time.time()
print(end - start)
print("#================================++#")
if __name__ == '__main__':
FeatureTransformer.demo()