File size: 4,366 Bytes
d59aeff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.nn as nn
import torch.nn.functional as F


class MOLAttention(nn.Module):
    """ Discretized Mixture of Logistic (MOL) attention.
    C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and 
        GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis".
    """
    def __init__(
        self,
        query_dim,
        r=1,
        M=5,
    ):
        """
        Args:
            query_dim: attention_rnn_dim.
            M: number of mixtures.
        """
        super().__init__()
        if r < 1:
            self.r = float(r)
        else:
            self.r = int(r)
        self.M = M
        self.score_mask_value = 0.0 # -float("inf")
        self.eps = 1e-5
        # Position arrary for encoder time steps
        self.J = None
        # Query layer: [w, sigma,]
        self.query_layer = torch.nn.Sequential(
            nn.Linear(query_dim, 256, bias=True),
            nn.ReLU(),
            nn.Linear(256, 3*M, bias=True)
        )
        self.mu_prev = None
        self.initialize_bias()

    def initialize_bias(self):
        """Initialize sigma and Delta."""
        # sigma
        torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0)
        # Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0
        # softplus(-0.432) = 0.5003
        if self.r == 2:
            torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545)
        elif self.r == 4:
            torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815)
        elif self.r == 1:
            torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413)
        else:
            torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432)

    
    def init_states(self, memory):
        """Initialize mu_prev and J.
            This function should be called by the decoder before decoding one batch.
        Args:
            memory: (B, T, D_enc) encoder output.
        """
        B, T_enc, _ = memory.size()
        device = memory.device
        self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5  # NOTE: for discretize usage
        # self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float)
        self.mu_prev = torch.zeros(B, self.M).to(device)

    def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None):
        """
        att_rnn_h: attetion rnn hidden state.
        memory: encoder outputs (B, T_enc, D).
        mask: binary mask for padded data (B, T_enc).
        """
        # [B, 3M]
        mixture_params = self.query_layer(att_rnn_h)
        
        # [B, M]
        w_hat = mixture_params[:, :self.M]
        sigma_hat = mixture_params[:, self.M:2*self.M]
        Delta_hat = mixture_params[:, 2*self.M:3*self.M]
        
        # print("w_hat: ", w_hat)
        # print("sigma_hat: ", sigma_hat)
        # print("Delta_hat: ", Delta_hat)

        # Dropout to de-correlate attention heads
        w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed?
        
        # Mixture parameters
        w = torch.softmax(w_hat, dim=-1) + self.eps
        sigma = F.softplus(sigma_hat) + self.eps
        Delta = F.softplus(Delta_hat)
        mu_cur = self.mu_prev + Delta
        # print("w:", w)
        j = self.J[:memory.size(1) + 1]

        # Attention weights
        # CDF of logistic distribution
        phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid(
            (mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1))))
        # print("phi_t:", phi_t)
        
        # Discretize attention weights
        # (B, T_enc + 1)
        alpha_t = torch.sum(phi_t, dim=1)
        alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
        alpha_t[alpha_t == 0] = self.eps
        # print("alpha_t: ", alpha_t.size())
        # Apply masking
        if mask is not None:
            alpha_t.data.masked_fill_(mask, self.score_mask_value)

        context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1)
        if memory_pitch is not None:
            context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1)

        self.mu_prev = mu_cur
        
        if memory_pitch is not None:
            return context, context_pitch, alpha_t
        return context, alpha_t