File size: 6,946 Bytes
d2542a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .attention import SelfAttention
class MultiAttention(nn.Module):
def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
num_segments=None, heads=1, fusion=None):
""" Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters.
:param int input_size: The expected input feature size.
:param int output_size: The hidden feature size of the attention mechanisms.
:param int freq: The frequency of the sinusoidal positional encoding.
:param None | str pos_enc: The selected positional encoding [absolute, relative].
:param None | int num_segments: The selected number of segments to split the videos.
:param int heads: The selected number of global heads.
:param None | str fusion: The selected type of feature fusion.
"""
super(MultiAttention, self).__init__()
# Global Attention, considering differences among all frames
self.attention = SelfAttention(input_size=input_size, output_size=output_size,
freq=freq, pos_enc=pos_enc, heads=heads)
self.num_segments = num_segments
if self.num_segments is not None:
assert self.num_segments >= 2, "num_segments must be None or 2+"
self.local_attention = nn.ModuleList()
for _ in range(self.num_segments):
# Local Attention, considering differences among the same segment with reduce hidden size
self.local_attention.append(SelfAttention(input_size=input_size, output_size=output_size//num_segments,
freq=freq, pos_enc=pos_enc, heads=4))
self.permitted_fusions = ["add", "mult", "avg", "max"]
self.fusion = fusion
if self.fusion is not None:
self.fusion = self.fusion.lower()
assert self.fusion in self.permitted_fusions, f"Fusion method must be: {*self.permitted_fusions,}"
def forward(self, x):
""" Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms.
:param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features.
:return: A tuple of:
weighted_value: Tensor with shape [T, input_size] containing the weighted frame features.
attn_weights: Tensor with shape [T, T] containing the attention weights.
"""
weighted_value, attn_weights = self.attention(x) # global attention
if self.num_segments is not None and self.fusion is not None:
segment_size = math.ceil(x.shape[0] / self.num_segments)
for segment in range(self.num_segments):
left_pos = segment * segment_size
right_pos = (segment + 1) * segment_size
local_x = x[left_pos:right_pos]
weighted_local_value, attn_local_weights = self.local_attention[segment](local_x) # local attentions
# Normalize the features vectors
weighted_value[left_pos:right_pos] = F.normalize(weighted_value[left_pos:right_pos].clone(), p=2, dim=1)
weighted_local_value = F.normalize(weighted_local_value, p=2, dim=1)
if self.fusion == "add":
weighted_value[left_pos:right_pos] += weighted_local_value
elif self.fusion == "mult":
weighted_value[left_pos:right_pos] *= weighted_local_value
elif self.fusion == "avg":
weighted_value[left_pos:right_pos] += weighted_local_value
weighted_value[left_pos:right_pos] /= 2
elif self.fusion == "max":
weighted_value[left_pos:right_pos] = torch.max(weighted_value[left_pos:right_pos].clone(),
weighted_local_value)
return weighted_value, attn_weights
class PGL_SUM(nn.Module):
def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
num_segments=None, heads=1, fusion=None):
""" Class wrapping the PGL-SUM model; its key modules and parameters.
:param int input_size: The expected input feature size.
:param int output_size: The hidden feature size of the attention mechanisms.
:param int freq: The frequency of the sinusoidal positional encoding.
:param None | str pos_enc: The selected positional encoding [absolute, relative].
:param None | int num_segments: The selected number of segments to split the videos.
:param int heads: The selected number of global heads.
:param None | str fusion: The selected type of feature fusion.
"""
super(PGL_SUM, self).__init__()
self.attention = MultiAttention(input_size=input_size, output_size=output_size, freq=freq,
pos_enc=pos_enc, num_segments=num_segments, heads=heads, fusion=fusion)
self.linear_1 = nn.Linear(in_features=input_size, out_features=input_size)
self.linear_2 = nn.Linear(in_features=self.linear_1.out_features, out_features=1)
self.drop = nn.Dropout(p=0.5)
self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6)
self.norm_linear = nn.LayerNorm(normalized_shape=self.linear_1.out_features, eps=1e-6)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, frame_features):
""" Produce frames importance scores from the frame features, using the PGL-SUM model.
:param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by
using the pool5 layer of GoogleNet.
:return: A tuple of:
y: Tensor with shape [1, T] containing the frames importance scores in [0, 1].
attn_weights: Tensor with shape [T, T] containing the attention weights.
"""
residual = frame_features
weighted_value, attn_weights = self.attention(frame_features)
y = weighted_value + residual
y = self.drop(y)
y = self.norm_y(y)
# 2-layer NN (Regressor Network)
y = self.linear_1(y)
y = self.relu(y)
y = self.drop(y)
y = self.norm_linear(y)
y = self.linear_2(y)
y = self.sigmoid(y)
y = y.view(1, -1)
return y, attn_weights
if __name__ == '__main__':
pass
"""Uncomment for a quick proof of concept
model = PGL_SUM(input_size=256, output_size=256, num_segments=3, fusion="Add").cuda()
_input = torch.randn(500, 256).cuda() # [seq_len, hidden_size]
output, weights = model(_input)
print(f"Output shape: {output.shape}\tattention shape: {weights.shape}")
"""
|