Anole / chameleon /inference /logits_processor.py
xuefengli
update
7362797
raw
history blame
11.9 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from transformers import LogitsProcessor
class TopPProbabilityProcessor(LogitsProcessor):
# Modified version of TopPLogitsWarper to act on probabilities.
# Changes:
# * filter_value changed from -inf to 0
# * removed softmax
# * renormalize L1
def __init__(
self,
top_p: float,
min_tokens_to_keep: int = 1,
):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
)
self.top_p = top_p
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape=[batch, seq-len]
# probs.shape=[batch, vocab]
sorted_probs, sorted_indices = torch.sort(probs, descending=False)
cumulative_probs = sorted_probs.cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
probs = probs.masked_fill(indices_to_remove, 0.0)
probs = probs / probs.sum(dim=-1, keepdim=True)
return probs
class DisallowTokensInIndexRangeLogitsProcessor(LogitsProcessor):
def __init__(
self, token_ids: list[int], start_index: int, end_index: int | None = None
):
self.token_ids = torch.tensor(token_ids)
self.start_index = start_index
self.end_index = end_index if end_index is not None else math.inf
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
current_index = input_ids.shape[1]
if self.start_index <= current_index < self.end_index:
logits[:, self.token_ids] = -math.inf
return logits
class DisallowTokensLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
def __init__(self, token_ids: list[int]):
super().__init__(token_ids, 0)
class DisallowTokensAtIndexLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
def __init__(self, token_ids: list[int], index: int):
super().__init__(token_ids, index, index + 1)
class DisallowTokensAfterIndexLogitsProcessor(
DisallowTokensInIndexRangeLogitsProcessor
):
def __init__(self, token_ids: list[int], index: int):
super().__init__(token_ids, index + 1)
class DisallowTokensAtOrAfterIndexLogitsProcessor(
DisallowTokensInIndexRangeLogitsProcessor
):
def __init__(self, token_ids: list[int], index: int):
super().__init__(token_ids, index)
class DisallowTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
def __init__(
self,
token_ids: list[int],
start_indices: list[int],
end_indices: list[int] | None = None,
):
self.token_ids = torch.tensor(token_ids)
self.start_indices = torch.tensor(start_indices)
self.end_indices = (
torch.tensor(end_indices)
if end_indices is not None
else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
)
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape = [batch, seq_len]
# logits.shape = [batch, vocab]
current_index = input_ids.shape[1]
mask = (self.start_indices <= current_index) & (
current_index < self.end_indices
)
# The following will fail if the mask is all False.
# logits[mask, self.token_ids] = -math.inf
logits[torch.where(mask)[0].unsqueeze(1), self.token_ids] = -math.inf
return logits
class DisallowTokensAtBatchIndexLogitsProcessor(
DisallowTokensInBatchIndexRangeLogitsProcessor
):
def __init__(self, token_ids: list[int], batch_index: list[int]):
super().__init__(token_ids, batch_index, [i + 1 for i in batch_index])
class AllowOnlyTokensInIndexRangeLogitsProcessor(LogitsProcessor):
def __init__(
self, token_ids: list[int], start_index: int, end_index: int | None = None
):
self.token_ids = torch.tensor(token_ids)
self.start_index = start_index
self.end_index = end_index if end_index is not None else math.inf
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
current_index = input_ids.shape[1]
if self.start_index <= current_index < self.end_index:
replacement = torch.full_like(logits, -math.inf)
replacement[:, self.token_ids] = logits[:, self.token_ids]
logits[:] = replacement
return logits
class AllowOnlyTokensLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
def __init__(self, token_ids: list[int]):
super().__init__(token_ids, 0)
class AllowOnlyTokensAtIndexLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
def __init__(self, token_ids: list[int], index: int):
super().__init__(token_ids, index, index + 1)
class AllowOnlyTokensAfterIndexLogitsProcessor(
AllowOnlyTokensInIndexRangeLogitsProcessor
):
def __init__(self, token_ids: list[int], index: int):
super().__init__(token_ids, index + 1)
class AllowOnlyTokensAtOrAfterIndexLogitsProcessor(
AllowOnlyTokensInIndexRangeLogitsProcessor
):
def __init__(self, token_ids: list[int], index: int):
super().__init__(token_ids, index)
class AllowOnlyTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
def __init__(
self,
token_ids: list[int],
start_indices: list[int],
end_indices: list[int] | None = None,
):
self.token_ids = torch.tensor(token_ids)
self.start_indices = torch.tensor(start_indices)
self.end_indices = (
torch.tensor(end_indices)
if end_indices is not None
else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
)
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape = [batch, seq_len]
# logits.shape = [batch, vocab]
current_index = input_ids.shape[1]
mask = (self.start_indices <= current_index) & (
current_index < self.end_indices
)
valid_batch_indices = torch.where(mask)[0].unsqueeze(1)
full_mask = torch.full_like(logits, -math.inf)
full_mask[valid_batch_indices, self.token_ids] = logits[
valid_batch_indices, self.token_ids
]
logits[:] = torch.where(full_mask != -math.inf, full_mask, logits)
return logits
class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor):
def __init__(
self, trigger_token_id: int, subsequent_token_ids: list[int], offset: int
):
self.trigger_token_id = trigger_token_id
self.subsequent_token_ids = torch.tensor(subsequent_token_ids)
self.offset = offset
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape=[batch, seq_len]
# logits.shape=[batch, vocab]
if input_ids.shape[1] < self.offset:
return logits
trigger_positions = (
input_ids[:, -self.offset] == self.trigger_token_id
).unsqueeze(-1)
disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
disallowed_tokens_mask[:, self.subsequent_token_ids] = False
return logits.masked_fill_(
disallowed_tokens_mask & trigger_positions,
-math.inf,
)
class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor):
def __init__(self, trigger_token_id: int, allowed_token_ids: list[int], width: int):
self.trigger_token_id = trigger_token_id
self.allowed_token_ids = torch.tensor(allowed_token_ids).unsqueeze(
0
) # shape: [1, num_allowed_tokens]
self.width = width
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape=[batch, seq_len]
# logits.shape=[batch, vocab]
width = min(self.width, input_ids.shape[1])
trigger_positions = (
(input_ids[:, -width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1)
)
disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
disallowed_tokens_mask[:, self.allowed_token_ids] = False
return logits.masked_fill_(
disallowed_tokens_mask & trigger_positions,
-math.inf,
)
class CFGLogitsProcessor(LogitsProcessor):
def __init__(
self,
guidance_scale: float,
unconditional_ids: torch.LongTensor,
model,
):
self.guidance_scale = guidance_scale
self.unconditional_ids = unconditional_ids
self.model = model
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
conditioned_logits = logits
self.unconditional_ids = torch.cat(
[self.unconditional_ids, input_ids[:, -1:]], dim=1
)
unconditioned_outputs = self.model(self.unconditional_ids)
unconditioned_logits = unconditioned_outputs[:, -1, :]
return (
self.guidance_scale * (conditioned_logits - unconditioned_logits)
+ unconditioned_logits
)
class InBatchCFGLogitsProcessor(LogitsProcessor):
def __init__(self, guidance_scale: float):
self.guidance_scale = guidance_scale
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape=[2*batch, seq-len]
# logits.shape=[2*batch, vocab]
conditioned_logits, unconditioned_logits = torch.chunk(logits, chunks=2, dim=0)
mixed_logits = unconditioned_logits + self.guidance_scale * (
conditioned_logits - unconditioned_logits
)
return mixed_logits.repeat(2, 1)
class InBatchInstructCFGLogitsProcessor(LogitsProcessor):
# See https://arxiv.org/abs/2211.09800
def __init__(self, guidance_scale_text: float, guidance_scale_image: float):
self.guidance_scale_text = guidance_scale_text
self.guidance_scale_image = guidance_scale_image
def __call__(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape=[3*batch, seq-len]
# logits.shape=[3*batch, vocab]
(
full_conditioned_logits,
image_conditioned_logits,
unconditioned_logits,
) = logits.chunk(3)
mixed_logits = (
unconditioned_logits
+ self.guidance_scale_image
* (image_conditioned_logits - unconditioned_logits)
+ self.guidance_scale_text
* (full_conditioned_logits - image_conditioned_logits)
)
return mixed_logits.repeat(3, 1)