|
import torch |
|
import torch.nn.functional as F |
|
import torchaudio |
|
import copy |
|
from torch import Tensor, nn |
|
import logging |
|
from .model import length_to_mask |
|
from .samplers import (apply_typical_p, early_eos_penalty, |
|
top_k_top_p_filtering, freq_rep_penalty) |
|
from .nn_future import RotatingBufferCache |
|
from .minbpe.codebook import CodebookTokenizer |
|
from .minbpe.regex import RegexTokenizer |
|
|
|
|
|
@torch.inference_mode() |
|
def ar_generate(texttok: RegexTokenizer, speechtok: CodebookTokenizer, |
|
codeclm: nn.Module, xx: Tensor, ss_gen: Tensor, first_codex_idx: int, |
|
max_len: int = 1500, fp16: bool = True, temperature: float = 1.0, topk: int = None, |
|
top_p=1.0, alpha_frequency=0, alpha_presence=0, penalty_window=100, |
|
typical_p=1.0, eos_penalty_factor=1.0, eos_penalty_decay=0, n_phones_gen=None, vocode=True, |
|
beam_width: int = 1, beam_length_penalty=2, use_kv_cache: bool = True) -> tuple[Tensor, Tensor]: |
|
""" Use the `codeclm` language model to autoregressively generate a completion of `xx` (seq_len), where the first `first_codex_idx`-1 |
|
indices correspond to the input phones. The output generation is limited to at most `max_len` (measured as num latent codes). |
|
Returns both output first quantizer codes and synthesized audio using `codec`. Use decoding with `beam_width` to keep |
|
track of top `beam_width` outcomes, selecting the top one among them. |
|
|
|
- Optionally vocode if `vocode` (default True). |
|
- See `InferenceConfig` for other inference docs. |
|
""" |
|
assert xx.dim() == 1, "Only batch size of 1 is currently supported." |
|
assert beam_width == 1, "Only beam size of 1 is currently supported." |
|
|
|
bs = beam_width |
|
x_inp = xx[None].repeat(bs, 1) |
|
ss_gen = ss_gen[None].repeat(bs, 1, 1) |
|
|
|
|
|
|
|
offsets = torch.tensor([first_codex_idx - 1 for _ in range(bs)], dtype=torch.long, device=xx.device) |
|
valid_logit_idx_start = len(texttok.vocab) |
|
valid_logit_idx_end = len(texttok.vocab) + len(speechtok.vocab) + 1 |
|
|
|
|
|
|
|
|
|
cum_logprobs = torch.zeros(bs, dtype=torch.float, device=x_inp.device) |
|
eos_idx = len(texttok.vocab) + speechtok.special_tokens['<|endofspeech|>'] |
|
n_vocab = len(texttok.vocab) + len(speechtok.vocab) |
|
|
|
logging.info(f"Starting beam decoding with beam_width={beam_width}") |
|
|
|
prev_ids = [[] for _ in range(bs)] |
|
|
|
cache = None |
|
if use_kv_cache: |
|
|
|
cache_window = min(codeclm.ar.args.sliding_window, x_inp.shape[-1] + max_len) |
|
cache = RotatingBufferCache(codeclm.ar.args.n_layers, bs, cache_window, codeclm.ar.args.n_kv_heads, codeclm.ar.args.head_dim) |
|
cache.to(device=x_inp.device, dtype=torch.float16) |
|
|
|
counter = 0 |
|
while x_inp.shape[-1] < max_len: |
|
counter += 1 |
|
gen_length = torch.tensor([x_inp.shape[-1] for _ in range(bs)], dtype=torch.long, device=xx.device) |
|
padding_mask = length_to_mask(gen_length, offsets) |
|
|
|
with torch.autocast('cuda', enabled=fp16): |
|
logits: Tensor = codeclm(x_inp, padding_mask, spk_reference=ss_gen, cache=cache, counter=counter) |
|
logits = logits.float() |
|
|
|
logits = logits[:, -1] |
|
|
|
|
|
filtered_logits = logits.clone() |
|
|
|
|
|
if len(prev_ids[0]) > 1: |
|
filtered_logits = freq_rep_penalty(filtered_logits, previous=torch.tensor(prev_ids, dtype=torch.long), |
|
alpha_frequency=alpha_frequency, alpha_presence=alpha_presence, |
|
penalty_window=penalty_window) |
|
|
|
filtered_logits[..., :valid_logit_idx_start-1] = float('-inf') |
|
filtered_logits[..., valid_logit_idx_end:] = float('-inf') |
|
|
|
if n_phones_gen is not None: |
|
|
|
filtered_logits = early_eos_penalty(filtered_logits, len(prev_ids[0]), n_phones_gen, |
|
eos_penalty_decay, eos_penalty_factor, |
|
eos_index=eos_idx) |
|
|
|
filtered_logits = filtered_logits / temperature |
|
filtered_logits = top_k_top_p_filtering(filtered_logits, top_k=topk, top_p=top_p) |
|
filtered_logits = apply_typical_p(filtered_logits, mass=typical_p) |
|
|
|
|
|
filtered_logits[..., :valid_logit_idx_start-1] = float('-inf') |
|
filtered_logits[..., valid_logit_idx_end:] = float('-inf') |
|
logits = filtered_logits |
|
|
|
|
|
|
|
logprobs = logits.log_softmax(dim=-1) |
|
|
|
|
|
|
|
for j in range(bs): |
|
if x_inp[j, -1] == eos_idx: |
|
|
|
logprobs[j] = float('-inf') |
|
logprobs[j, eos_idx] = 0 |
|
|
|
candidate_cum_logprobs = cum_logprobs[:, None] + logprobs |
|
|
|
logp_flat = logprobs.flatten() |
|
candidates = torch.multinomial(logp_flat.exp(), num_samples=beam_width, replacement=False) |
|
|
|
beam_idxs = candidates // n_vocab |
|
tok_inds_in_each_beam = candidates % n_vocab |
|
|
|
|
|
if torch.all(tok_inds_in_each_beam == eos_idx): |
|
|
|
non_eos_toks = (x_inp != eos_idx).sum(dim=-1) |
|
gen_length = non_eos_toks - first_codex_idx |
|
penalties = (gen_length**beam_length_penalty) |
|
penalized_cum_tok_logp = candidate_cum_logprobs / penalties[:, None] |
|
|
|
eos_avg_logps = penalized_cum_tok_logp[:, eos_idx] |
|
best_beam_idx = eos_avg_logps.argmax() |
|
best_avg_logp = eos_avg_logps[best_beam_idx] |
|
best_beam = x_inp[best_beam_idx] |
|
logging.info((f"best beam = {best_beam_idx} @ penalized_cum_tok_logp = {best_avg_logp.item():.3f} |\n num toks: {non_eos_toks.cpu().tolist()}. " |
|
f"Candidates: {eos_avg_logps.cpu()} |\n non-eos toks: {non_eos_toks.cpu().tolist()} |\n penalties: {penalties.cpu().tolist()} | " |
|
f"raw cumulative probs: {candidate_cum_logprobs[:, eos_idx].cpu().tolist()}")) |
|
break |
|
|
|
|
|
x_inp = x_inp[beam_idxs] |
|
|
|
next_sample = tok_inds_in_each_beam |
|
|
|
cum_logprobs = cum_logprobs[beam_idxs] + logprobs[beam_idxs, tok_inds_in_each_beam] |
|
|
|
prev_ids = [copy.deepcopy(prev_ids[beam_idx.item()]) for beam_idx in beam_idxs] |
|
|
|
for j in range(bs): |
|
prev_ids[j].append(tok_inds_in_each_beam[j].item()) |
|
|
|
logging.debug("L%d | next sample: %s | beam: %s | cum_logp: %s", len(x_inp[0]), next_sample.cpu().tolist(), beam_idxs.cpu().tolist(), cum_logprobs.cpu()) |
|
|
|
|
|
if cache is not None: |
|
cache.cache_k = cache.cache_k[:, beam_idxs] |
|
cache.cache_v = cache.cache_v[:, beam_idxs] |
|
|
|
|
|
x_inp = torch.cat([x_inp, next_sample[:, None]], dim=-1) |
|
|
|
|
|
if x_inp.shape[-1] >= max_len - 1: |
|
logging.warning(f"[autoregressive generation] output length = {x_inp.shape[-1]} -- inference likely failed or input too long!") |
|
best_beam = x_inp[0] |
|
|
|
if not vocode: return best_beam |
|
else: raise AssertionError() |
|
|