|
import copy |
|
import logging |
|
from abc import ABC |
|
from functools import lru_cache |
|
from typing import List |
|
|
|
import torch |
|
|
|
from transformers_gad.recognizer import StringRecognizer, AcceptState |
|
from transformers_gad.parser import parse_ebnf |
|
from transformers_gad.trie import ByteTrie |
|
from transformers_gad.utf8_utils import PartialUTF8 |
|
from .vocab_struct import LEAF, TokenTrie |
|
from transformers_gad.mapping import get_mapping |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AbsTokenRecognizer(ABC): |
|
def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False): |
|
parsed_grammar = parse_ebnf(grammar_str) |
|
grammar_encoding = parsed_grammar.grammar_encoding |
|
self.start_rule_id = parsed_grammar.symbol_table.get(start_rule_name) |
|
self.byte_encoding = unicode |
|
|
|
if unicode and not tokenizer.__class__.__name__.lower().startswith( |
|
"gpt2" |
|
): |
|
raise ValueError( |
|
"Constrained decoding with unicode is only supported for GPT2 model. Support for other models is coming soon." |
|
"Or you can use the constraints with only ascii characters." |
|
) |
|
|
|
self.eos_token_id = tokenizer.eos_token_id |
|
self.token_trie = TokenTrie(tokenizer) |
|
self.tokenizer = tokenizer |
|
self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id) |
|
self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode) |
|
self.mapping = get_mapping(tokenizer, unicode=unicode) |
|
assert len(self.mapping) == len( |
|
self.token_trie |
|
), f"{len(self.mapping)}, {len(self.token_trie)}" |
|
|
|
def _consume_token_id( |
|
self, token_id: int, accept_state: AcceptState |
|
) -> AcceptState: |
|
if self.string_recognizer._must_stop(accept_state.stacks): |
|
if token_id == self.eos_token_id: |
|
return self.string_recognizer.get_termination_accept_state() |
|
else: |
|
raise ValueError( |
|
f"All stacks are empty, so the only token accepted is EOS({self.eos_token_id}), but got {token_id}" |
|
) |
|
if token_id == self.eos_token_id: |
|
if self.string_recognizer._can_stop(accept_state.stacks): |
|
|
|
|
|
return self.string_recognizer.get_termination_accept_state() |
|
else: |
|
raise ValueError( |
|
f"At least one of the stack should be empty when EOS is reached. However, " |
|
f"the stacks are {accept_state.stacks}" |
|
) |
|
|
|
bytes_or_codepoints = self.mapping.map(token_id) |
|
accept_state = self.string_recognizer._consume_bytes( |
|
bytes_or_codepoints, accept_state |
|
) |
|
return accept_state |
|
|
|
def probe_token_id(self, token_id: int, accept_state: AcceptState) -> bool: |
|
stacks = accept_state.stacks |
|
if self.string_recognizer._must_stop(stacks): |
|
if token_id == self.eos_token_id: |
|
return True |
|
else: |
|
return False |
|
if token_id == self.eos_token_id: |
|
if self.string_recognizer._can_stop(stacks): |
|
|
|
|
|
return True |
|
else: |
|
return False |
|
|
|
|
|
bytes_or_codepoints = self.mapping.map(token_id, verbose=False) |
|
new_acc_state = self.string_recognizer._consume_bytes( |
|
bytes_or_codepoints, accept_state, verbose=False |
|
) |
|
return len(new_acc_state.stacks) > 0 |
|
|
|
def advance_token_ids(self, *args, **kwargs): |
|
"""Process a list of tokens according to the grammar rules.""" |
|
raise NotImplementedError |
|
|
|
def batch_filter_vocab(self, batch_accept_states, device) -> torch.Tensor: |
|
batch_acceptance = [] |
|
for accept_state in batch_accept_states: |
|
batch_acceptance.append(self.filter_vocab(accept_state, device)) |
|
return torch.stack(batch_acceptance) |
|
|
|
def filter_vocab(self, accept_state, device) -> torch.Tensor: |
|
if not accept_state.stacks: |
|
|
|
|
|
vocab_size = len(self.mapping) |
|
logger.debug(f"Empty stack, sum of acceptance: {0}") |
|
|
|
accepts = [False] * vocab_size |
|
accepts[self.eos_token_id] = True |
|
return torch.tensor(accepts, dtype=torch.bool, device=device) |
|
|
|
acceptance = self.get_token_acceptance(accept_state, device) |
|
|
|
return acceptance |
|
|
|
def get_token_acceptance(self, accept_state, device) -> torch.Tensor: |
|
acceptance_matrix = torch.cat( |
|
[ |
|
self.get_token_acceptance_array_for_stack( |
|
tuple(stack), accept_state.partial_utf8, device |
|
) |
|
for stack in accept_state.stacks |
|
] |
|
) |
|
|
|
acceptance = acceptance_matrix.reshape(len(accept_state.stacks), -1).any(dim=0) |
|
return acceptance |
|
|
|
@lru_cache(maxsize=32768) |
|
def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device): |
|
|
|
assert isinstance(stack, tuple) |
|
stack = list(stack) |
|
|
|
if self.byte_encoding: |
|
|
|
accept_f = lambda x: self.string_recognizer._probe_bytes( |
|
x, [stack], partial_utf8=partial_utf8 |
|
) |
|
token_acceptance = self.unicode_trie.get_token_acceptance( |
|
accept=accept_f, accept_eos=False, eos_token_id=self.eos_token_id |
|
) |
|
else: |
|
accepts = [False] * len(self.mapping) |
|
token_acceptance = check_token_acceptance_in_trie( |
|
self.token_trie.trie, |
|
[stack], |
|
self.string_recognizer, |
|
self.eos_token_id, |
|
accepts, |
|
) |
|
x = torch.tensor(token_acceptance, dtype=torch.bool, device=device) |
|
x_eos = self.validate_and_set_eos_acceptance(x) |
|
return x_eos |
|
|
|
def validate_and_set_eos_acceptance(self, acceptance: torch.Tensor) -> torch.Tensor: |
|
if torch.any(acceptance) == 0: |
|
acceptance[self.eos_token_id] = True |
|
else: |
|
if acceptance[self.eos_token_id]: |
|
raise ValueError() |
|
acceptance[self.eos_token_id] = False |
|
return acceptance |
|
|
|
|
|
class IncrementalTokenRecognizer(AbsTokenRecognizer): |
|
def __init__(self, grammar_str, start_rule_name, tokenizer, unicode=False): |
|
super().__init__(grammar_str, tokenizer, start_rule_name, unicode) |
|
self.last_size = None |
|
self.is_incremental = True |
|
|
|
|
|
|
|
|
|
def advance_token_ids(self, input_ids, batch_accept_states, parse_start_index=None): |
|
|
|
if self.last_size is None: |
|
prefix_to_parse = [ |
|
single_input_ids[parse_start_index:] |
|
if parse_start_index is not None |
|
else [] |
|
for single_input_ids in input_ids |
|
] |
|
|
|
|
|
batch_accept_states = [ |
|
self._consume_token_ids(prefix, accept_state) |
|
for prefix, accept_state in zip(prefix_to_parse, batch_accept_states) |
|
] |
|
|
|
|
|
elif len(input_ids[0]) == self.last_size + 1: |
|
batch_accept_states = [ |
|
self._consume_token_id(single_input_ids[-1], accept_state) |
|
for single_input_ids, accept_state in zip( |
|
input_ids, batch_accept_states |
|
) |
|
] |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise RuntimeError( |
|
"Input ID's length is inconsistent with the current state of " |
|
"the GrammarConstrainedLogitsProcessor. If you want to process " |
|
"another input sequence, please instantiate a new " |
|
"GrammarConstrainedLogitsProcessor " |
|
"or call reset_parser method of GrammarAlignedOracleLogitsProcessor" |
|
) |
|
self.last_size = len(input_ids[0]) |
|
|
|
return batch_accept_states |
|
|
|
def _consume_token_ids( |
|
self, token_ids: List[int], accept_state: AcceptState = None, as_string=True |
|
): |
|
if accept_state is None: |
|
accept_state = self.string_recognizer.get_initial_accept_state() |
|
if as_string: |
|
string = self.tokenizer.decode(token_ids) |
|
accept_state = self.string_recognizer._consume_string(string, accept_state) |
|
else: |
|
for i, token_id in enumerate(token_ids): |
|
accept_state = self._consume_token_id(token_id, accept_state) |
|
if len(accept_state.stacks) > 0: |
|
cur_token_ids = token_ids[: i + 1] |
|
logging.debug(f"{cur_token_ids} is accepted") |
|
decoded_string = self.tokenizer.decode(cur_token_ids) |
|
logging.debug(f"The decoded string is {decoded_string}") |
|
return accept_state |
|
|
|
def reset(self): |
|
self.last_size = None |
|
|
|
def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts): |
|
|
|
for byte, next_trie in trie.items(): |
|
if byte == LEAF: |
|
token_id = next_trie |
|
if token_id != eos_token_id: |
|
|
|
|
|
accepts[token_id] = bool(stacks) |
|
continue |
|
|
|
new_stacks = [] |
|
for stk in stacks: |
|
if not stk: |
|
continue |
|
|
|
next_element_offset = stk[-1] |
|
num_chars = grammar.grammar_encoding[next_element_offset] |
|
|
|
if not grammar.char_acceptance_at_element(next_element_offset).get( |
|
byte, False |
|
): |
|
|
|
continue |
|
|
|
next_element_offset += num_chars + 1 |
|
new_stack = stk[:-1] |
|
if grammar.grammar_encoding[next_element_offset]: |
|
new_stack.append(next_element_offset) |
|
new_stacks.extend(grammar.advance_stack(tuple(new_stack))) |
|
|
|
if new_stacks: |
|
check_token_acceptance_in_trie( |
|
next_trie, new_stacks, grammar, eos_token_id, accepts |
|
) |
|
|
|
return accepts |
|
|
|
|
|
if __name__ == "__main__": |
|
from transformers import AutoTokenizer |
|
|
|
with open("examples/grammars/japanese.ebnf", "r") as file: |
|
input_text = file.read() |
|
parsed_grammar = parse_ebnf(input_text) |
|
parsed_grammar.print() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
tokenRecognizer = IncrementalTokenRecognizer( |
|
grammar_str=input_text, start_rule_name="root", tokenizer=tokenizer |
|
) |
|
|
|
japanese = "γγͺγΌγ " |
|
token_ids = tokenizer.encode(japanese) |
|
|
|
stacks = tokenRecognizer._consume_token_ids( |
|
token_ids, tokenRecognizer.string_recognizer.stacks, as_string=False |
|
) |
|
|
|
if stacks: |
|
print("The Japanese input is accepted") |
|
else: |
|
print("The Japanese input is not accepted") |
|
|
|
korean = "μλ
νμΈμ" |
|
token_ids = tokenizer.encode(korean) |
|
|
|
try: |
|
stacks = tokenRecognizer._consume_token_ids( |
|
token_ids, tokenRecognizer.string_recognizer.stacks, as_string=False |
|
) |
|
if stacks: |
|
print("The Korean input is accepted") |
|
else: |
|
print("The Korean input is not accepted") |
|
except ValueError as e: |
|
print("The Korean input is not accepted") |
|
|