File size: 1,583 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) 2024 Amphion.
#
# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/length_regulator.py
# Licensed under Apache License 2.0

import torch


class GaussianUpsampling(torch.nn.Module):
    """
    Gaussian upsampling with fixed temperature as in:
    https://arxiv.org/abs/2010.04301
    """

    def __init__(self, delta=0.1):
        super().__init__()
        self.delta = delta

    def forward(self, hs, ds, h_masks=None, d_masks=None):
        """
        Args:
            hs (Tensor): Batched hidden state to be expanded (B, T_text, adim)
            ds (Tensor): Batched token duration (B, T_text)
            h_masks (Tensor): Mask tensor (B,T_feats)
            d_masks (Tensor): Mask tensor (B,T_text)
        Returns:
            Tensor: Expanded hidden state (B, T_feat, adim)
        """
        B = ds.size(0)
        device = ds.device

        if h_masks is None:
            T_feats = ds.sum().int()
        else:
            T_feats = h_masks.size(-1)
        t = torch.arange(0, T_feats).unsqueeze(0).repeat(B, 1).to(device).float()
        if h_masks is not None:
            t = t * h_masks.float()

        c = ds.cumsum(dim=-1) - ds / 2
        energy = -1 * self.delta * (t.unsqueeze(-1) - c.unsqueeze(1)) ** 2
        if d_masks is not None:
            energy = energy.masked_fill(
                ~(d_masks.unsqueeze(1).repeat(1, T_feats, 1)), -float("inf")
            )

        p_attn = torch.softmax(energy, dim=2)  # (B, T_feats, T_text)
        hs = torch.matmul(p_attn, hs)
        return hs