from typing import Optional import torch from torch import nn import torch.nn.functional as F from poetry_diacritizer.options import AttentionType class BahdanauAttention(nn.Module): def __init__(self, dim): super(BahdanauAttention, self).__init__() self.query_layer = nn.Linear(dim, dim, bias=False) self.tanh = nn.Tanh() self.v = nn.Linear(dim, 1, bias=False) def forward(self, query: torch.Tensor, keys: torch.Tensor): """ Args: query: (B, 1, dim) or (batch, dim) processed_memory: (batch, max_time, dim) """ if query.dim() == 2: # insert time-axis for broadcasting query = query.unsqueeze(1) # (batch, 1, dim) query = self.query_layer(query) # (batch, max_time, 1) alignment = self.v(self.tanh(query + keys)) # (batch, max_time) return alignment.squeeze(-1) class LocationSensitive(nn.Module): def __init__(self, dim): super(LocationSensitive, self).__init__() self.query_layer = nn.Linear(dim, dim, bias=False) self.v = nn.Linear(dim, 1, bias=True) self.location_layer = nn.Linear(32, dim, bias=False) padding = int((31 - 1) / 2) self.location_conv = torch.nn.Conv1d( 1, 32, kernel_size=31, stride=1, padding=padding, dilation=1, bias=False ) self.score_mask_value = -float("inf") def forward( self, query: torch.Tensor, keys: torch.Tensor, prev_alignments: torch.Tensor, ): # keys = keys.permute(1,0,2) query = self.query_layer(query) if query.dim() == 2: # insert time-axis for broadcasting query = query.unsqueeze(1) # -> [batch_size, 1, attention_dim] alignments = prev_alignments.unsqueeze(1) # location features [batch_size, max_time, filters] filters = self.location_conv(alignments) location_features = self.location_layer(filters.transpose(1, 2)) alignments = self.v(torch.tanh(query + location_features + keys)) return alignments.squeeze(-1) class AttentionWrapper(nn.Module): def __init__( self, attention_type: AttentionType = AttentionType.LocationSensitive, attention_units: int = 256, score_mask_value=-float("inf"), ): super().__init__() self.score_mask_value = score_mask_value self.attention_type = attention_type if attention_type == AttentionType.LocationSensitive: self.attention_mechanism = LocationSensitive(attention_units) elif attention_type == AttentionType.Content_Based: self.attention_mechanism = BahdanauAttention(attention_units) else: raise Exception("The attention type is not known") def forward( self, query: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, mask: Optional[torch.Tensor] = None, prev_alignment: Optional[torch.Tensor] = None, ): # Alignment # (batch, max_time) if self.attention_type == AttentionType.Content_Based: alignment = self.attention_mechanism(query, keys) else: alignment = self.attention_mechanism(query, keys, prev_alignment) # Attention context vector if mask is not None: alignment.data.masked_fill_(mask, self.score_mask_value) alignment = F.softmax(alignment, dim=1) attention = torch.bmm(alignment.unsqueeze(1), values) attention = attention.squeeze(1) return attention, alignment class MultiHeadAttentionLayer(nn.Module): def __init__(self, hid_dim: int, n_heads: int, dropout: float = 0.0): super().__init__() assert hid_dim % n_heads == 0 self.hid_dim = hid_dim self.n_heads = n_heads self.head_dim = hid_dim // n_heads self.fc_q = nn.Linear(hid_dim, hid_dim) self.fc_k = nn.Linear(hid_dim, hid_dim) self.fc_v = nn.Linear(hid_dim, hid_dim) self.fc_o = nn.Linear(hid_dim * 2, hid_dim) if dropout != 0.0: self.dropout = nn.Dropout(dropout) self.use_dropout = dropout != 0.0 device = next(self.parameters()).device self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) def forward(self, query, key, value, mask=None): batch_size = query.shape[0] # query = [batch size, query len, hid dim] # key = [batch size, key len, hid dim] # value = [batch size, value len, hid dim] Q = self.fc_q(query) K = self.fc_k(key) V = self.fc_v(value) # Q = [batch size, query len, hid dim] # K = [batch size, key len, hid dim] # V = [batch size, value len, hid dim] Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # Q = [batch size, n heads, query len, head dim] # K = [batch size, n heads, key len, head dim] # V = [batch size, n heads, value len, head dim] energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale # energy = [batch size, n heads, query len, key len] if mask is not None: energy = energy.masked_fill(mask == 0, -float("inf")) attention = torch.softmax(energy, dim=-1) # attention = [batch size, n heads, query len, key len] if self.use_dropout: context_vector = torch.matmul(self.dropout(attention), V) else: context_vector = torch.matmul(attention, V) # x = [batch size, n heads, query len, head dim] context_vector = context_vector.permute(0, 2, 1, 3).contiguous() # x = [batch size, query len, n heads, head dim] context_vector = context_vector.view(batch_size, -1, self.hid_dim) x = torch.cat((query, context_vector), dim=-1) # x = [batch size, query len, hid dim * 2] x = self.fc_o(x) # x = [batch size, query len, hid dim] return x, attention