File size: 1,301 Bytes
998b155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from torch.autograd import Variable
from torch import nn

from tacotron_pytorch.attention import BahdanauAttention, AttentionWrapper
from tacotron_pytorch.attention import get_mask_from_lengths


def test_attention_wrapper():
    B = 2

    encoder_outputs = Variable(torch.rand(B, 100, 256))
    memory_lengths = [100, 50]

    mask = get_mask_from_lengths(encoder_outputs, memory_lengths)
    print("Mask size:", mask.size())

    memory_layer = nn.Linear(256, 256)
    query = Variable(torch.rand(B, 128))

    attention_mechanism = BahdanauAttention(256)

    # Attention context + input
    rnn = nn.GRUCell(256 + 128, 256)

    attention_rnn = AttentionWrapper(rnn, attention_mechanism)
    initial_attention = Variable(torch.zeros(B, 256))
    cell_state = Variable(torch.zeros(B, 256))

    processed_memory = memory_layer(encoder_outputs)

    cell_output, attention, alignment = attention_rnn(
        query, initial_attention, cell_state, encoder_outputs,
        processed_memory=processed_memory,
        mask=None, memory_lengths=memory_lengths)

    print("Cell output size:", cell_output.size())
    print("Attention output size:", attention.size())
    print("Alignment size:", alignment.size())

    assert (alignment.sum(-1) == 1).data.all()


test_attention_wrapper()