Spaces:
Runtime error
Runtime error
File size: 1,301 Bytes
998b155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
from torch.autograd import Variable
from torch import nn
from tacotron_pytorch.attention import BahdanauAttention, AttentionWrapper
from tacotron_pytorch.attention import get_mask_from_lengths
def test_attention_wrapper():
B = 2
encoder_outputs = Variable(torch.rand(B, 100, 256))
memory_lengths = [100, 50]
mask = get_mask_from_lengths(encoder_outputs, memory_lengths)
print("Mask size:", mask.size())
memory_layer = nn.Linear(256, 256)
query = Variable(torch.rand(B, 128))
attention_mechanism = BahdanauAttention(256)
# Attention context + input
rnn = nn.GRUCell(256 + 128, 256)
attention_rnn = AttentionWrapper(rnn, attention_mechanism)
initial_attention = Variable(torch.zeros(B, 256))
cell_state = Variable(torch.zeros(B, 256))
processed_memory = memory_layer(encoder_outputs)
cell_output, attention, alignment = attention_rnn(
query, initial_attention, cell_state, encoder_outputs,
processed_memory=processed_memory,
mask=None, memory_lengths=memory_lengths)
print("Cell output size:", cell_output.size())
print("Attention output size:", attention.size())
print("Alignment size:", alignment.size())
assert (alignment.sum(-1) == 1).data.all()
test_attention_wrapper()
|