File size: 5,858 Bytes
d2542a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import numpy as np
class SelfAttention(nn.Module):
def __init__(self, input_size=1024, output_size=1024, freq=10000, heads=1, pos_enc=None):
""" The basic (multi-head) Attention 'cell' containing the learnable parameters of Q, K and V
:param int input_size: Feature input size of Q, K, V.
:param int output_size: Feature -hidden- size of Q, K, V.
:param int freq: The frequency of the sinusoidal positional encoding.
:param int heads: Number of heads for the attention module.
:param str | None pos_enc: The type of the positional encoding [supported: Absolute, Relative].
"""
super(SelfAttention, self).__init__()
self.permitted_encodings = ["absolute", "relative"]
if pos_enc is not None:
pos_enc = pos_enc.lower()
assert pos_enc in self.permitted_encodings, f"Supported encodings: {*self.permitted_encodings,}"
self.input_size = input_size
self.output_size = output_size
self.heads = heads
self.pos_enc = pos_enc
self.freq = freq
self.Wk, self.Wq, self.Wv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
for _ in range(self.heads):
self.Wk.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
self.Wq.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
self.Wv.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
self.out = nn.Linear(in_features=output_size, out_features=input_size, bias=False)
self.softmax = nn.Softmax(dim=-1)
self.drop = nn.Dropout(p=0.5)
def getAbsolutePosition(self, T):
"""Calculate the sinusoidal positional encoding based on the absolute position of each considered frame.
Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762)
:param int T: Number of frames contained in Q, K and V
:return: Tensor with shape [T, T]
"""
freq = self.freq
d = self.input_size
pos = torch.tensor([k for k in range(T)], device=self.out.weight.device)
i = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)
# Reshape tensors each pos_k for each i indices
pos = pos.reshape(pos.shape[0], 1)
pos = pos.repeat_interleave(i.shape[0], dim=1)
i = i.repeat(pos.shape[0], 1)
AP = torch.zeros(T, T, device=self.out.weight.device)
AP[pos, 2*i] = torch.sin(pos / freq ** ((2 * i) / d))
AP[pos, 2*i+1] = torch.cos(pos / freq ** ((2 * i) / d))
return AP
def getRelativePosition(self, T):
"""Calculate the sinusoidal positional encoding based on the relative position of each considered frame.
r_pos calculations as here: https://theaisummer.com/positional-embeddings/
:param int T: Number of frames contained in Q, K and V
:return: Tensor with shape [T, T]
"""
freq = self.freq
d = 2 * T
min_rpos = -(T - 1)
i = torch.tensor([k for k in range(T)], device=self.out.weight.device)
j = torch.tensor([k for k in range(T)], device=self.out.weight.device)
# Reshape tensors each i for each j indices
i = i.reshape(i.shape[0], 1)
i = i.repeat_interleave(i.shape[0], dim=1)
j = j.repeat(i.shape[0], 1)
# Calculate the relative positions
r_pos = j - i - min_rpos
RP = torch.zeros(T, T, device=self.out.weight.device)
idx = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)
RP[:, 2*idx] = torch.sin(r_pos[:, 2*idx] / freq ** ((i[:, 2*idx] + j[:, 2*idx]) / d))
RP[:, 2*idx+1] = torch.cos(r_pos[:, 2*idx+1] / freq ** ((i[:, 2*idx+1] + j[:, 2*idx+1]) / d))
return RP
def forward(self, x):
""" Compute the weighted frame features, based on either the global or local (multi-head) attention mechanism.
:param torch.tensor x: Frame features with shape [T, input_size]
:return: A tuple of:
y: Weighted features based on the attention weights, with shape [T, input_size]
att_weights : The attention weights (before dropout), with shape [T, T]
"""
outputs = []
for head in range(self.heads):
K = self.Wk[head](x)
Q = self.Wq[head](x)
V = self.Wv[head](x)
# Q *= 0.06 # scale factor VASNet
# Q /= np.sqrt(self.output_size) # scale factor (i.e 1 / sqrt(d_k) )
energies = torch.matmul(Q, K.transpose(1, 0))
if self.pos_enc is not None:
if self.pos_enc == "absolute":
AP = self.getAbsolutePosition(T=energies.shape[0])
energies = energies + AP
elif self.pos_enc == "relative":
RP = self.getRelativePosition(T=energies.shape[0])
energies = energies + RP
att_weights = self.softmax(energies)
_att_weights = self.drop(att_weights)
y = torch.matmul(_att_weights, V)
# Save the current head output
outputs.append(y)
y = self.out(torch.cat(outputs, dim=1))
return y, att_weights.clone() # for now we don't deal with the weights (probably max or avg pooling)
if __name__ == '__main__':
pass
"""Uncomment for a quick proof of concept
model = SelfAttention(input_size=256, output_size=256, pos_enc="absolute").cuda()
_input = torch.randn(500, 256).cuda() # [seq_len, hidden_size]
output, weights = model(_input)
print(f"Output shape: {output.shape}\tattention shape: {weights.shape}")
"""
|