File size: 8,874 Bytes
8520a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torch.nn.functional as F
import torchaudio
import copy
from torch import Tensor, nn
import logging
from .model import length_to_mask
from .samplers import (apply_typical_p, early_eos_penalty,
                      top_k_top_p_filtering, freq_rep_penalty)
from .nn_future import RotatingBufferCache
from .minbpe.codebook import CodebookTokenizer
from .minbpe.regex import RegexTokenizer


@torch.inference_mode()
def ar_generate(texttok: RegexTokenizer, speechtok: CodebookTokenizer, 
                codeclm: nn.Module, xx: Tensor, ss_gen: Tensor, first_codex_idx: int, 
                max_len: int = 1500, fp16: bool = True, temperature: float = 1.0, topk: int = None,
                top_p=1.0, alpha_frequency=0, alpha_presence=0, penalty_window=100,
                typical_p=1.0, eos_penalty_factor=1.0, eos_penalty_decay=0, n_phones_gen=None, vocode=True,
                beam_width: int = 1, beam_length_penalty=2, use_kv_cache: bool = True) -> tuple[Tensor, Tensor]:
    """ Use the `codeclm` language model to autoregressively generate a completion of `xx` (seq_len), where the first `first_codex_idx`-1
    indices correspond to the input phones. The output generation is limited to at most `max_len` (measured as num latent codes).
    Returns both output first quantizer codes and synthesized audio using `codec`. Use decoding with `beam_width` to keep 
    track of top `beam_width` outcomes, selecting the top one among them. 

    - Optionally vocode if `vocode` (default True).
    - See `InferenceConfig` for other inference docs. 
    """
    assert xx.dim() == 1, "Only batch size of 1 is currently supported."
    assert beam_width == 1, "Only beam size of 1 is currently supported."
    # internally our batch size will be the beam width
    bs = beam_width
    x_inp = xx[None].repeat(bs, 1) # (bs, seq_len)
    ss_gen = ss_gen[None].repeat(bs, 1, 1)
    # We must subtract 1 in the line below so that we match the train-time conditions of having a
    # False padding value for the <bos> token position. This is needed so that we correctly use the
    # _acoustic_ and not the linguistic language embedding for the <bos> token.
    offsets = torch.tensor([first_codex_idx - 1 for _ in range(bs)], dtype=torch.long, device=xx.device)
    valid_logit_idx_start = len(texttok.vocab) # vocab['s2i']['quant0-0000']
    valid_logit_idx_end = len(texttok.vocab) + len(speechtok.vocab) + 1 # vocab['s2i']['quant1-0000']
    # Make mask that is True where we have valid outputs, False otherwise (where we have text outputs). 
    # logit_mask = torch.zeros(n_vocab, dtype=bool, device=x_inp.device)
    # logit_mask[valid_logit_idx_start:valid_logit_idx_end] = True
    # logit_mask[vocab['s2i']['<eos>']] = True
    cum_logprobs = torch.zeros(bs, dtype=torch.float, device=x_inp.device)
    eos_idx = len(texttok.vocab) + speechtok.special_tokens['<|endofspeech|>']
    n_vocab = len(texttok.vocab) + len(speechtok.vocab)

    logging.info(f"Starting beam decoding with beam_width={beam_width}")

    prev_ids = [[] for _ in range(bs)]

    cache = None
    if use_kv_cache:
        # Initialise kv cache
        cache_window = min(codeclm.ar.args.sliding_window, x_inp.shape[-1] + max_len)
        cache = RotatingBufferCache(codeclm.ar.args.n_layers, bs, cache_window, codeclm.ar.args.n_kv_heads, codeclm.ar.args.head_dim)
        cache.to(device=x_inp.device, dtype=torch.float16)

    counter = 0
    while x_inp.shape[-1] < max_len:
        counter += 1
        gen_length = torch.tensor([x_inp.shape[-1] for _ in range(bs)], dtype=torch.long, device=xx.device)
        padding_mask = length_to_mask(gen_length, offsets)
        
        with torch.autocast('cuda', enabled=fp16):
            logits: Tensor = codeclm(x_inp, padding_mask, spk_reference=ss_gen, cache=cache, counter=counter)
        logits = logits.float()

        logits = logits[:, -1] # select last index, now (bs, logit_dim)

        # <---------------------- logit filtering ---------------------->
        filtered_logits = logits.clone()

        # apply repetition penalty before logit mask if any item in the beam has more than 1 prior token.
        if len(prev_ids[0]) > 1: 
            filtered_logits = freq_rep_penalty(filtered_logits, previous=torch.tensor(prev_ids, dtype=torch.long), 
                                             alpha_frequency=alpha_frequency, alpha_presence=alpha_presence, 
                                             penalty_window=penalty_window)

        filtered_logits[..., :valid_logit_idx_start-1] = float('-inf')
        filtered_logits[..., valid_logit_idx_end:] = float('-inf')

        if n_phones_gen is not None:
            # apply eos penalty
            filtered_logits = early_eos_penalty(filtered_logits, len(prev_ids[0]), n_phones_gen, 
                                                eos_penalty_decay, eos_penalty_factor, 
                                                eos_index=eos_idx)

        filtered_logits = filtered_logits / temperature
        filtered_logits = top_k_top_p_filtering(filtered_logits, top_k=topk, top_p=top_p)
        filtered_logits = apply_typical_p(filtered_logits, mass=typical_p)

        # mask out anything that isn't first quantizer output codes
        filtered_logits[..., :valid_logit_idx_start-1] = float('-inf')
        filtered_logits[..., valid_logit_idx_end:] = float('-inf')
        logits = filtered_logits

        # <---------------------- next frame prediction --------------------->

        logprobs = logits.log_softmax(dim=-1)

        # update assignments: if any beam ended in <eos> last step, it MUST also end in <eos> this step.
        # so, below we multiply the logits with a True/False mask, setting to 
        for j in range(bs):
            if x_inp[j, -1] == eos_idx:
                # do not add any additional probability to it, keeping it the same for all vocab idxs
                logprobs[j] = float('-inf') # zero probability of anything non-eos after 1 eos
                logprobs[j, eos_idx] = 0 # probability=1 of <eos> after <eos>

        candidate_cum_logprobs = cum_logprobs[:, None] + logprobs # (bs, 1) + (bs, vocab) -> (bs, vocab)

        logp_flat = logprobs.flatten()
        candidates = torch.multinomial(logp_flat.exp(), num_samples=beam_width, replacement=False) # (bs,)
        # Ravel it up:
        beam_idxs = candidates // n_vocab # (bs,)
        tok_inds_in_each_beam = candidates % n_vocab # (bs,)                

        # check for breaks
        if torch.all(tok_inds_in_each_beam == eos_idx):
            # apply length penalty:
            non_eos_toks = (x_inp != eos_idx).sum(dim=-1) # (bs,) number of non eos toks
            gen_length = non_eos_toks - first_codex_idx
            penalties = (gen_length**beam_length_penalty)
            penalized_cum_tok_logp = candidate_cum_logprobs / penalties[:, None] 

            eos_avg_logps = penalized_cum_tok_logp[:, eos_idx]
            best_beam_idx = eos_avg_logps.argmax()
            best_avg_logp = eos_avg_logps[best_beam_idx]
            best_beam = x_inp[best_beam_idx]
            logging.info((f"best beam = {best_beam_idx} @ penalized_cum_tok_logp = {best_avg_logp.item():.3f} |\n num toks: {non_eos_toks.cpu().tolist()}. "
                         f"Candidates: {eos_avg_logps.cpu()} |\n non-eos toks: {non_eos_toks.cpu().tolist()} |\n penalties: {penalties.cpu().tolist()} | "
                         f"raw cumulative probs: {candidate_cum_logprobs[:, eos_idx].cpu().tolist()}"))
            break

        # update beam histories:
        x_inp = x_inp[beam_idxs]
        # update next token
        next_sample = tok_inds_in_each_beam
        # update cum logprob
        cum_logprobs = cum_logprobs[beam_idxs] + logprobs[beam_idxs, tok_inds_in_each_beam]
        # update prior inds to point to correct beam
        prev_ids = [copy.deepcopy(prev_ids[beam_idx.item()]) for beam_idx in beam_idxs]
        # add new tokens to previous ids
        for j in range(bs):
            prev_ids[j].append(tok_inds_in_each_beam[j].item())

        logging.debug("L%d | next sample: %s | beam: %s | cum_logp: %s", len(x_inp[0]), next_sample.cpu().tolist(), beam_idxs.cpu().tolist(), cum_logprobs.cpu())

        # update cache with beam indexes
        if cache is not None:
            cache.cache_k = cache.cache_k[:, beam_idxs]
            cache.cache_v = cache.cache_v[:, beam_idxs]
        
        # add 1 None below to make (bs,) -> (bs, 1) so we can concat along seq len dim.
        x_inp = torch.cat([x_inp, next_sample[:, None]], dim=-1)
        

    if x_inp.shape[-1] >= max_len - 1:
        logging.warning(f"[autoregressive generation] output length = {x_inp.shape[-1]} -- inference likely failed or input too long!")
        best_beam = x_inp[0]

    if not vocode: return best_beam # (seq_len,)
    else: raise AssertionError()