sudip1310's picture
Upload 5 files
998b155
raw
history blame
2.71 kB
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
class BahdanauAttention(nn.Module):
def __init__(self, dim):
super(BahdanauAttention, self).__init__()
self.query_layer = nn.Linear(dim, dim, bias=False)
self.tanh = nn.Tanh()
self.v = nn.Linear(dim, 1, bias=False)
def forward(self, query, processed_memory):
"""
Args:
query: (batch, 1, dim) or (batch, dim)
processed_memory: (batch, max_time, dim)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
# (batch, 1, dim)
processed_query = self.query_layer(query)
# (batch, max_time, 1)
alignment = self.v(self.tanh(processed_query + processed_memory))
# (batch, max_time)
return alignment.squeeze(-1)
def get_mask_from_lengths(memory, memory_lengths):
"""Get mask tensor from list of length
Args:
memory: (batch, max_time, dim)
memory_lengths: array like
"""
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
for idx, l in enumerate(memory_lengths):
mask[idx][:l] = 1
return ~mask
class AttentionWrapper(nn.Module):
def __init__(self, rnn_cell, attention_mechanism,
score_mask_value=-float("inf")):
super(AttentionWrapper, self).__init__()
self.rnn_cell = rnn_cell
self.attention_mechanism = attention_mechanism
self.score_mask_value = score_mask_value
def forward(self, query, attention, cell_state, memory,
processed_memory=None, mask=None, memory_lengths=None):
if processed_memory is None:
processed_memory = memory
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
# Concat input query and previous attention context
cell_input = torch.cat((query, attention), -1)
# Feed it to RNN
cell_output = self.rnn_cell(cell_input, cell_state)
# Alignment
# (batch, max_time)
alignment = self.attention_mechanism(cell_output, processed_memory)
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize attention weight
alignment = F.softmax(alignment)
# Attention context vector
# (batch, 1, dim)
attention = torch.bmm(alignment.unsqueeze(1), memory)
# (batch, dim)
attention = attention.squeeze(1)
return cell_output, attention, alignment