import logging from typing import Tuple import torch from outlines.samplers import MultinomialSampler logger = logging.getLogger(__name__) class PenalizedMultinomialSampler(MultinomialSampler): def __init__(self, **kwargs): super().__init__(**kwargs) self.penalized_tokens_group: list[torch.IntTensor] = [] self.max_repeats_per_token_group: list[int] = [] self.repeats_per_token_group: list[int] = [] self.token_id_to_tokens_groups: list[list[int]] = [] def set_max_repeats(self, token_ids: list[int], max_repeats: int) -> None: max_token_ids = max(token_ids) if max_token_ids >= len(self.token_id_to_tokens_groups): self.token_id_to_tokens_groups += [[] for _ in range(len(self.token_id_to_tokens_groups), max_token_ids + 1)] for token_id in token_ids: self.token_id_to_tokens_groups[token_id].append(len(self.penalized_tokens_group)) self.penalized_tokens_group.append(torch.tensor(token_ids, dtype=torch.int32)) self.max_repeats_per_token_group.append(max_repeats) self.repeats_per_token_group.append(0) def __call__( self, next_token_logits: torch.DoubleTensor, sequence_weights: torch.DoubleTensor, rng: torch.Generator, ) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]: """Call the multinomial sampler. Parameters ---------- next_token_logits A tensor of shape ``(n_seqs, vocab_size,)`` that represents the probability distribution of the next token over the vocabulary. sequence_weights A tensor of shape ``(n_seqs,)`` that represents the cumulative weight of each sequence. rng A random number generator. Returns ------- A tuple with an array that contains the ids of the sampled tokens of shape ``(n_seqs, 1)``, an array that contains the ancestors of each sampled id of shape ``(n_seqs,)`` and an array that contains the updated cumulative weights of each sequence of shape ``(n_seqs,)``. """ if sequence_weights.min() == sequence_weights.max() == 0: self.repeats_per_token_group = [0] * len(self.repeats_per_token_group) else: for penalized_tokens_group, max_repeats_per_token_group, repeats_per_token_group in zip(self.penalized_tokens_group, self.max_repeats_per_token_group, self.repeats_per_token_group): if repeats_per_token_group >= max_repeats_per_token_group: penalty = torch.zeros_like(next_token_logits) penalty[:, penalized_tokens_group] = - torch.inf next_token_logits = next_token_logits + penalty next_token_ids, ancestors, weights = super().__call__( next_token_logits=next_token_logits, sequence_weights=sequence_weights, rng=rng ) for next_token_id in next_token_ids.cpu(): if next_token_id < len(self.token_id_to_tokens_groups): for token_group in self.token_id_to_tokens_groups[next_token_id]: self.repeats_per_token_group[token_group] += 1 return next_token_ids, ancestors, weights