|
''' |
|
This file has been 100% copied from this PR to the Transformers library: |
|
https://github.com/huggingface/transformers/pull/27557 |
|
|
|
Author: Saibo-creator |
|
Author GitHub: https://github.com/Saibo-creator |
|
|
|
All credits go to the author. |
|
''' |
|
|
|
import math |
|
|
|
import torch |
|
from transformers.generation.logits_process import LogitsProcessor |
|
from transformers.utils import add_start_docstrings |
|
|
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) |
|
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): |
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam |
|
search or log softmax for each vocabulary token when using beam search |
|
|
|
Return: |
|
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. |
|
|
|
""" |
|
|
|
|
|
class GrammarConstrainedLogitsProcessor(LogitsProcessor): |
|
def __init__(self, grammar_constraint): |
|
self.last_size = None |
|
self.grammar_constraint = grammar_constraint |
|
self.batch_stacks = None |
|
|
|
def filter_logits(self, logits, device): |
|
|
|
|
|
|
|
acceptance = self.grammar_constraint.batch_filter_vocab(self.batch_stacks, device) |
|
|
|
|
|
logits[~acceptance] = -math.inf |
|
|
|
|
|
def process_logits(self, input_ids, scores, parse_start_index=None): |
|
""" |
|
:param input_ids: |
|
:param scores: |
|
:param parse_start_index: default None, which means generate from scratch. Set to 0 to parse all input_ids |
|
:return: |
|
""" |
|
|
|
if self.batch_stacks is None: |
|
self.batch_stacks = [self.grammar_constraint.init_stacks() for _ in range(len(input_ids))] |
|
|
|
|
|
|
|
if self.last_size is None: |
|
prefix_to_parse = [ |
|
single_input_ids[parse_start_index:] if parse_start_index is not None else [] |
|
for single_input_ids in input_ids |
|
] |
|
|
|
self.batch_stacks = [ |
|
self.grammar_constraint.accept_token_ids(prefix, stack) |
|
for prefix, stack in zip(prefix_to_parse, self.batch_stacks) |
|
] |
|
|
|
|
|
elif len(input_ids[0]) == self.last_size + 1: |
|
|
|
self.batch_stacks = [ |
|
self.grammar_constraint.accept_token_id(single_input_ids[-1], stack) |
|
for single_input_ids, stack in zip(input_ids, self.batch_stacks) |
|
] |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise RuntimeError( |
|
"Input ID's length is inconsistent with the current state of " |
|
"the GrammarConstrainedLogitsProcessor. If you want to process " |
|
"another input sequence, please instantiate a new " |
|
"GrammarConstrainedLogitsProcessor." |
|
) |
|
|
|
self.filter_logits(scores, scores.device) |
|
|
|
self.last_size = len(input_ids[0]) |
|
return scores |
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
return self.process_logits(input_ids, scores) |
|
|