lhoestq's picture
lhoestq HF staff
lower columns temperature
72a89db
raw
history blame
3.31 kB
import logging
from typing import Tuple
import torch
from outlines.samplers import MultinomialSampler
logger = logging.getLogger(__name__)
class PenalizedMultinomialSampler(MultinomialSampler):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.penalized_tokens_group: list[torch.IntTensor] = []
self.max_repeats_per_token_group: list[int] = []
self.repeats_per_token_group: list[int] = []
self.token_id_to_tokens_groups: list[list[int]] = []
def set_max_repeats(self, token_ids: list[int], max_repeats: int) -> None:
max_token_ids = max(token_ids)
if max_token_ids >= len(self.token_id_to_tokens_groups):
self.token_id_to_tokens_groups += [[] for _ in range(len(self.token_id_to_tokens_groups), max_token_ids + 1)]
for token_id in token_ids:
self.token_id_to_tokens_groups[token_id].append(len(self.penalized_tokens_group))
self.penalized_tokens_group.append(torch.tensor(token_ids, dtype=torch.int32))
self.max_repeats_per_token_group.append(max_repeats)
self.repeats_per_token_group.append(0)
def __call__(
self,
next_token_logits: torch.DoubleTensor,
sequence_weights: torch.DoubleTensor,
rng: torch.Generator,
) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
"""Call the multinomial sampler.
Parameters
----------
next_token_logits
A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
probability distribution of the next token over the vocabulary.
sequence_weights
A tensor of shape ``(n_seqs,)`` that represents the cumulative
weight of each sequence.
rng
A random number generator.
Returns
-------
A tuple with an array that contains the ids of the sampled tokens of
shape ``(n_seqs, 1)``, an array that contains the ancestors of each
sampled id of shape ``(n_seqs,)`` and an array that contains the updated
cumulative weights of each sequence of shape ``(n_seqs,)``.
"""
if sequence_weights.min() == sequence_weights.max() == 0:
self.repeats_per_token_group = [0] * len(self.repeats_per_token_group)
else:
for penalized_tokens_group, max_repeats_per_token_group, repeats_per_token_group in zip(self.penalized_tokens_group, self.max_repeats_per_token_group, self.repeats_per_token_group):
if repeats_per_token_group >= max_repeats_per_token_group:
penalty = torch.zeros_like(next_token_logits)
penalty[:, penalized_tokens_group] = - torch.inf
next_token_logits = next_token_logits + penalty
next_token_ids, ancestors, weights = super().__call__(
next_token_logits=next_token_logits,
sequence_weights=sequence_weights,
rng=rng
)
for next_token_id in next_token_ids.cpu():
if next_token_id < len(self.token_id_to_tokens_groups):
for token_group in self.token_id_to_tokens_groups[next_token_id]:
self.repeats_per_token_group[token_group] += 1
return next_token_ids, ancestors, weights