BANAO-Task2-Text-to-speech / tests /test_attention.py
sudip1310's picture
Upload 5 files
998b155
import torch
from torch.autograd import Variable
from torch import nn
from tacotron_pytorch.attention import BahdanauAttention, AttentionWrapper
from tacotron_pytorch.attention import get_mask_from_lengths
def test_attention_wrapper():
B = 2
encoder_outputs = Variable(torch.rand(B, 100, 256))
memory_lengths = [100, 50]
mask = get_mask_from_lengths(encoder_outputs, memory_lengths)
print("Mask size:", mask.size())
memory_layer = nn.Linear(256, 256)
query = Variable(torch.rand(B, 128))
attention_mechanism = BahdanauAttention(256)
# Attention context + input
rnn = nn.GRUCell(256 + 128, 256)
attention_rnn = AttentionWrapper(rnn, attention_mechanism)
initial_attention = Variable(torch.zeros(B, 256))
cell_state = Variable(torch.zeros(B, 256))
processed_memory = memory_layer(encoder_outputs)
cell_output, attention, alignment = attention_rnn(
query, initial_attention, cell_state, encoder_outputs,
processed_memory=processed_memory,
mask=None, memory_lengths=memory_lengths)
print("Cell output size:", cell_output.size())
print("Attention output size:", attention.size())
print("Alignment size:", alignment.size())
assert (alignment.sum(-1) == 1).data.all()
test_attention_wrapper()