Spaces:
Runtime error
Runtime error
import torch | |
from torch.autograd import Variable | |
from torch import nn | |
from torch.nn import functional as F | |
class BahdanauAttention(nn.Module): | |
def __init__(self, dim): | |
super(BahdanauAttention, self).__init__() | |
self.query_layer = nn.Linear(dim, dim, bias=False) | |
self.tanh = nn.Tanh() | |
self.v = nn.Linear(dim, 1, bias=False) | |
def forward(self, query, processed_memory): | |
""" | |
Args: | |
query: (batch, 1, dim) or (batch, dim) | |
processed_memory: (batch, max_time, dim) | |
""" | |
if query.dim() == 2: | |
# insert time-axis for broadcasting | |
query = query.unsqueeze(1) | |
# (batch, 1, dim) | |
processed_query = self.query_layer(query) | |
# (batch, max_time, 1) | |
alignment = self.v(self.tanh(processed_query + processed_memory)) | |
# (batch, max_time) | |
return alignment.squeeze(-1) | |
def get_mask_from_lengths(memory, memory_lengths): | |
"""Get mask tensor from list of length | |
Args: | |
memory: (batch, max_time, dim) | |
memory_lengths: array like | |
""" | |
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_() | |
for idx, l in enumerate(memory_lengths): | |
mask[idx][:l] = 1 | |
return ~mask | |
class AttentionWrapper(nn.Module): | |
def __init__(self, rnn_cell, attention_mechanism, | |
score_mask_value=-float("inf")): | |
super(AttentionWrapper, self).__init__() | |
self.rnn_cell = rnn_cell | |
self.attention_mechanism = attention_mechanism | |
self.score_mask_value = score_mask_value | |
def forward(self, query, attention, cell_state, memory, | |
processed_memory=None, mask=None, memory_lengths=None): | |
if processed_memory is None: | |
processed_memory = memory | |
if memory_lengths is not None and mask is None: | |
mask = get_mask_from_lengths(memory, memory_lengths) | |
# Concat input query and previous attention context | |
cell_input = torch.cat((query, attention), -1) | |
# Feed it to RNN | |
cell_output = self.rnn_cell(cell_input, cell_state) | |
# Alignment | |
# (batch, max_time) | |
alignment = self.attention_mechanism(cell_output, processed_memory) | |
if mask is not None: | |
mask = mask.view(query.size(0), -1) | |
alignment.data.masked_fill_(mask, self.score_mask_value) | |
# Normalize attention weight | |
alignment = F.softmax(alignment) | |
# Attention context vector | |
# (batch, 1, dim) | |
attention = torch.bmm(alignment.unsqueeze(1), memory) | |
# (batch, dim) | |
attention = attention.squeeze(1) | |
return cell_output, attention, alignment | |