File size: 3,528 Bytes
39d2f14
 
 
 
 
597cecf
39d2f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597cecf
 
 
 
39d2f14
 
 
 
 
597cecf
39d2f14
 
 
 
 
597cecf
 
 
 
 
 
 
 
39d2f14
 
 
597cecf
39d2f14
597cecf
 
 
39d2f14
 
597cecf
 
 
 
 
 
 
39d2f14
597cecf
39d2f14
 
 
 
 
 
597cecf
 
 
 
39d2f14
 
597cecf
39d2f14
 
 
 
 
 
597cecf
 
39d2f14
 
597cecf
39d2f14
 
597cecf
 
 
39d2f14
 
 
 
 
 
597cecf
39d2f14
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn

# from tts_encode import tts_encode


def calculate_remaining_lengths(mel_lengths):
    B = mel_lengths.shape[0]
    max_L = mel_lengths.max().item()  # Get the maximum length in the batch

    # Create a range tensor: shape (max_L,), [0, 1, 2, ..., max_L-1]
    range_tensor = torch.arange(max_L, device=mel_lengths.device).expand(B, max_L)

    # Compute targets using broadcasting: (L-1) - range_tensor
    remain_lengths = (mel_lengths[:, None] - 1 - range_tensor).clamp(min=0)

    return remain_lengths


class PositionalEncoding(nn.Module):
    def __init__(self, hidden_dim, max_len=4096):
        super().__init__()
        pe = torch.zeros(max_len, hidden_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float()
            * (-torch.log(torch.tensor(10000.0)) / hidden_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # Shape: (1, max_len, hidden_dim)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].to(x.device)
        return x


class SpeechLengthPredictor(nn.Module):

    def __init__(
        self,
        vocab_size=2545,
        n_mel=100,
        hidden_dim=256,
        n_text_layer=4,
        n_cross_layer=4,
        n_head=8,
        output_dim=1,
    ):
        super().__init__()

        # Text Encoder: Embedding + Transformer Layers
        self.text_embedder = nn.Embedding(
            vocab_size + 1, hidden_dim, padding_idx=vocab_size
        )
        self.text_pe = PositionalEncoding(hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_head,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
        )
        self.text_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=n_text_layer
        )

        # Mel Spectrogram Embedder
        self.mel_embedder = nn.Linear(n_mel, hidden_dim)
        self.mel_pe = PositionalEncoding(hidden_dim)

        # Transformer Decoder Layers with Cross-Attention in Every Layer
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=n_head,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer)

        # Final Classification Layer
        self.predictor = nn.Linear(hidden_dim, output_dim)

    def forward(self, text_ids, mel):
        # Encode text
        text_embedded = self.text_pe(self.text_embedder(text_ids))
        text_features = self.text_encoder(text_embedded)  # (B, L_text, D)

        # Encode Mel spectrogram
        mel_features = self.mel_pe(self.mel_embedder(mel))  # (B, L_mel, D)

        # Causal Masking for Decoder
        seq_len = mel_features.size(1)
        causal_mask = (
            torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device)
        )
        # causal_mask = torch.triu(
        #     torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1
        # )

        # Transformer Decoder with Cross-Attention in Each Layer
        decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask)

        # Length Prediction
        length_logits = self.predictor(decoder_out).squeeze(-1)
        return length_logits