File size: 24,605 Bytes
565faca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
import inspect
import math
from dataclasses import dataclass, field
from typing import Literal, Optional, Union

import torch
import torch.nn as nn
import tqdm
from einops import rearrange
from torch.nn import functional as F

from fam.llm.layers import Block, LayerNorm, RMSNorm
from fam.llm.mixins import CausalInferenceMixin, NonCausalInferenceMixin

from IPython import embed
END_OF_TEXT_TOKEN = 1537


def _select_spkemb(spkemb, mask):
    _, examples, _ = spkemb.shape
    mask = torch.nn.functional.one_hot(mask.long(), num_classes=examples).to(spkemb)  # shape: (batch, time, examples)
    spkemb = spkemb.transpose(1, 2)  # b ex c -> b c ex
    mask = mask.transpose(1, 2)  # b t ex -> b ex t
    return torch.bmm(spkemb, mask).transpose(1, 2)  # b c t -> b t c


@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_sizes: list = field(default_factory=list)
    target_vocab_sizes: Optional[list] = None
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    spkemb_dropout: float = 0.0
    bias: bool = True  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
    causal: bool = (
        True  # auto-regressive or not, i.e. whether to have attention mask that prevents attending to future tokens
    )
    spk_emb_on_text: bool = True  # whether to add speaker embedding conditioning to text tokens or not
    norm_type: str = "layernorm"  # "rmsnorm" or "layernorm
    rmsnorm_eps: Optional[float] = None  # only used for rmsnorm
    nonlinearity_type: str = "gelu"  # "gelu" or "swiglu"
    swiglu_multiple_of: Optional[int] = None  # MLP hidden layer (using SwiGLU) will be multiple of this
    attn_kernel_type: Literal["torch_attn"] = "torch_attn"
    #Literal["fa2", "torch_attn", "hand"] = "fa2"
    kv_cache_enabled: bool = False  # whether to use key-value cache for attention


def _check_speaker_emb_dims(
    speaker_embs: Union[list, torch.Tensor], expected_speaker_emb_dim: int, expected_batch_size: int
) -> Union[torch.Tensor, list]:
    """
    Checks that the speaker embedding dimensions are correct, and reshapes them if necessary.
    """
    if type(speaker_embs) == list:
        b_se = len(speaker_embs)
        for i, s in enumerate(speaker_embs):
            if s is not None:
                emb_dim = s.shape[-1]
                if s.ndim == 1:
                    speaker_embs[i] = speaker_embs[i].unsqueeze(0)
    else:
        if speaker_embs.ndim == 2:
            # if we have a single speaker embedding for the whole sequence,
            # add a dummy dimension for backwards compatibility
            speaker_embs = speaker_embs[:, None, :]

        # num_examples is the number of utterances packed into this sequence
        b_se, num_examples, emb_dim = speaker_embs.size()

    assert b_se == expected_batch_size, f"Batch size mismatch: {b_se} != {expected_batch_size}"
    assert (
        emb_dim == expected_speaker_emb_dim
    ), f"Speaker embedding dimension mismatch: {emb_dim} != {expected_speaker_emb_dim}"

    return speaker_embs


class GPT(nn.Module, NonCausalInferenceMixin, CausalInferenceMixin):
    def __init__(self, config: GPTConfig, speaker_emb_dim: Optional[int] = None):
        """
        Initialize the GPT model.

        Args:
            config (GPTConfig): Configuration object for the model.
            speaker_emb_dim (Optional[int]): Dimension of the speaker embedding. Default is None.
        """
        super().__init__()
        assert config.vocab_sizes is not None
        assert config.block_size is not None
        self.config = config

        self.kv_cache_enabled = False  # disabled by default
        self.kv_pos = 0

        self.speaker_emb_dim = speaker_emb_dim
        self.spk_emb_on_text = config.spk_emb_on_text
        if self.config.causal is True and self.spk_emb_on_text is False:
            print("!!!!!!!!!!!!!!!!!!")
            print(
                f"!!!!!!!! Using DEFAULT of {END_OF_TEXT_TOKEN} as end of text token to find speaker cond masking!! You likely need to change this."
            )
            print("!!!!!!!!!!!!!!!!!!")
        if self.config.causal is False and self.spk_emb_on_text is False:
            raise Exception(
                "Cannot use speaker embedding masking with non-causal model. This is unexpected. Check for relevant changes required in code before proceeding."
            )

        if config.norm_type == "rmsnorm":
            if config.rmsnorm_eps is None:
                raise Exception("RMSNorm requires rmsnorm_eps to be set")
            ln_f = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
        elif config.norm_type == "layernorm":
            ln_f = LayerNorm(config.n_embd, bias=config.bias)
        else:
            raise Exception(f"Unknown norm type: {config.norm_type}")

        self.transformer = nn.ModuleDict(
            dict(
                wtes=nn.ModuleList([nn.Embedding(vsize, config.n_embd,) for vsize in config.vocab_sizes]),
                wpe=nn.Embedding(config.block_size, config.n_embd),
                drop=nn.Dropout(config.dropout),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=ln_f,
            )
        )
        if speaker_emb_dim is not None:
            self.speaker_cond_pos = nn.Linear(speaker_emb_dim, config.n_embd, bias=False) # ここで256->2048

        self.lm_heads = nn.ModuleList()
        if config.target_vocab_sizes is not None:
            assert config.causal is False
        else:
            assert config.causal is True

        for vsize in config.vocab_sizes if config.target_vocab_sizes is None else config.target_vocab_sizes:
            self.lm_heads.append(nn.Linear(config.n_embd, vsize, bias=False))

        if config.target_vocab_sizes is None:
            for i in range(len(config.vocab_sizes)):
                # TODO: do we not need to take the transpose here?
                # https://paperswithcode.com/method/weight-tying
                self.lm_heads[i].weight = self.transformer.wtes[i].weight  # type: ignore
            assert len(self.lm_heads) == len(
                self.transformer.wtes  # type: ignore
            ), f"Number of heads ({len(self.lm_heads)}) must match number of one-hot embedding matrics ({len(self.transformer.wtes)})."  # type: ignore
        # - causal
        # GPT(
        # (transformer): ModuleDict(
        #     (wtes): ModuleList(
        #     (0): Embedding(2562, 2048)
        #     )
        #     (wpe): Embedding(2048, 2048)
        #     (drop): Dropout(p=0.0, inplace=False)
        #     (h): ModuleList(
        #     (0-23): 24 x Block(
        #         (ln_1): RMSNorm()
        #         (ln_2): RMSNorm()
        #         (attn): SelfAttention(
        #         (c_attn): Linear(in_features=2048, out_features=6144, bias=False)
        #         (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
        #         (resid_dropout): Dropout(p=0.0, inplace=False)
        #         )
        #         (mlp): MLP(
        #         (swiglu): SwiGLU(
        #             (w1): Linear(in_features=2048, out_features=5632, bias=False)
        #             (w3): Linear(in_features=2048, out_features=5632, bias=False)
        #         )
        #         (c_proj): Linear(in_features=5632, out_features=2048, bias=False)
        #         (dropout): Dropout(p=0.0, inplace=False)
        #         )
        #     )
        #     )
        #     (ln_f): RMSNorm()
        # )
        # (speaker_cond_pos): Linear(in_features=256, out_features=2048, bias=False)
        # (lm_heads): ModuleList(
        #     (0): Linear(in_features=2048, out_features=2562, bias=False)
        # )
        # )
        # GPTConfig(block_size=2048, vocab_sizes=[2562], target_vocab_sizes=None, n_layer=24, n_head=16, n_embd=2048, dropout=0.0, spkemb_dropout=0.1, bias=False, causal=True, spk_emb_on_text=True, norm_type='rmsnorm', rmsnorm_eps=1e-05, nonlinearity_type='swiglu', swiglu_multiple_of=256, attn_kernel_type='torch_attn', kv_cache_enabled=False)
        #
        # - non causal
        # GPT(
        #   (transformer): ModuleDict(
        #     (wtes): ModuleList(
        #       (0): Embedding(1538, 384)
        #       (1): Embedding(1025, 384)
        #     )
        #     (wpe): Embedding(1024, 384)
        #     (drop): Dropout(p=0.0, inplace=False)
        #     (h): ModuleList(
        #       (0-5): 6 x Block(
        #         (ln_1): LayerNorm()
        #         (ln_2): LayerNorm()
        #         (attn): SelfAttention(
        #           (c_attn): Linear(in_features=384, out_features=1152, bias=False)
        #           (c_proj): Linear(in_features=384, out_features=384, bias=False)
        #           (resid_dropout): Dropout(p=0.0, inplace=False)
        #         )
        #         (mlp): MLP(
        #           (c_fc): Linear(in_features=384, out_features=1536, bias=False)
        #           (gelu): GELU(approximate='none')
        #           (c_proj): Linear(in_features=1536, out_features=384, bias=False)
        #           (dropout): Dropout(p=0.0, inplace=False)
        #         )
        #       )
        #     )
        #     (ln_f): LayerNorm()
        #   )
        #   (speaker_cond_pos): Linear(in_features=256, out_features=384, bias=False)
        #   (lm_heads): ModuleList(
        #     (0-5): 6 x Linear(in_features=384, out_features=1025, bias=False)
        #   )
        # )
        # GPTConfig(block_size=1024, vocab_sizes=[1538, 1025], target_vocab_sizes=[1025, 1025, 1025, 1025, 1025, 1025], n_layer=6, n_head=6, n_embd=384, dropout=0.0, spkemb_dropout=0.0, bias=False, causal=False, spk_emb_on_text=True, norm_type='layernorm', rmsnorm_eps=None, nonlinearity_type='gelu', swiglu_multiple_of=None, attn_kernel_type='fa2', kv_cache_enabled=False)
        # if config.causal is False:
        #     embed()
        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def _mask_spk_emb_on_text(self, idx: torch.Tensor, spk_emb: torch.Tensor) -> torch.Tensor:
        """
        This is in a separate function so we can test it easily.
        """
        # find index of end of text token in each sequence, then generate a binary mask
        # of shape (b, 1, t) to mask out the speaker embedding for all tokens before the end of text token.
        # Note: this does NOT mask the <end_of_text_token> token. This is important so that the first audio token predicted
        # has speaker information to use.

        # Check in channel dimension 0 as this is usually the first hierarchy where we put the text tokens.
        is_end_of_text = idx[:, 0, :] == END_OF_TEXT_TOKEN
        # use > 0, in case end_of_text_token is repeated for any reason.
        mask = (torch.cumsum(is_end_of_text, dim=-1) > 0).float()
        spk_emb = spk_emb * mask[:, :, None]

        return spk_emb

    def forward(
        self,
        idx,
        targets=None,
        speaker_embs=None,
        embedding=None,
        speaker_emb_mask=None,
        loss_reduce: Literal["mean", "none"] = "mean",
    ):
        # print(f'{idx.shape}')
        device = idx.device
        b, num_hierarchies, t = idx.size()

        if speaker_embs is not None:
            speaker_embs = _check_speaker_emb_dims(
                speaker_embs=speaker_embs, expected_speaker_emb_dim=self.speaker_emb_dim, expected_batch_size=b
            )

        assert (
            t <= self.config.block_size
        ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

        if self.kv_cache_enabled:
            if self.kv_pos == 0:
                pos = torch.arange(0, t, dtype=torch.long, device=device)
                self.kv_pos += t
            else:
                assert t == 1, "KV cache is only supported for single token inputs"
                pos = torch.tensor([self.kv_pos], dtype=torch.long, device=device)  # shape (1)
                self.kv_pos += 1
        else:
            pos = torch.arange(0, t, dtype=torch.long, device=device)  # shape (t)
            
        # print("here1")
        # forward the GPT model itself
        # assert num_hierarchies == len(
        #     self.transformer.wtes
        # ), f"Input tensor has {num_hierarchies} hierarchies, but model has {len(self.transformer.wtes)} set of input embeddings."

        # embed the tokens, positional encoding, and speaker embedding
        tok_emb = torch.zeros((b, t, self.config.n_embd), device=device)
        # ends up swapping (B, num_hierarchies, t) tokens -> (B, t, c) embeddings.
        # print(f'{len(self.transformer.wtes)=}')
        # print(f'{self.transformer.wtes[0]=}')
        wte = self.transformer.wtes[0]
        #for i in range(num_hierarchies):
        for i, wte in enumerate(self.transformer.wtes):
            # print(f'{idx[:,i,:].shape=}')
            # print(f'{idx[:,i,:]=}')
            # print(f'{torch.max(idx[:,i,:])}')
            # print(f'{torch.min(idx[:,i,:])}')
            # print(f'{wte(idx[:,i,:]).shape=}')
            # print(f'{tok_emb.shape=}')
            mask_pad = idx[:, i, :] == -1 # 要素が-1であるindexを取得
            masked_idx = idx[:, i, :].clone()
            masked_idx[mask_pad] = 0
            # embed(header='a')
            embedded_idx = wte(masked_idx)
            # embed(header='b')
            # embedding_dim = embedded_idx.shape[-1]
            # mask_expanded = mask_pad.unsqueeze(-1).expand(-1, -1, embedding_dim)
            embedded_idx[mask_pad] = 0
            # embedded_idx = wte(idx[:, i, :])
            # print(embedded_idx[:,:,:10])
            # embed(header='c')
            # embed()
            # masked_embedded_idx = 
            tok_emb += embedded_idx
            # tok_emb += wte(idx[:, i, :])
        # embed()
        pos_emb = self.transformer.wpe(pos)  # position embeddings of shape (t, n_embd)

        spk_emb = 0.0
        if speaker_embs is not None:
            if type(speaker_embs) == list:
                assert speaker_emb_mask is None
                assert self.training is False
                assert self.spk_emb_on_text is True
                # print(f'{self.config.n_embd=}')
                spk_emb = []
                for speaker_emb_row in speaker_embs:
                    if speaker_emb_row is not None:
                        spk_emb.append(self.speaker_cond_pos(speaker_emb_row.unsqueeze(0)))
                        assert spk_emb[-1].shape == (1, 1, self.config.n_embd), f"spk_emb[-1].shape={spk_emb[-1].shape}"
                    else:
                        spk_emb.append(torch.zeros((1, 1, self.config.n_embd), device=device, dtype=pos_emb.dtype))
                # print(f'{len(spk_emb)}, {[v.shape for v in spk_emb]=}')
                spk_emb = torch.cat(spk_emb, dim=0)

                assert (
                    spk_emb.ndim == 3 and spk_emb.shape[1] == 1 and spk_emb.shape[0] == b
                ), f"spk_emb.ndim={spk_emb.ndim}, spk_emb.shape={spk_emb.shape}, len(speaker_embs)={len(speaker_embs)}"
            else:
                speakers_embedded = self.speaker_cond_pos(speaker_embs)  # shape (b, num_examples, c)

                if speaker_emb_mask is not None:
                    spk_emb = _select_spkemb(speakers_embedded, speaker_emb_mask)
                    assert spk_emb.shape == (b, t, self.config.n_embd)
                else:
                    spk_emb = speakers_embedded
                    # if we don't have a mask, we assume that the speaker embedding is the same for all tokens
                    # then num_examples dimension just becomes the time dimension
                    assert spk_emb.ndim == 3 and spk_emb.shape[1] == 1

                if self.training and self.config.spkemb_dropout > 0.0:
                    # Remove speaker conditioning at random.
                    dropout = torch.ones_like(speakers_embedded) * (
                        torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout
                    )
                    spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded)

            if self.spk_emb_on_text is False:
                assert speaker_emb_mask is None, "Not implemented for spk_emb_on_text=False"
                spk_emb = self._mask_spk_emb_on_text(idx, spk_emb)
        elif embedding is not None:
            # spk_emb = embedding
            # spk_emb = torch.zeros((b, t, self.config.n_embd), device=device)
            # for i, wte in enumerate(self.transformer.wtes):
            #     print(f'{embedding[:, i, :].shape=}, {embedding.shape=}')
            #     print(f'{wte(embedding[:, i, :]).shape=}')
            #     spk_emb += wte(embedding[:, i, :])
            spk_emb = self.speaker_cond_pos(embedding)
        # TODO: implement causal attnetion mask here
        # memo:
        # b, t, d=2048のとき, tok_emb=(b,t,d), pos_emb=(t,d), spk_emp=(b,1,d)
        # train: tok_emb.shape=torch.Size([128, 187, 2048]), pos_emb.shape=torch.Size([187, 2048]), spk_emb.shape=torch.Size([128, 1, 1, 187])<- spk_embは(b,1,2048)になってほしい?
        # sample: tok_emb.shape=torch.Size([2, 369, 2048]), pos_emb.shape=torch.Size([369, 2048]), spk_emb.shape=torch.Size([2, 1, 2048])
        # print(f'{tok_emb.shape=}, {pos_emb.shape=}, {spk_emb.shape=}')
        x = self.transformer.drop(tok_emb + pos_emb + spk_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            list_logits = [lm_head(x) for lm_head in self.lm_heads]
            # print(f'{len(list_logits)=}, {list_logits[0].shape=}')
            # embed(header='cc')

            losses = [
                F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    targets[:, i, :].contiguous().view(-1),
                    ignore_index=-1,
                    reduction=loss_reduce,
                )
                for i, logits in enumerate(list_logits)
            ]
            # TODO: should we do this better without stack somehow?
            # embed(header='bb')
            losses = torch.stack(losses)
            if loss_reduce == "mean":
                # embed(header='aa')
                losses = losses.mean()
            else:
                losses = rearrange(losses, "h (b t) -> b h t", h=len(self.lm_heads), b=b, t=t)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            if self.config.causal:
                list_logits = [
                    lm_head(x[:, [-1], :]) for lm_head in self.lm_heads
                ]  # note: using list [-1] to preserve the time dim
                # print(f'{len(list_logits)=}, {list_logits[0].shape=}')
            else:
                list_logits = [lm_head(x) for lm_head in self.lm_heads]
            losses = None

        return list_logits, losses

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    @torch.no_grad()
    def generate(
        self,
        idx: torch.Tensor,
        max_new_tokens: int,
        seq_lens: Optional[list] = None,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        speaker_embs: Optional[torch.Tensor] = None,
        batch_size: Optional[int] = None,
        guidance_scale: Optional[float] = None,
    ):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"

        if self.config.causal:
            if seq_lens is None or batch_size is None:
                raise Exception("seq_lens and batch_size must be provided for causal sampling")

            return self._causal_sample(
                idx=idx,
                max_new_tokens=max_new_tokens,
                seq_lens=seq_lens,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                speaker_embs=speaker_embs,
                batch_size=batch_size,
                guidance_scale=guidance_scale,
            )

        else:
            if seq_lens is not None:
                raise Exception("seq_lens is not supported yet for non-causal sampling")

            if batch_size is None:
                raise Exception("batch_size must be provided for non-causal sampling")

            if guidance_scale is not None:
                raise Exception("guidance_scale is not supported for non-causal sampling")

            if top_p is not None:
                raise Exception("top_p is not supported for non-causal sampling")

            out = []
            for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="non-causal batching"):
                end_index = min(start_index + batch_size, idx.shape[0])
                out.append(
                    self._non_causal_sample(
                        idx=idx[start_index:end_index],
                        speaker_embs=speaker_embs[start_index:end_index] if speaker_embs is not None else None,
                        temperature=temperature,
                        top_k=top_k,
                    )
                )
            return torch.cat(out, dim=0)
            return torch.cat(out, dim=0)