| import torch | |
| import torch.nn.functional as F | |
| def top_p_filtering(logits, top_p: float = 1.0): | |
| """ | |
| Filter a distribution of logits using top-p filtering. | |
| The input logits tensor is modified in-place. | |
| Args: | |
| logits (torch.Tensor): A tensor of logits to be filtered. Expected shape is [..., vocab_size]. | |
| top_p (float, optional): The cumulative probability threshold for top-p sampling. | |
| If < 1.0, only keep the smallest set of tokens whose | |
| cumulative probability does not exceed this threshold. | |
| Returns: | |
| torch.Tensor: logits where values outside the top-p threshold are set to -β. | |
| """ | |
| if top_p < 1.0: | |
| sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) | |
| sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p | |
| sorted_idx_to_remove[..., 0] = False | |
| idx_to_remove = sorted_idx_to_remove.scatter( | |
| -1, sorted_idx, sorted_idx_to_remove | |
| ) | |
| logits.masked_fill_(idx_to_remove, -torch.inf) | |
| return logits | |
| def process_logits( | |
| logits, | |
| top_p: float = None, | |
| ): | |
| """ | |
| Process logits by optionally applying nucleus (top-p) filtering and token selection. | |
| If `top_p` is None, the token with the highest probability (argmax) is selected. | |
| If `top_p` is provided, smallest set of tokens with cumulative probability β₯ top_p are kept, then softmax is applied to obtain | |
| probabilities. A token is sampled from this filtered distribution using `torch.multinomial`. | |
| Args: | |
| logits (torch.Tensor): A tensor of logits to process. | |
| top_p (float, optional): The cumulative probability threshold for nucleus sampling. | |
| If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability β₯ top_p are kept (stochastic generation). | |
| Returns: | |
| torch.Tensor: selected token index. | |
| """ | |
| if top_p is None: | |
| next_id = torch.argmax(logits, dim=-1, keepdim=True) | |
| else: | |
| logits = top_p_filtering(logits, top_p=0.9) | |
| probs = F.softmax(logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1, replacement=True) | |
| return next_id | |