| |
| import torch |
|
|
|
|
| |
| |
| def modify_logits_for_top_k_filtering(logits, top_k): |
| """Set the logits for none top-k values to -inf. Done in-place.""" |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| logits.masked_fill_(indices_to_remove, float("-Inf")) |
|
|
|
|
| |
| |
| def modify_logits_for_top_p_filtering(logits, top_p): |
| """Set the logits for none top-p values to -inf. Done in-place.""" |
| if top_p <= 0.0 or top_p >= 1.0: |
| return |
|
|
| |
| sorted_logits, sorted_indices = torch.sort(logits, descending=False) |
| cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
| |
| sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |
| |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| logits.masked_fill_(indices_to_remove, float("-inf")) |
|
|
|
|
| |
| def sample(logits, top_k=1, top_p=0.0, temperature=1.0): |
| """Sample from top-k logits. |
| Arguments: |
| logits: Tensor of shape (batch_size, vocab_size) |
| """ |
| logits = torch.nan_to_num(logits) |
| logits = torch.where(logits == float("-inf"), 0, logits) |
| logits = torch.where(logits == float("inf"), 0, logits) |
|
|
| if top_k == 1: |
| return logits.argmax(dim=-1) |
| else: |
| if top_p > 0.0: |
| assert top_p <= 1.0, "top-p should be in (0, 1]." |
| if top_k > 0: |
| top_k = min(top_k, logits.size(-1)) |
| logits_top, indices = torch.topk(logits, top_k, dim=-1) |
| if temperature != 1.0: |
| logits_top /= temperature |
| modify_logits_for_top_p_filtering(logits_top, top_p) |
|
|
| return indices[ |
| torch.arange(indices.shape[0], device=indices.device), |
| torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), |
| ] |
| else: |
| |
| logits_top = logits / temperature if temperature != 1.0 else logits.clone() |
| modify_logits_for_top_p_filtering(logits_top, top_p) |
| return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) |
|
|