| |
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| |
| def top_k_top_p_filtering( |
| logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 |
| ): |
| """ |
| Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. |
| |
| Args: |
| logits (torch.Tensor): Logits distribution with shape (batch size, vocabulary size). |
| top_k (int, optional): Keep only top k tokens with highest probability (top-k filtering). |
| Set to 0 to disable. Defaults to 0. |
| top_p (float, optional): Keep the top tokens with a cumulative probability >= top_p (nucleus filtering). |
| Must be between 0 and 1, inclusive. Defaults to 1.0. |
| filter_value (float, optional): The value to assign to filtered logits. Defaults to -float('Inf'). |
| min_tokens_to_keep (int, optional): Ensure that at least this number of tokens are kept per batch example. |
| Defaults to 1. |
| |
| Returns: |
| torch.Tensor: The filtered logits. |
| """ |
| """ |
| Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
| Make sure we keep at least min_tokens_to_keep per batch example in the output |
| From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
| """ |
| if top_k > 0: |
| |
| top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) |
| indices_to_remove = logits < torch.topk(logits, top_k).values[..., -1, None] |
| logits[indices_to_remove] = filter_value |
|
|
| if top_p < 1.0: |
| |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
| |
| sorted_indices_to_remove = cumulative_probs > top_p |
| if min_tokens_to_keep > 1: |
| sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
|
|
| |
| indices_to_remove = sorted_indices.scatter(1, sorted_indices, sorted_indices_to_remove) |
| logits[indices_to_remove] = filter_value |
|
|
| return logits |
|
|
|
|
| def topk_sampling(logits, top_k=50, top_p=1.0, temperature=1.0): |
| """ |
| Perform top-k and top-p sampling on logits. |
| |
| Args: |
| logits (torch.Tensor): The logits to sample from. |
| top_k (int, optional): The number of highest probability tokens to keep for top-k filtering. |
| Must be a positive integer. Defaults to 50. |
| top_p (float, optional): The cumulative probability threshold for nucleus sampling. |
| Must be between 0 and 1. Defaults to 1.0. |
| temperature (float, optional): The scaling factor to adjust the logits distribution. |
| Must be strictly positive. Defaults to 1.0. |
| |
| Returns: |
| torch.Tensor: The sampled token. |
| """ |
|
|
| |
| if temperature != 1.0: |
| logits = logits / temperature |
|
|
| |
| logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
| |
| token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) |
| return token |
|
|