from typing import Optional, Tuple import numpy as np import torch import tqdm from torch.nn import functional as F from IPython import embed def top_p_sample(prob_dist: torch.Tensor, top_p: float): sorted_probs, sorted_indices = torch.sort(prob_dist, descending=True, dim=-1) cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) # (b, vocab_size) sorted_indices_to_remove = cum_sum_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 sorted_indices_to_remove = sorted_indices_to_remove.bool() # replace probs to be removed with 0 in the sorted_probs sorted_probs[sorted_indices_to_remove] = 0 # reverse the sorting process reversed_indices = torch.argsort(sorted_indices) prob_dist = torch.gather(sorted_probs, -1, reversed_indices) # normalize prob_dist = prob_dist / prob_dist.sum(dim=-1, keepdim=True) return prob_dist class CausalInferenceMixin: """ Mixin class for performing inference in a causal language model. This mixin provides methods for predicting the next token in a sequence, sampling from the model, and applying token prediction masks. Attributes: None Methods: _sample_next_token: Predicts the next token in the sequence. _create_token_pred_mask: Creates a token prediction mask based on sequence lengths. _apply_token_pred_mask: Applies a token prediction mask to the next token predictions. _sample_batch: Samples a batch of tokens from the model. _sort_for_batching: Sorts the input sequences for efficient batching. _causal_sample: Generates a sequence of tokens using causal sampling. """ @torch.no_grad() def _sample_next_token( self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: Optional[int], top_p: Optional[float], guidance_scale: Optional[float], ) -> torch.Tensor: """ Predict the next token in the sequence. Args: idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model. temperature (float): Sampling temperature. top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering. top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it. guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance. Returns: torch.Tensor: Next index in the sequence after sampling. Shape: (batch, num_hierarchies). """ if top_k is not None and top_p is not None: raise ValueError("Only one of top_k and top_p can be set") # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, :, -self.config.block_size :] # forward the model to get the logits for the index in the sequence list_logits, _ = self( idx_cond, speaker_embs=speaker_embs ) # list with len num_hierarchies of (b,1,vocab_size) tensors # print(f'{list_logits[0].shape=}, {len(list_logits)=}') # print(f'{list_logits[0][:,:,:10]}') if guidance_scale is not None: assert idx_cond.shape[0] % 2 == 0 assert list_logits[0].shape[0] % 2 == 0 for i, logits in enumerate(list_logits): logits_cond, logits_uncond = logits.split(logits.shape[0] // 2, dim=0) list_logits[i] = (guidance_scale) * logits_cond + (1 - guidance_scale) * logits_uncond assert list_logits[0].shape[0] == idx_cond.shape[0] // 2 # pluck the logits at the final step and scale by desired temperature list_logits = [ logits[:, -1, :] / temperature for logits in list_logits ] # list with len num_hierarchies of (b,vocab_size) tensors # optionally crop the logits to only the top k options if top_k is not None: for i in range(len(list_logits)): logits = list_logits[i] v, _ = torch.topk( logits, min(top_k, logits.size(-1)) ) # returns a descending sorted list of values and indices of top_k values logits[logits < v[:, [-1]]] = -float("Inf") # set all logits below the smallest top_k value to -Inf list_logits[i] = logits # apply softmax to convert logits to (normalized) probabilities # embed() probs = [ F.softmax(logits, dim=-1) for logits in list_logits ] # list of len num_hierarchies of (b,vocab_size) tensors # print(f'{probs[0].shape=}') # print(f'{probs[0][:,:,:10]}') if top_p is not None: for i in range(len(probs)): probs[i] = top_p_sample(probs[i], top_p) # sample from the distribution idx_next = [ torch.multinomial(prob, num_samples=1) for prob in probs ] # list of len num_hierarchies of (b,1) tensors idx_next = torch.cat(idx_next, dim=-1) # (b, num_hierarchies) tensor return idx_next # (b, num_hierarchies) tensor @torch.no_grad() def _create_token_pred_mask(self, idx: torch.Tensor, seq_lens: list[int]) -> torch.Tensor: """ Creates a token prediction mask based on sequence lengths. Args: idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). seq_lens (list[int]): List of sequence lengths for each sequence in idx. Returns: torch.Tensor: Token prediction mask of shape (batch, time). """ token_pred_mask = torch.zeros((idx.shape[0], idx.shape[-1]), dtype=torch.bool, device=idx.device) for i in range(len(seq_lens)): token_pred_mask[i, : seq_lens[i]] = True assert (token_pred_mask[:, : min(seq_lens)] == 1).all() return token_pred_mask @torch.no_grad() def _apply_token_pred_mask( self, *, idx_next: torch.Tensor, orig_input_at_t: torch.Tensor, token_pred_mask_at_t: torch.Tensor ) -> torch.Tensor: """ Applies a token prediction mask to the next token predictions. Args: idx_next (torch.Tensor): Next token predictions of shape (batch, num_hierarchies). orig_input_at_t (torch.Tensor): Original input at time step t of shape (batch, num_hierarchies). token_pred_mask_at_t (torch.Tensor): Token prediction mask at time step t of shape (batch, 1). Returns: torch.Tensor: Updated next token predictions after applying the token prediction mask. """ idx_next = idx_next * (~token_pred_mask_at_t) + orig_input_at_t * token_pred_mask_at_t return idx_next @torch.no_grad() def _sample_batch( self, *, idx: torch.Tensor, max_new_tokens: int, seq_lens: list[int], temperature: float, top_k: Optional[int], top_p: Optional[float], speaker_embs: Optional[torch.Tensor], guidance_scale: Optional[float], ): """ Samples a batch of tokens from the model. Args: idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx). seq_lens (list[int]): List of sequence lengths for each sequence in idx. temperature (float): Sampling temperature. top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering. top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it. speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model. guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance. Returns: torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time). """ assert max(seq_lens) <= idx.shape[-1] token_pred_mask = self._create_token_pred_mask(idx, seq_lens) input = torch.clone(idx) min_seq_lens = min(seq_lens) idx = idx[:, :, :min_seq_lens] if guidance_scale is not None: if speaker_embs is None: raise Exception("Guidance is only supported for conditional models") # create speaker embeddings equivalent to the batch size, filling with None # for second half to do unconditional generation. speaker_embs = list(speaker_embs) + [None] * (speaker_embs.shape[0]) for timestep in tqdm.tqdm(range(min_seq_lens, min_seq_lens + max_new_tokens), desc="tokens: "): if (self.kv_cache_enabled is True) and (timestep > min_seq_lens): idx_input = idx[:, :, -1:] else: idx_input = idx if guidance_scale is not None: # TODO: fix: will cause a problem with kv-caching as it's not expecting larger batch-size. if timestep == min_seq_lens: print("[hack!!!!] Guidance is on, so we're doubling batch size!") # replicate idx in the batch dimension idx_input = ( idx_input.unsqueeze(0).repeat(2, 1, 1, 1).reshape(-1, idx_input.shape[1], idx_input.shape[2]) ) # sanity checks assert idx_input.shape[0] % 2 == 0 idx_next = self._sample_next_token( idx=idx_input, speaker_embs=speaker_embs, temperature=temperature, top_k=top_k, top_p=top_p, guidance_scale=guidance_scale, ) # (b, num_hierarchies) assert idx_next.shape[0] == idx.shape[0] if timestep < token_pred_mask.shape[-1]: idx_next = self._apply_token_pred_mask( idx_next=idx_next, orig_input_at_t=input[:, :, timestep], token_pred_mask_at_t=token_pred_mask[:, [timestep]], ) idx_next = idx_next.unsqueeze(-1) # (b, num_hierarchies, T=1) tensor # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=2) return idx @torch.no_grad() def _sort_for_batching( self, *, idx: torch.Tensor, seq_lens: list[int], speaker_embs: Optional[torch.Tensor], batch_size: int, max_new_tokens: int, ) -> Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]: """ Sorts the input sequences for efficient batching. Args: idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). seq_lens (list[int]): List of sequence lengths for each sequence in idx. speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model. batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling. max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx). Returns: Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]: - sorted_indices (list[int]): List of indices of the input sequences that transform it into sorted order. - invert_sorted_indices (list[int]): List of indices to invert the sorted sequences back to the original order. - idx (torch.Tensor): Input sequence indices in sorted order. - seq_lens (list[int]): Sequence lengths in sorted order. - speaker_embs (Optional[torch.Tensor]): speaker embeddings in sorted order. - max_token_len (int): Effective maximum number of tokens to generate. """ assert len(seq_lens) == idx.shape[0] assert max(seq_lens) <= idx.shape[-1] sorted_indices = np.argsort(seq_lens) inverted_sorted_indices = np.zeros(len(seq_lens), dtype=np.int32) inverted_sorted_indices[sorted_indices] = np.arange(len(seq_lens), dtype=np.int32) idx = idx[sorted_indices] seq_lens = [seq_lens[i] for i in sorted_indices] speaker_embs = speaker_embs[sorted_indices] if speaker_embs is not None else None max_token_len = 0 # figure out effective max_tokens to generate for start_index in range(0, len(seq_lens), batch_size): end_index = min(start_index + batch_size, len(seq_lens)) batch_seq_lens = seq_lens[start_index:end_index] # random heuristic... # # TODO: fix! max_token_len = max(max_token_len, min(batch_seq_lens) + max_new_tokens) return sorted_indices, inverted_sorted_indices, idx, seq_lens, speaker_embs, max_token_len @torch.no_grad() def _causal_sample( self, *, idx: torch.Tensor, max_new_tokens: int, seq_lens: list[int], temperature: float, top_k: Optional[int], top_p: Optional[float], speaker_embs: Optional[torch.Tensor], batch_size: int, guidance_scale: Optional[float] = None, ) -> torch.Tensor: """ Generates a sequence of tokens using causal sampling. Args: idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time). max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx). seq_lens (list[int]): List of sequence lengths for each sequence in idx. temperature (float): Sampling temperature. top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering. top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it. speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model. batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling. guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance. Returns: torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time). """ ( _, invert_sorted_indices, idx, seq_lens, speaker_embs, max_token_len, ) = self._sort_for_batching( idx=idx, seq_lens=seq_lens, speaker_embs=speaker_embs, batch_size=batch_size, max_new_tokens=max_new_tokens ) return_idx = torch.zeros((len(seq_lens), idx.size(1), max_token_len), dtype=torch.long, device=idx.device) for start_index in tqdm.tqdm(range(0, len(seq_lens), batch_size), desc="batch: "): end_index = min(start_index + batch_size, len(seq_lens)) kv_batch_size = end_index - start_index if guidance_scale is not None: kv_batch_size = 2 * kv_batch_size if self.kv_cache_enabled: print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16") self.empty_kv_cache( batch_size=kv_batch_size, kv_cache_maxlen=self.config.block_size, dtype=torch.bfloat16, ) batch_seq_lens = seq_lens[start_index:end_index] batch_max_new_tokens = max_token_len - min(batch_seq_lens) batch_idx = idx[start_index:end_index] batch_speaker_embs = speaker_embs[start_index:end_index] if speaker_embs is not None else None batch_idx = self._sample_batch( idx=batch_idx, max_new_tokens=batch_max_new_tokens, seq_lens=batch_seq_lens, temperature=temperature, top_k=top_k, top_p=top_p, speaker_embs=batch_speaker_embs, guidance_scale=guidance_scale, ) return_idx[start_index:end_index] = batch_idx return return_idx[invert_sorted_indices] def empty_kv_cache(self, *, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype): """ Empties key-value (KV) cache for causal attention. Args: batch_size (int): The batch size. kv_cache_maxlen (int): The maximum length of the KV cache. dtype (torch.dtype): The data type of the KV cache. Raises: Exception: If KV cache is enabled for non-causal attention. """ if self.kv_cache_enabled is False: raise Exception("KV cache is not enabled") if self.config.causal is False: raise Exception("KV cache is not supported for non-causal attention") self.kv_pos = 0 for block in self.transformer.h: block.attn.empty_kv_cache(batch_size=batch_size, kv_cache_maxlen=kv_cache_maxlen, dtype=dtype) def enable_kv_cache(self): """ Enables key-value (KV) cache for causal attention. Raises: Exception: If KV cache is enabled for non-causal attention. """ if self.config.causal is False: raise Exception("KV cache is not supported for non-causal attention") self.kv_cache_enabled = True for block in self.transformer.h: block.attn.kv_cache_enabled = True def disable_kv_cache(self): """ Disables the key-value cache for the transformer and all its blocks. """ self.kv_cache_enabled = False for block in self.transformer.h: block.attn.kv_cache_enabled = False block.attn.kv_cache = None block.attn.kv_cache_first_empty_index = 0 @torch.no_grad() def _slow_causal_sampling_loop( self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, speaker_embs: Optional[torch.Tensor] = None, guidance_scale: Optional[float] = None, ): """ Old non-batched version of causal sampling. Kept for testing / reference. Take a conditioning sequence of indices idx (LongTensor of shape (b,n_head,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this. """ assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens" assert idx.size(0) == 1, "can only do one sequence at a time for now" assert top_p is None, "nucleus sampling not supported yet with _slow_causal_sampling_loop" if self.config.causal is not True: raise Exception("Causal sampling is only supported for causal models") if self.kv_cache_enabled: print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16") self.empty_kv_cache( batch_size=1, kv_cache_maxlen=self.config.block_size, dtype=torch.bfloat16, ) for i in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, -self.config.block_size :] if self.kv_cache_enabled: if i > 0: idx_cond = idx_cond[:, :, -1:] # forward the model to get the logits for the index in the sequence list_logits, _ = self(idx_cond, speaker_embs=speaker_embs) if guidance_scale is not None: # we've already checked that kv-caching is not switched on # so this should be ok. list_logits_uncond, _ = self(idx_cond, speaker_embs=None) list_logits = [ (guidance_scale) * logits + (1 - guidance_scale) * logits_uncond for logits, logits_uncond in zip(list_logits, list_logits_uncond) ] # pluck the logits at the final step and scale by desired temperature list_logits = [logits[:, -1, :] / temperature for logits in list_logits] # optionally crop the logits to only the top k options if top_k is not None: for i in range(len(list_logits)): logits = list_logits[i] v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("Inf") list_logits[i] = logits # apply softmax to convert logits to (normalized) probabilities probs = [F.softmax(logits, dim=-1) for logits in list_logits] # sample from the distribution idx_next = torch.tensor( [torch.multinomial(prob, num_samples=1) for prob in probs], device=idx.device ) # (c, 1) # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(-1)), dim=2) return idx