File size: 2,695 Bytes
1547a56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST

# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------

import contextlib
import torch
import torch.nn as nn

from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet


class SpeechDecoderPostnet(nn.Module):
    """

    Args:
        in_channels (int): the number of input channels
        mid_channels (int): the number of intermediate channels
        out_channels (int): the number of output channels
        kernel_sizes (List[int]): the kernel size for each convolutional layer
    """

    def __init__(
        self,
        odim,
        args,
    ):
        super(SpeechDecoderPostnet, self).__init__()
        # define decoder postnet
        # define final projection
        self.feat_out = torch.nn.Linear(args.decoder_embed_dim, odim * args.reduction_factor)
        self.prob_out = torch.nn.Linear(args.decoder_embed_dim, args.reduction_factor)

        # define postnet
        self.postnet = (
            None
            if args.postnet_layers == 0
            else Postnet(
                idim=0,
                odim=odim,
                n_layers=args.postnet_layers,
                n_chans=args.postnet_chans,
                n_filts=args.postnet_filts,
                use_batch_norm=args.use_batch_norm,
                dropout_rate=args.postnet_dropout_rate,
            )
        )

        self.odim = odim
        self.num_updates = 0
        self.freeze_decoder_updates = args.freeze_decoder_updates

    def forward(self, zs):
        ft = self.freeze_decoder_updates <= self.num_updates
        with torch.no_grad() if not ft else contextlib.ExitStack():
            # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
            before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
            # (B, Lmax//r, r) -> (B, Lmax//r * r)
            logits = self.prob_out(zs).view(zs.size(0), -1)
            # postnet -> (B, Lmax//r * r, odim)
            if self.postnet is None:
                after_outs = before_outs
            else:
                after_outs = before_outs + self.postnet(
                    before_outs.transpose(1, 2)
                ).transpose(1, 2)

        return before_outs, after_outs, logits
    
    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self.num_updates = num_updates