|
from typing import Optional |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from poetry_diacritizer.options import AttentionType |
|
|
|
|
|
class BahdanauAttention(nn.Module): |
|
def __init__(self, dim): |
|
super(BahdanauAttention, self).__init__() |
|
self.query_layer = nn.Linear(dim, dim, bias=False) |
|
self.tanh = nn.Tanh() |
|
self.v = nn.Linear(dim, 1, bias=False) |
|
|
|
def forward(self, query: torch.Tensor, keys: torch.Tensor): |
|
""" |
|
Args: |
|
query: (B, 1, dim) or (batch, dim) |
|
processed_memory: (batch, max_time, dim) |
|
""" |
|
if query.dim() == 2: |
|
|
|
query = query.unsqueeze(1) |
|
|
|
query = self.query_layer(query) |
|
|
|
|
|
alignment = self.v(self.tanh(query + keys)) |
|
|
|
|
|
return alignment.squeeze(-1) |
|
|
|
|
|
class LocationSensitive(nn.Module): |
|
def __init__(self, dim): |
|
super(LocationSensitive, self).__init__() |
|
self.query_layer = nn.Linear(dim, dim, bias=False) |
|
self.v = nn.Linear(dim, 1, bias=True) |
|
self.location_layer = nn.Linear(32, dim, bias=False) |
|
padding = int((31 - 1) / 2) |
|
self.location_conv = torch.nn.Conv1d( |
|
1, 32, kernel_size=31, stride=1, padding=padding, dilation=1, bias=False |
|
) |
|
|
|
self.score_mask_value = -float("inf") |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
keys: torch.Tensor, |
|
prev_alignments: torch.Tensor, |
|
): |
|
|
|
query = self.query_layer(query) |
|
if query.dim() == 2: |
|
|
|
query = query.unsqueeze(1) |
|
|
|
|
|
alignments = prev_alignments.unsqueeze(1) |
|
|
|
|
|
filters = self.location_conv(alignments) |
|
location_features = self.location_layer(filters.transpose(1, 2)) |
|
|
|
alignments = self.v(torch.tanh(query + location_features + keys)) |
|
return alignments.squeeze(-1) |
|
|
|
|
|
class AttentionWrapper(nn.Module): |
|
def __init__( |
|
self, |
|
attention_type: AttentionType = AttentionType.LocationSensitive, |
|
attention_units: int = 256, |
|
score_mask_value=-float("inf"), |
|
): |
|
super().__init__() |
|
self.score_mask_value = score_mask_value |
|
self.attention_type = attention_type |
|
|
|
if attention_type == AttentionType.LocationSensitive: |
|
self.attention_mechanism = LocationSensitive(attention_units) |
|
elif attention_type == AttentionType.Content_Based: |
|
self.attention_mechanism = BahdanauAttention(attention_units) |
|
else: |
|
raise Exception("The attention type is not known") |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
keys: torch.Tensor, |
|
values: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
prev_alignment: Optional[torch.Tensor] = None, |
|
): |
|
|
|
|
|
|
|
if self.attention_type == AttentionType.Content_Based: |
|
alignment = self.attention_mechanism(query, keys) |
|
else: |
|
alignment = self.attention_mechanism(query, keys, prev_alignment) |
|
|
|
|
|
|
|
if mask is not None: |
|
alignment.data.masked_fill_(mask, self.score_mask_value) |
|
|
|
alignment = F.softmax(alignment, dim=1) |
|
attention = torch.bmm(alignment.unsqueeze(1), values) |
|
attention = attention.squeeze(1) |
|
|
|
return attention, alignment |
|
|
|
|
|
class MultiHeadAttentionLayer(nn.Module): |
|
def __init__(self, hid_dim: int, n_heads: int, dropout: float = 0.0): |
|
super().__init__() |
|
|
|
assert hid_dim % n_heads == 0 |
|
|
|
self.hid_dim = hid_dim |
|
self.n_heads = n_heads |
|
self.head_dim = hid_dim // n_heads |
|
|
|
self.fc_q = nn.Linear(hid_dim, hid_dim) |
|
self.fc_k = nn.Linear(hid_dim, hid_dim) |
|
self.fc_v = nn.Linear(hid_dim, hid_dim) |
|
|
|
self.fc_o = nn.Linear(hid_dim * 2, hid_dim) |
|
|
|
if dropout != 0.0: |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.use_dropout = dropout != 0.0 |
|
|
|
device = next(self.parameters()).device |
|
|
|
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) |
|
|
|
def forward(self, query, key, value, mask=None): |
|
|
|
batch_size = query.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
Q = self.fc_q(query) |
|
K = self.fc_k(key) |
|
V = self.fc_v(value) |
|
|
|
|
|
|
|
|
|
|
|
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
|
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
|
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
|
|
|
|
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale |
|
|
|
|
|
|
|
if mask is not None: |
|
energy = energy.masked_fill(mask == 0, -float("inf")) |
|
|
|
attention = torch.softmax(energy, dim=-1) |
|
|
|
|
|
|
|
if self.use_dropout: |
|
context_vector = torch.matmul(self.dropout(attention), V) |
|
else: |
|
context_vector = torch.matmul(attention, V) |
|
|
|
|
|
|
|
context_vector = context_vector.permute(0, 2, 1, 3).contiguous() |
|
|
|
|
|
|
|
context_vector = context_vector.view(batch_size, -1, self.hid_dim) |
|
|
|
x = torch.cat((query, context_vector), dim=-1) |
|
|
|
|
|
|
|
x = self.fc_o(x) |
|
|
|
|
|
|
|
return x, attention |
|
|