File size: 17,389 Bytes
8520a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from .nn_future import (FNNSwiGLU, MistralTransformer, ModelArgs,
                        RotatingBufferCache, SinePositionalEmbedding)
from .utils import construct_padding_mask, length_to_mask

LAYERNORM_EPS = 4e-5

# ------------------------
# Code adapted from OpenAI guided diffusion repo

def timestep_embedding(timesteps, dim, max_period=10000, dtype=torch.float32):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


# --------------------------------
# autoregressive codec language model


class CodecLM(nn.Module):

    def __init__(self, n_vocab, dim=1536, nhead=24, n_layers=26, n_spk_layers=2, dim_ff_scale=None, sliding_window=3000) -> None:
        super().__init__()

        if dim_ff_scale is None: hidden_dim = int(dim*4*(3/4))
        else: hidden_dim = int(dim*dim_ff_scale)

        self.cfg = ModelArgs(n_vocab, dim=dim, n_layers=n_layers, n_heads=nhead, n_kv_heads=nhead, hidden_dim=hidden_dim, sliding_window=sliding_window)
        self.ar = MistralTransformer(self.cfg)

        self.embed = nn.Embedding(n_vocab, dim)

        # --- spk embedding network
        dim_ff = int(dim*4*(3/4))
        self.pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True)
        self.ref_chunked_emb = ChunkedEmbedding(1024 + 1, 8, dim) # add 1 for pad idx
        self.spk_identity_emb = nn.Embedding(1, dim)
        # define custom decoder
        encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff,
                                                activation=FNNSwiGLU(dim, dim_ff), dropout=0,
                                                batch_first=True, norm_first=True, layer_norm_eps=LAYERNORM_EPS)
        encoder_layer.linear1 = nn.Identity()
        self.spk_encoder = nn.TransformerEncoder(encoder_layer, n_spk_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS))
        # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder
        for l in self.spk_encoder.layers: l.activation = FNNSwiGLU(dim, dim_ff)


    @torch.inference_mode
    def get_spk_embedding(self, spk_reference, c_codes_lengths=None) -> Tensor:
        """ Gets speaker reference embeddings using `spk_reference` codes of shape (bs, seq_len, n_codebooks). """
        bs = spk_reference.shape[0]
        if bs != 1:
            raise AssertionError(f"Speaker embedding extraction only implemented using for bs=1 currently.")
        spk_seq = self.ref_chunked_emb(spk_reference) # (bs, sl, dim)
        spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)

        spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
        # add pos encoding
        spk_seq = self.pos_embedding(spk_seq)
        # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry)
        src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024) 
        src_key_padding_mask = torch.cat((
                                            # append a zero here since we DO want to attend to initial position.
                                            torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device), 
                                            src_key_padding_mask
                                            ), 
                                            dim=1)
        # pass through transformer
        res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
        return res.squeeze(1)


    def forward(self, x: Tensor, x_padding_mask: Optional[Tensor] = None, spk_reference: Optional[Tensor] = None,
                cache: Optional[RotatingBufferCache] = None, counter: int = 0) -> Tensor:
        """ Inputs:
            - `x`: (bs, seq_len, vocab_size) 
            - `x_padding_mask`: (bs, seq_len) mask for each input, True for positions to *ignore*, False otherwise.
                Note that since this is an autoregressive model, this doesn't actually matter for infernece, so it is ignored at inference. 
            - `spk_reference`: (bs, seq_len, n_codebooks) corresponding to the speaker reference to clone from.
            - `cache` and `counter`: used for kv caching, optional.

            Returns `x` of same shape (bs, seq_len, dim)
        """
        x = self.embed(x)

        # --- speaker reference/embedding
        if spk_reference is not None:
            # compute ref
            bs = spk_reference.shape[0]
            spk_seq = self.ref_chunked_emb(spk_reference) # (bs, sl, dim)
            spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)

            spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
            # add pos encoding
            spk_seq = self.pos_embedding(spk_seq)
            # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry)
            src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024) 
            src_key_padding_mask = torch.cat((
                                                # append a zero here since we DO want to attend to initial position.
                                                torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device), 
                                                src_key_padding_mask
                                             ), 
                                             dim=1)
            # pass through transformer
            res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
            
            x = torch.cat([res, x], dim=1)

        positions = torch.arange(0, x.shape[1], device=x.device, dtype=torch.long)
        if cache is not None and counter != 1:
            # using only the last token to predict the next one
            x = x[:,-1,:].unsqueeze(1)
            positions = positions[-1:]

        x = self.ar(x, positions, cache) # (bs, seq_len, vocab)
        if spk_reference is not None and (cache is None or counter == 1):
            x = x[:, 1:] # strip out the first output token corresponding to the speaker embedding token.

        return x


# -------------------------
# residual discrete diffusion model

class ChunkedEmbedding(nn.Module):

    def __init__(self, codebook_size: int, n_quantizer: int, dim: int) -> None:
        super().__init__()
        assert dim % n_quantizer == 0, f"ChunkedEmbedding output dim ({dim}) must be divisible by n_quant {n_quantizer}"
        self.embs = nn.ModuleList([nn.Embedding(codebook_size, dim//n_quantizer) for _ in range(n_quantizer)])

    def forward(self, x: Tensor) -> Tensor:
        """ Embeds each codebook index in `x` (bs, seq_len, n_quantizer) to an embedding vector, concatenating results.
        Returns output of shape (bs, seq_len, dim)
        """
        y = torch.cat([self.embs[i](x[..., i]) for i in range(x.shape[-1])], dim=-1)
        return y



class ResidualTransformer(nn.Module):

    def __init__(self, n_text_vocab, n_quant=1024, dim=1024, nhead=16, 
                 enc_layers=8, dec_layers=16, n_spk_layers=3,
                 c_quant_levels=8, pred_quant_levels=8, 
                 t_emb_dim=1024, norm_first=True, p_cond_drop=0.1, dropout=0) -> None:
        super().__init__()

        self.cond_pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True)
        self.pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True)

        # *4 from heuristic, *2/3 from swiglu, since there are 3 linear matrices not 2.
        # so we must keep # params the same.
        dim_ff = int(dim*4*(3/4))

        # define custom encoder
        encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff,
                            activation=FNNSwiGLU(dim, dim_ff), dropout=dropout,
                            batch_first=True, norm_first=norm_first, layer_norm_eps=LAYERNORM_EPS)
        encoder_layer.linear1 = nn.Identity()
        encoder = nn.TransformerEncoder(encoder_layer, enc_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS) if norm_first else None)

        # define custom decoder
        decoder_layer = nn.TransformerDecoderLayer(dim, nhead, dim_ff,
                                                activation=FNNSwiGLU(dim, dim_ff), dropout=dropout,
                                                batch_first=True, norm_first=norm_first, layer_norm_eps=LAYERNORM_EPS)
        decoder_layer.linear1 = nn.Identity()
        decoder = nn.TransformerDecoder(decoder_layer, dec_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS) if norm_first else None)

        # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder
        for l in decoder.layers: l.activation = FNNSwiGLU(dim, dim_ff)

        self.tfm = nn.Transformer(dim, nhead, dim_feedforward=dim_ff, batch_first=True, 
            norm_first=norm_first,
            num_encoder_layers=enc_layers,
            num_decoder_layers=dec_layers,
            custom_encoder=encoder,
            custom_decoder=decoder,
            layer_norm_eps=LAYERNORM_EPS,
            dropout=dropout
        )
        # Timestep embedding network
        self.t_emb_dim = t_emb_dim
        self.timestep_encoder_emb = nn.Sequential(
            nn.Linear(t_emb_dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )
        self.timestep_decoder_emb = nn.Sequential(
            nn.Linear(t_emb_dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )

        self.text_embed = nn.Embedding(n_text_vocab, dim)

        ## ----> reference / conditioning encoder:
        self.ref_embedder = ChunkedEmbedding(n_quant, c_quant_levels, dim)
        self.ref_pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True)
        self.spk_identity_emb = nn.Embedding(1, dim)
        spk_encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff,
                                                activation=FNNSwiGLU(dim, dim_ff), dropout=dropout,
                                                batch_first=True, norm_first=True, layer_norm_eps=LAYERNORM_EPS)
        spk_encoder_layer.linear1 = nn.Identity()
        self.spk_encoder = nn.TransformerEncoder(spk_encoder_layer, n_spk_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS))
        # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder
        for l in self.spk_encoder.layers: l.activation = FNNSwiGLU(dim, dim_ff)
        # ----> end speaker encoder network

        # self.residual_encoder = nn.Embedding(n_quant, dim) # only encode first quantization level of decoder input.
        self.residual_encoder = ChunkedEmbedding(n_quant, c_quant_levels, dim)

        self.residual_decoder = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, n_quant)
            ) for i in range(pred_quant_levels)
        ])
        self.n_quantizer = pred_quant_levels
        self.p_cond_drop = p_cond_drop


    @torch.inference_mode
    def get_spk_embedding(self, c_codes, c_codes_length) -> Tensor:
        """ Obtain speaker embedding vectors using `c_codes` from reference encodec sequences, and `c_codes_length` of lengths for each sequence """
        bs = c_codes.shape[0]
        spk_seq = self.ref_embedder(c_codes) # (bs, sl, dim)
        spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)
        spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
        # add pos encoding
        spk_seq = self.ref_pos_embedding(spk_seq)

        # add 1 to c_codes_length to account for the fact that we concatenate the spk_ref_emb to it. 
        src_key_padding_mask = length_to_mask(c_codes_length+1, torch.zeros_like(c_codes_length), max_len=spk_seq.shape[1])
        src_key_padding_mask = src_key_padding_mask.to(dtype=torch.bool, device=spk_seq.device)

        # pass through transformer
        res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
        return res.squeeze(1)


    def forward(self, c_text: Tensor, c_codes: Tensor, c_texts_length: Tensor, c_codes_length: Tensor, 
                x: Tensor, x_padding_mask: Tensor, t: Tensor, drop_cond=False):
        """ Input:
            - `c_text`: (bs, seq_len1) the prompt text (BPE encoded)
            - `c_codes`: (bs, seq_len2, n_quant) the full tokenized codes of the reference speech
            - `c_texts_length`: (bs, ) the length of the codes in the text prompt
            - `c_codes_length`: (bs, ) the length of the prompt acoustic token codes in `c_codes`.
            - `x`: (bs, seq_len3) L0 residual codes
            - `x`: (bs, seq_len3, n_quant) L0 residual codes
            - `x_padding_mask`: (bs, seq_len3) masking for residual codes
            - `t`: (bs) timestep
            - `drop_cond`: bool, whether or not to forcibly drop the conditioning information.
        Returns:
            - outs: (bs, seq_len, n_quantizer, codebook_size)
        """
        
        c_text = self.text_embed(c_text) # (bs, seq_len1, dim)

        ## ----> reference / conditioning encoder:
        bs = c_codes.shape[0]

        
        if self.training:
            zero_cond_inds = torch.rand_like(t, dtype=c_text.dtype) < self.p_cond_drop
        else:
            # never randomly zero when in eval mode
            zero_cond_inds = torch.zeros_like(t, dtype=torch.bool)
            if drop_cond:
                # force drop conditioning
                zero_cond_inds = torch.ones_like(t, dtype=torch.bool)
        
        c_codes_length[zero_cond_inds] = 0
        c_codes[zero_cond_inds] = 1024

        spk_seq = self.ref_embedder(c_codes) # (bs, sl, dim)
        spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)
        spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
        # add pos encoding
        spk_seq = self.ref_pos_embedding(spk_seq)

        # add 1 to c_codes_length to account for the fact that we concatenate the spk_ref_emb to it. 
        src_key_padding_mask = length_to_mask(c_codes_length+1, torch.zeros_like(c_codes_length), max_len=spk_seq.shape[1])
        src_key_padding_mask = src_key_padding_mask.to(dtype=torch.bool, device=spk_seq.device)

        # pass through transformer
        res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
        c_codes = res # (bs, 1, dim)
        c_codes_lengths_extract = torch.ones_like(c_codes_length) # manually override all the code lengths to equal 1, since we only have 1 spk embedding. 
        ## ----> end reference / conditioning encoder:

        ## ----> timestep embeddings and parsing
        t_emb = timestep_embedding(t, self.t_emb_dim, dtype=c_text.dtype)
        t_emb_encoder = self.timestep_encoder_emb(t_emb) # (bs, t_dim)
        t_emb_decoder = self.timestep_decoder_emb(t_emb)
        
        ## ----> concatenating text/phone inputs and implicit speaker embedding. 
        c_phones_unpacked = nn.utils.rnn.unpad_sequence(c_text, c_texts_length.cpu(), batch_first=True)
        c_codes_unpacked = nn.utils.rnn.unpad_sequence(c_codes, c_codes_lengths_extract.cpu(), batch_first=True)
        # >>> Concat [speaker codes, text codes]
        assert all(b.shape[0] == 1 for b in c_codes_unpacked)
        c_joined = [torch.cat((b, a), dim=0) for a, b in zip(c_phones_unpacked, c_codes_unpacked)]

        c = nn.utils.rnn.pad_sequence(c_joined, batch_first=True)
        c_joined_lengths = torch.tensor([p.shape[0] for p in c_joined], device=c.device, dtype=torch.long)
        c_padding_mask = length_to_mask(c_joined_lengths, torch.zeros_like(c_joined_lengths))
        c = self.cond_pos_embedding(c)

        ## Format input:
        x = self.residual_encoder(x) # (bs, seq_len3, dim)

        x = self.pos_embedding(x)

        x = x + t_emb_decoder[:, None]
        c = c + t_emb_encoder[:, None]
        ## Perform prediction:
        output = self.tfm(c, x, src_key_padding_mask=c_padding_mask, 
                          tgt_key_padding_mask=x_padding_mask,
                          memory_key_padding_mask=c_padding_mask) # (bs, seq_len, dim)
        outs = torch.stack([self.residual_decoder[i](output) for i in range(self.n_quantizer)], dim=-1) # (bs, seq_len, logit_dim, n_quant)
        return outs