File size: 4,595 Bytes
1547a56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST

# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------

import contextlib
import torch
import torch.nn as nn

import torch.nn.functional as F
from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as TacotronDecoderPrenet
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask


class SpeechDecoderPrenet(nn.Module):
    """

    Args:
        in_channels (int): the number of input channels
        mid_channels (int): the number of intermediate channels
        out_channels (int): the number of output channels
        kernel_sizes (List[int]): the kernel size for each convolutional layer
    """

    def __init__(
        self,
        odim,
        args,
    ):
        super(SpeechDecoderPrenet, self).__init__()
        # define decoder prenet
        if args.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                TacotronDecoderPrenet(
                    idim=odim,
                    n_layers=args.dprenet_layers,
                    n_units=args.dprenet_units,
                    dropout_rate=args.dprenet_dropout_rate,
                ),
                torch.nn.Linear(args.dprenet_units, args.decoder_embed_dim),
            )
        else:
            decoder_input_layer = "linear"

        pos_enc_class = (
            ScaledPositionalEncoding if args.dec_use_scaled_pos_enc else PositionalEncoding
        )

        if decoder_input_layer == "linear":
            self.decoder_prenet = torch.nn.Sequential(
                torch.nn.Linear(odim, args.decoder_embed_dim),
                torch.nn.LayerNorm(args.decoder_embed_dim),
                torch.nn.Dropout(args.transformer_dec_dropout_rate),
                torch.nn.ReLU(),
                pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate),
            )
        elif isinstance(decoder_input_layer, torch.nn.Module):
            self.decoder_prenet = torch.nn.Sequential(
                decoder_input_layer, pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate, max_len=args.max_speech_positions)
            )

        if args.spk_embed_integration_type == 'pre':
            self.spkembs_layer = torch.nn.Sequential(
                torch.nn.Linear(args.spk_embed_dim + args.decoder_embed_dim, args.decoder_embed_dim), torch.nn.ReLU()
            )
        self.num_updates = 0
        self.freeze_decoder_updates = args.freeze_decoder_updates

    def forward(self, prev_output_tokens, tgt_lengths_in=None, spkembs=None):
        ft = self.freeze_decoder_updates <= self.num_updates
        with torch.no_grad() if not ft else contextlib.ExitStack():
            prev_output_tokens = self.decoder_prenet(prev_output_tokens)

            if spkembs is not None:
                spkembs = F.normalize(spkembs).unsqueeze(1).expand(-1, prev_output_tokens.size(1), -1)
                prev_output_tokens = self.spkembs_layer(torch.cat([prev_output_tokens, spkembs], dim=-1))

            if tgt_lengths_in is not None:
                tgt_frames_mask = ~(self._source_mask(tgt_lengths_in).squeeze(1))
            else:
                tgt_frames_mask = None
            return prev_output_tokens, tgt_frames_mask

    def _source_mask(self, ilens):
        """Make masks for self-attention.
        Args:
            ilens (LongTensor or List): Batch of lengths (B,).
        Returns:
            Tensor: Mask tensor for self-attention.
                    dtype=torch.uint8 in PyTorch 1.2-
                    dtype=torch.bool in PyTorch 1.2+ (including 1.2)
        Examples:
            >>> ilens = [5, 3]
            >>> self._source_mask(ilens)
            tensor([[[1, 1, 1, 1, 1],
                    [[1, 1, 1, 0, 0]]], dtype=torch.uint8)
        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2)
    
    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self.num_updates = num_updates