Spaces:
Runtime error
Runtime error
import torch | |
from torch.autograd import Variable | |
from torch import nn | |
from tacotron_pytorch.attention import BahdanauAttention, AttentionWrapper | |
from tacotron_pytorch.attention import get_mask_from_lengths | |
def test_attention_wrapper(): | |
B = 2 | |
encoder_outputs = Variable(torch.rand(B, 100, 256)) | |
memory_lengths = [100, 50] | |
mask = get_mask_from_lengths(encoder_outputs, memory_lengths) | |
print("Mask size:", mask.size()) | |
memory_layer = nn.Linear(256, 256) | |
query = Variable(torch.rand(B, 128)) | |
attention_mechanism = BahdanauAttention(256) | |
# Attention context + input | |
rnn = nn.GRUCell(256 + 128, 256) | |
attention_rnn = AttentionWrapper(rnn, attention_mechanism) | |
initial_attention = Variable(torch.zeros(B, 256)) | |
cell_state = Variable(torch.zeros(B, 256)) | |
processed_memory = memory_layer(encoder_outputs) | |
cell_output, attention, alignment = attention_rnn( | |
query, initial_attention, cell_state, encoder_outputs, | |
processed_memory=processed_memory, | |
mask=None, memory_lengths=memory_lengths) | |
print("Cell output size:", cell_output.size()) | |
print("Attention output size:", attention.size()) | |
print("Alignment size:", alignment.size()) | |
assert (alignment.sum(-1) == 1).data.all() | |
test_attention_wrapper() | |