File size: 5,880 Bytes
2493d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch

from TTS.tts.layers.speedy_speech.encoder import Encoder
from TTS.tts.layers.speedy_speech.decoder import Decoder
from TTS.tts.layers.speedy_speech.duration_predictor import DurationPredictor
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.models.speedy_speech import SpeedySpeech


use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def test_encoder():
    input_dummy = torch.rand(8, 14, 37).to(device)
    input_lengths = torch.randint(31, 37, (8, )).long().to(device)
    input_lengths[-1] = 37
    input_mask = torch.unsqueeze(
        sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)

    # residual bn conv encoder
    layer = Encoder(out_channels=11,
                    in_hidden_channels=14,
                    encoder_type='residual_conv_bn').to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]

    # transformer encoder
    layer = Encoder(out_channels=11,
                    in_hidden_channels=14,
                    encoder_type='transformer',
                    encoder_params={
                        'hidden_channels_ffn': 768,
                        'num_heads': 2,
                        "kernel_size": 3,
                        "dropout_p": 0.1,
                        "num_layers": 6,
                        "rel_attn_window_size": 4,
                        "input_length": None
                    }).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]


def test_decoder():
    input_dummy = torch.rand(8, 128, 37).to(device)
    input_lengths = torch.randint(31, 37, (8, )).long().to(device)
    input_lengths[-1] = 37

    input_mask = torch.unsqueeze(
        sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)

    # residual bn conv decoder
    layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]

    # transformer decoder
    layer = Decoder(out_channels=11,
                    in_hidden_channels=128,
                    decoder_type='transformer',
                    decoder_params={
                        'hidden_channels_ffn': 128,
                        'num_heads': 2,
                        "kernel_size": 3,
                        "dropout_p": 0.1,
                        "num_layers": 8,
                        "rel_attn_window_size": 4,
                        "input_length": None
                    }).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]


    # wavenet decoder
    layer = Decoder(out_channels=11,
                    in_hidden_channels=128,
                    decoder_type='wavenet',
                    decoder_params={
                        "num_blocks": 12,
                        "hidden_channels": 192,
                        "kernel_size": 5,
                        "dilation_rate": 1,
                        "num_layers": 4,
                        "dropout_p": 0.05
                    }).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]



def test_duration_predictor():
    input_dummy = torch.rand(8, 128, 27).to(device)
    input_lengths = torch.randint(20, 27, (8, )).long().to(device)
    input_lengths[-1] = 27

    x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)),
                             1).to(device)

    layer = DurationPredictor(hidden_channels=128).to(device)

    output = layer(input_dummy, x_mask)
    assert list(output.shape) == [8, 1, 27]


def test_speedy_speech():
    num_chars = 7
    B = 8
    T_en = 37
    T_de = 74

    x_dummy = torch.randint(0, 7, (B, T_en)).long().to(device)
    x_lengths = torch.randint(31, T_en, (B, )).long().to(device)
    x_lengths[-1] = T_en

    # set durations. max total duration should be equal to T_de
    durations = torch.randint(1, 4, (B, T_en))
    durations = durations * (T_de / durations.sum(1)).unsqueeze(1)
    durations = durations.to(torch.long).to(device)
    max_dur = durations.sum(1).max()
    durations[:, 0] += T_de - max_dur if T_de > max_dur else 0

    y_lengths = durations.sum(1)

    model = SpeedySpeech(num_chars, out_channels=80, hidden_channels=128)
    if use_cuda:
        model.cuda()

    # forward pass
    o_de, o_dr, attn = model(x_dummy, x_lengths, y_lengths, durations)

    assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}"
    assert list(attn.shape) == [B, T_de, T_en]
    assert list(o_dr.shape) == [B, T_en]

    # with speaker embedding
    model = SpeedySpeech(num_chars,
                         out_channels=80,
                         hidden_channels=128,
                         num_speakers=10,
                         c_in_channels=256).to(device)
    model.forward(x_dummy,
                  x_lengths,
                  y_lengths,
                  durations,
                  g=torch.randint(0, 10, (B,)).to(device))

    assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}"
    assert list(attn.shape) == [B, T_de, T_en]
    assert list(o_dr.shape) == [B, T_en]


    # with speaker external embedding
    model = SpeedySpeech(num_chars,
                         out_channels=80,
                         hidden_channels=128,
                         num_speakers=10,
                         external_c=True,
                         c_in_channels=256).to(device)
    model.forward(x_dummy,
                  x_lengths,
                  y_lengths,
                  durations,
                  g=torch.rand((B,256)).to(device))

    assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}"
    assert list(attn.shape) == [B, T_de, T_en]
    assert list(o_dr.shape) == [B, T_en]