File size: 8,874 Bytes
8520a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import torch
import torch.nn.functional as F
import torchaudio
import copy
from torch import Tensor, nn
import logging
from .model import length_to_mask
from .samplers import (apply_typical_p, early_eos_penalty,
top_k_top_p_filtering, freq_rep_penalty)
from .nn_future import RotatingBufferCache
from .minbpe.codebook import CodebookTokenizer
from .minbpe.regex import RegexTokenizer
@torch.inference_mode()
def ar_generate(texttok: RegexTokenizer, speechtok: CodebookTokenizer,
codeclm: nn.Module, xx: Tensor, ss_gen: Tensor, first_codex_idx: int,
max_len: int = 1500, fp16: bool = True, temperature: float = 1.0, topk: int = None,
top_p=1.0, alpha_frequency=0, alpha_presence=0, penalty_window=100,
typical_p=1.0, eos_penalty_factor=1.0, eos_penalty_decay=0, n_phones_gen=None, vocode=True,
beam_width: int = 1, beam_length_penalty=2, use_kv_cache: bool = True) -> tuple[Tensor, Tensor]:
""" Use the `codeclm` language model to autoregressively generate a completion of `xx` (seq_len), where the first `first_codex_idx`-1
indices correspond to the input phones. The output generation is limited to at most `max_len` (measured as num latent codes).
Returns both output first quantizer codes and synthesized audio using `codec`. Use decoding with `beam_width` to keep
track of top `beam_width` outcomes, selecting the top one among them.
- Optionally vocode if `vocode` (default True).
- See `InferenceConfig` for other inference docs.
"""
assert xx.dim() == 1, "Only batch size of 1 is currently supported."
assert beam_width == 1, "Only beam size of 1 is currently supported."
# internally our batch size will be the beam width
bs = beam_width
x_inp = xx[None].repeat(bs, 1) # (bs, seq_len)
ss_gen = ss_gen[None].repeat(bs, 1, 1)
# We must subtract 1 in the line below so that we match the train-time conditions of having a
# False padding value for the <bos> token position. This is needed so that we correctly use the
# _acoustic_ and not the linguistic language embedding for the <bos> token.
offsets = torch.tensor([first_codex_idx - 1 for _ in range(bs)], dtype=torch.long, device=xx.device)
valid_logit_idx_start = len(texttok.vocab) # vocab['s2i']['quant0-0000']
valid_logit_idx_end = len(texttok.vocab) + len(speechtok.vocab) + 1 # vocab['s2i']['quant1-0000']
# Make mask that is True where we have valid outputs, False otherwise (where we have text outputs).
# logit_mask = torch.zeros(n_vocab, dtype=bool, device=x_inp.device)
# logit_mask[valid_logit_idx_start:valid_logit_idx_end] = True
# logit_mask[vocab['s2i']['<eos>']] = True
cum_logprobs = torch.zeros(bs, dtype=torch.float, device=x_inp.device)
eos_idx = len(texttok.vocab) + speechtok.special_tokens['<|endofspeech|>']
n_vocab = len(texttok.vocab) + len(speechtok.vocab)
logging.info(f"Starting beam decoding with beam_width={beam_width}")
prev_ids = [[] for _ in range(bs)]
cache = None
if use_kv_cache:
# Initialise kv cache
cache_window = min(codeclm.ar.args.sliding_window, x_inp.shape[-1] + max_len)
cache = RotatingBufferCache(codeclm.ar.args.n_layers, bs, cache_window, codeclm.ar.args.n_kv_heads, codeclm.ar.args.head_dim)
cache.to(device=x_inp.device, dtype=torch.float16)
counter = 0
while x_inp.shape[-1] < max_len:
counter += 1
gen_length = torch.tensor([x_inp.shape[-1] for _ in range(bs)], dtype=torch.long, device=xx.device)
padding_mask = length_to_mask(gen_length, offsets)
with torch.autocast('cuda', enabled=fp16):
logits: Tensor = codeclm(x_inp, padding_mask, spk_reference=ss_gen, cache=cache, counter=counter)
logits = logits.float()
logits = logits[:, -1] # select last index, now (bs, logit_dim)
# <---------------------- logit filtering ---------------------->
filtered_logits = logits.clone()
# apply repetition penalty before logit mask if any item in the beam has more than 1 prior token.
if len(prev_ids[0]) > 1:
filtered_logits = freq_rep_penalty(filtered_logits, previous=torch.tensor(prev_ids, dtype=torch.long),
alpha_frequency=alpha_frequency, alpha_presence=alpha_presence,
penalty_window=penalty_window)
filtered_logits[..., :valid_logit_idx_start-1] = float('-inf')
filtered_logits[..., valid_logit_idx_end:] = float('-inf')
if n_phones_gen is not None:
# apply eos penalty
filtered_logits = early_eos_penalty(filtered_logits, len(prev_ids[0]), n_phones_gen,
eos_penalty_decay, eos_penalty_factor,
eos_index=eos_idx)
filtered_logits = filtered_logits / temperature
filtered_logits = top_k_top_p_filtering(filtered_logits, top_k=topk, top_p=top_p)
filtered_logits = apply_typical_p(filtered_logits, mass=typical_p)
# mask out anything that isn't first quantizer output codes
filtered_logits[..., :valid_logit_idx_start-1] = float('-inf')
filtered_logits[..., valid_logit_idx_end:] = float('-inf')
logits = filtered_logits
# <---------------------- next frame prediction --------------------->
logprobs = logits.log_softmax(dim=-1)
# update assignments: if any beam ended in <eos> last step, it MUST also end in <eos> this step.
# so, below we multiply the logits with a True/False mask, setting to
for j in range(bs):
if x_inp[j, -1] == eos_idx:
# do not add any additional probability to it, keeping it the same for all vocab idxs
logprobs[j] = float('-inf') # zero probability of anything non-eos after 1 eos
logprobs[j, eos_idx] = 0 # probability=1 of <eos> after <eos>
candidate_cum_logprobs = cum_logprobs[:, None] + logprobs # (bs, 1) + (bs, vocab) -> (bs, vocab)
logp_flat = logprobs.flatten()
candidates = torch.multinomial(logp_flat.exp(), num_samples=beam_width, replacement=False) # (bs,)
# Ravel it up:
beam_idxs = candidates // n_vocab # (bs,)
tok_inds_in_each_beam = candidates % n_vocab # (bs,)
# check for breaks
if torch.all(tok_inds_in_each_beam == eos_idx):
# apply length penalty:
non_eos_toks = (x_inp != eos_idx).sum(dim=-1) # (bs,) number of non eos toks
gen_length = non_eos_toks - first_codex_idx
penalties = (gen_length**beam_length_penalty)
penalized_cum_tok_logp = candidate_cum_logprobs / penalties[:, None]
eos_avg_logps = penalized_cum_tok_logp[:, eos_idx]
best_beam_idx = eos_avg_logps.argmax()
best_avg_logp = eos_avg_logps[best_beam_idx]
best_beam = x_inp[best_beam_idx]
logging.info((f"best beam = {best_beam_idx} @ penalized_cum_tok_logp = {best_avg_logp.item():.3f} |\n num toks: {non_eos_toks.cpu().tolist()}. "
f"Candidates: {eos_avg_logps.cpu()} |\n non-eos toks: {non_eos_toks.cpu().tolist()} |\n penalties: {penalties.cpu().tolist()} | "
f"raw cumulative probs: {candidate_cum_logprobs[:, eos_idx].cpu().tolist()}"))
break
# update beam histories:
x_inp = x_inp[beam_idxs]
# update next token
next_sample = tok_inds_in_each_beam
# update cum logprob
cum_logprobs = cum_logprobs[beam_idxs] + logprobs[beam_idxs, tok_inds_in_each_beam]
# update prior inds to point to correct beam
prev_ids = [copy.deepcopy(prev_ids[beam_idx.item()]) for beam_idx in beam_idxs]
# add new tokens to previous ids
for j in range(bs):
prev_ids[j].append(tok_inds_in_each_beam[j].item())
logging.debug("L%d | next sample: %s | beam: %s | cum_logp: %s", len(x_inp[0]), next_sample.cpu().tolist(), beam_idxs.cpu().tolist(), cum_logprobs.cpu())
# update cache with beam indexes
if cache is not None:
cache.cache_k = cache.cache_k[:, beam_idxs]
cache.cache_v = cache.cache_v[:, beam_idxs]
# add 1 None below to make (bs,) -> (bs, 1) so we can concat along seq len dim.
x_inp = torch.cat([x_inp, next_sample[:, None]], dim=-1)
if x_inp.shape[-1] >= max_len - 1:
logging.warning(f"[autoregressive generation] output length = {x_inp.shape[-1]} -- inference likely failed or input too long!")
best_beam = x_inp[0]
if not vocode: return best_beam # (seq_len,)
else: raise AssertionError()
|