# -*- coding: utf-8 -*- """ Define the Attention Layer of the model. """ from __future__ import print_function, division import torch from torch.autograd import Variable from torch.nn import Module from torch.nn.parameter import Parameter class Attention(Module): """ Computes a weighted average of the different channels across timesteps. Uses 1 parameter pr. channel to compute the attention value for a single timestep. """ def __init__(self, attention_size, return_attention=False): """ Initialize the attention layer # Arguments: attention_size: Size of the attention vector. return_attention: If true, output will include the weight for each input token used for the prediction """ super(Attention, self).__init__() self.return_attention = return_attention self.attention_size = attention_size self.attention_vector = Parameter(torch.FloatTensor(attention_size)) self.attention_vector.data.normal_(std=0.05) # Initialize attention vector def __repr__(self): s = '{name}({attention_size}, return attention={return_attention})' return s.format(name=self.__class__.__name__, **self.__dict__) def forward(self, inputs, input_lengths): """ Forward pass. # Arguments: inputs (Torch.Variable): Tensor of input sequences input_lengths (torch.LongTensor): Lengths of the sequences # Return: Tuple with (representations and attentions if self.return_attention else None). """ logits = inputs.matmul(self.attention_vector) unnorm_ai = (logits - logits.max()).exp() # Compute a mask for the attention on the padded sequences # See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5 max_len = unnorm_ai.size(1) idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0) mask = Variable((idxes < input_lengths.unsqueeze(1)).float()) # apply mask and renormalize attention scores (weights) if self.attention_vector.device.type == "cuda": masked_weights = unnorm_ai * mask.cuda() else: masked_weights = unnorm_ai * mask att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence attentions = masked_weights.div(att_sums) # apply attention weights weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs)) # get the final fixed vector representations of the sentences representations = weighted.sum(dim=1) return (representations, attentions if self.return_attention else None)