import copy import logging from functools import lru_cache from typing import List, Tuple, Dict from transformers_gad.parser import ( END_OF_RULE_MARKER, END_OF_ALTERNATE_MARKER, parse_ebnf, REF_RULE_MARKER, ) from transformers_gad.utf8_utils import PartialUTF8, decode_utf8 from transformers_gad.utils import intervals_intersect import logging class AcceptState: def __init__(self, stacks, partial_utf8): self.stacks = stacks self.partial_utf8 = partial_utf8 @staticmethod def empty_state(): return AcceptState([], PartialUTF8()) class StringRecognizer: def __init__( self, grammar_encoding: List[int], start_rule_id: int = None, rule_offsets: List[int] = None, stacks: List[List[int]] = None, ): # strictly speaking, we don't need to copy grammar_encoding because we don't modify it # but we do it anyway to be safe # in case where the grammar is very large, we can consider not copying it self.grammar_encoding = grammar_encoding if rule_offsets is not None: self.rule_offsets = rule_offsets else: if start_rule_id is None: raise ValueError("start_rule_id cannot be None if rule_offsets is None") self.rule_offsets = self.init_rules(start_rule_id) # each stack is a list of indices into grammar_encoding # each index points to a rule's if stacks is not None: self.stacks = stacks else: if start_rule_id is None: raise ValueError("start_rule_id cannot be None if stacks is None") self.stacks: List[List[int]] = self.init_stack(start_rule_id) self.start_rule_id = start_rule_id def init_rules(self, start_rule_id: int) -> List[int]: _rule_offset = 0 rule_offsets = [] # Build `rules` as an array of rule IDs to their positions in `grammar_src` while self.grammar_encoding[_rule_offset] != 0xFFFF: rule_id = self.grammar_encoding[_rule_offset] # store the offset idx if len(rule_offsets) <= rule_id: rule_offsets.extend([-1] * (rule_id - len(rule_offsets) + 1)) rule_offsets[rule_id] = _rule_offset # Skip rule ID # _rule_offset += 1 simple_rhs_offset = _rule_offset + 1 # Skip rule alternates while self.grammar_encoding[simple_rhs_offset] != END_OF_RULE_MARKER: simple_rhs_offset = ( simple_rhs_offset + 1 + self.grammar_encoding[simple_rhs_offset] ) # Skip 0 denoting end of rule # _rule_offset += 1 _rule_offset = simple_rhs_offset + 1 retrieved_start_rule_id = self.grammar_encoding[rule_offsets[start_rule_id]] assert retrieved_start_rule_id == start_rule_id return rule_offsets def init_stack(self, start_rule_id: int) -> List[List[int]]: stacks = [] # Loop over alternates of start rule to build initial stacks sub_rhs_offset = self.rule_offsets[start_rule_id] + 1 while self.grammar_encoding[sub_rhs_offset]: stack: List[int] = [] # If alternate is nonempty, add to stack element_offset = sub_rhs_offset + 1 if self.grammar_encoding[element_offset] != END_OF_ALTERNATE_MARKER: stack.append(element_offset) stacks.extend(self.advance_stack(tuple(stack))) sub_rhs_offset += 1 + self.grammar_encoding[sub_rhs_offset] return stacks def get_initial_accept_state(self) -> AcceptState: return AcceptState(self.init_stack(self.start_rule_id), PartialUTF8()) def get_termination_accept_state(self) -> AcceptState: return AcceptState([], PartialUTF8()) @lru_cache(maxsize=32768) def advance_stack(self, stack: Tuple[int]) -> List[List[int]]: stack = list(stack) if len(stack) == 0: return [stack] # we get the last element of the stack, which is the element we are currently processing cur_element_offset = stack[-1] # if the element is a terminal, we don't need to advance the stack if self.grammar_encoding[cur_element_offset] != REF_RULE_MARKER: return [stack] # the remaining case is that the element is a non-terminal, i.e. a reference to another rule else: ref_rule_id = self.grammar_encoding[cur_element_offset + 1] # find the offset of the referenced rule ref_subrule_offset = self.rule_offsets[ref_rule_id] + 1 new_stacks: List[List[int]] = [] # Loop over alternates of referenced rule to build new stacks while self.grammar_encoding[ref_subrule_offset] != END_OF_RULE_MARKER: # copy the original stack without the last element new_stack = stack[:-1] # if the rule ref is followed by another element, we add it to the stack next_element_offset = cur_element_offset + 2 if ( self.grammar_encoding[next_element_offset] != END_OF_ALTERNATE_MARKER ): new_stack.append(next_element_offset) # if the referenced rule is not empty, we add its element offset to the stack ref_element_offset = ref_subrule_offset + 1 if self.grammar_encoding[ref_element_offset] != END_OF_ALTERNATE_MARKER: new_stack.append(ref_element_offset) new_stacks.extend(self.advance_stack(tuple(new_stack))) ref_subrule_offset += self.grammar_encoding[ref_subrule_offset] + 1 return new_stacks def _consume_byte(self, byte: int, accept_state: AcceptState): # suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes # we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times self._consume_bytes(bytes([byte]), accept_state) # @lru_cache(maxsize=32768) def _probe_bytes( self, byte_seq: bytes, stacks: List[List[int]], partial_utf8: PartialUTF8, verbose=True, ): if type(byte_seq) is list: byte_seq = bytes(byte_seq) code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) if verbose: logging.debug( f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}" ) new_stacks = self._consume_code_points(code_points, stacks) for stack in new_stacks: # stack is empty, meaning that the variables are all consumed if len(stack) == 0: return True element_offset = stack[-1] if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8): return True return False def _consume_bytes( self, byte_seq: bytes, accept_state: AcceptState = None, verbose=True, ): if accept_state is None: accept_state = self.get_initial_accept_state() stacks = accept_state.stacks partial_utf8 = accept_state.partial_utf8 if type(byte_seq) is list: byte_seq = bytes(byte_seq) code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) if verbose: logging.debug( f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}" ) new_stacks = self._consume_code_points(code_points, stacks) new_new_stacks = [] for stack in new_stacks: if len(stack) == 0: continue element_offset = stack[-1] if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8): new_new_stacks.append(stack) return AcceptState(new_new_stacks, new_partial_utf8) ########################## # # Code point recognition # ########################## @lru_cache(maxsize=30000) def _consume_code_point( self, code_point: int, stacks: Tuple[Tuple[int]] ) -> List[List[int]]: """ consume a character from the stack char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] """ new_stacks = [] stacks: List[List[int]] = list([list(stack) for stack in stacks]) if code_point == 0: return new_stacks for stack in stacks: new_stacks.extend( self._consume_code_point_per_stack(code_point, tuple(stack)) ) return new_stacks @lru_cache(maxsize=30000) def _consume_code_point_per_stack( self, code_point: int, stack: Tuple[int] ) -> List[List[int]]: """ consume a character from the stack char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] """ # TODO, the below code will raise an error when the stack is empty, but why is this happening? # if len(stacks) == 0: # raise ValueError("Stacks don't contain any stack, meaning that no character can be consumed") # code_point = 0 is a special case when the uf8 sequence is not complete, we return an empty stack # to indicate that the character is not accepted stack = list(stack) new_stacks = [] if code_point == 0: return new_stacks # stack is empty if len(stack) == 0: return new_stacks element_offset = stack[-1] found = self.accept_code_point_at_element(code_point, element_offset) if not found: return new_stacks size = self.grammar_encoding[element_offset] element_offset += size + 1 new_stack = stack[:-1] if self.grammar_encoding[element_offset]: new_stack.append(element_offset) return self.advance_stack(tuple(new_stack)) def _consume_code_points( self, code_points: List[int], stacks: List[List[int]], verbose=False ) -> List[List[int]]: for i, code_point in enumerate(code_points): # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks]) stacks = self._consume_code_point(code_point, tuple_stacks) if len(stacks) > 0 and verbose: accepted_code_point = code_points[: i + 1] corresponding_char = chr(code_point) logging.debug( f"code point {accepted_code_point} corresponding to {corresponding_char} is accepted" ) return stacks def _accept_code_points( self, code_points: List[int], stacks: List[List[int]], verbose=False ) -> bool: stacks = self._consume_code_points(code_points, stacks, verbose) return len(stacks) > 0 @lru_cache(maxsize=30000) def accept_code_point_at_element( self, code_point: int, element_offset: int ) -> bool: size = self.grammar_encoding[element_offset] # to make idx point to the range_start of the first range element_offset += 1 for i in range(0, size, 2): if ( self.grammar_encoding[element_offset + i] <= code_point <= self.grammar_encoding[element_offset + i + 1] ): return True return False # def _accept_code_point(self, code_point: int, stacks: List[List[int]]): # # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks # tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks]) # new_stacks: List[List[int]] = self._consume_code_point(code_point, tuple_stacks) # return len(new_stacks) > 0 ############################# # # Partial UTF-8 recognition # ############################# def partial_utf8_accept_at_element( self, element_offset: int, partial_utf8: PartialUTF8 ) -> bool: # Extract the accumulated value and the number of remaining bytes from the partial_utf8 object. partial_value = partial_utf8.value n_remain = partial_utf8.n_remain # Return False if there are no remaining bytes to process or if it's an invalid UTF-8 sequence. if n_remain == 1 and partial_value < 2: return False # If there are no remaining bytes, this means we had already consumed a complete UTF-8 sequence. if n_remain <= 0: return True # Calculate the lowest possible Unicode code point that can be formed with the remaining bytes. low = partial_value << (n_remain * 6) # Calculate the highest possible Unicode code point by setting all remaining bits to 1. high = low | ((1 << (n_remain * 6)) - 1) # If the low end of the range is 0 and a specific number of bytes remain, adjust low to the minimum value # that can be represented with that number of bytes. This accounts for UTF-8 encoding rules. if low == 0: if n_remain == 2: low = 1 << 11 # Minimum value representable with 2 additional bytes. elif n_remain == 3: low = 1 << 16 # Minimum value representable with 3 additional bytes. # Get the size of the grammar rule starting at the current element_offset. size = self.grammar_encoding[element_offset] # Move the element_offset to the start of the grammar rule's definition. element_offset += 1 # Iterate over the grammar rule, checking if the range defined by low-high overlaps with any specified ranges. for i in range(0, size, 2): # If the current range (specified in the grammar encoding) overlaps with the low-high range, return True. if intervals_intersect( low, high, self.grammar_encoding[element_offset + i], self.grammar_encoding[element_offset + i + 1], ): return True # If no overlap is found with any of the ranges, return False, indicating no valid partial match. return False ############################# # # String recognition # ############################# def _consume_string(self, string: str, accept_state: AcceptState): # _bytes = bytes(string, "utf-8") code_points = [ord(char) for char in string] stacks = self._consume_code_points(code_points, accept_state.stacks) return AcceptState(stacks, accept_state.partial_utf8) def _accept_prefix(self, string: str, accept_state: AcceptState = None): if accept_state is None: accept_state = self.get_initial_accept_state() new_accept_state = self._consume_string(string, accept_state) return len(new_accept_state.stacks) > 0 def _accept_string(self, string: str, accept_state: AcceptState = None): if accept_state is None: accept_state = self.get_initial_accept_state() new_accept_state = self._consume_string(string, accept_state) at_least_one_stack_is_empty = any( len(stack) == 0 for stack in new_accept_state.stacks ) return at_least_one_stack_is_empty def _can_stop(self, stacks: List[List[int]]): # This happens in practice, but maybe it shouldn't? TODO if len(stacks) == 0: return True # if any of the stack is empty, we can stop for stack in stacks: if len(stack) == 0: return True else: return False def _must_stop(self, stacks: List[List[int]]): return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks) ############################# # # Not Used # ############################# # For each sub-rule in the grammar, cache whether each byte is accepted. @lru_cache(maxsize=None) def char_acceptance_at_element(self, element_offset): """ Caches and returns a dictionary indicating whether a Unicode character is accepted at a given rule position. This function considers Unicode characters, dynamically inserting accepted ranges into a dictionary to optimize memory usage. Args: - rule_offset: The offset in the grammar encoding where the rule starts. Returns: - A dictionary where each key is a Unicode character (or range) and the value is True if accepted. """ logging.debug(f"element_offset: {element_offset}") acceptance = {} num_chars = self.grammar_encoding[element_offset] element_offset += 1 for i in range(0, num_chars, 2): start = self.grammar_encoding[element_offset + i] end = self.grammar_encoding[element_offset + i + 1] for j in range(start, end + 1): acceptance[j] = True logging.debug(acceptance) return acceptance def _consume_code_points_new( self, code_points: List[int], stacks: List[List[int]], verbose=False ) -> List[List[int]]: new_stacks: List[List[int]] = [] for stack in stacks: new_stacks.extend( self._consume_code_points_per_stack( tuple(code_points), tuple(stack), verbose ) ) return new_stacks @lru_cache(maxsize=30000) def _consume_code_points_per_stack( self, code_points: Tuple[int], stack: Tuple[int], verbose=False ) -> List[List[int]]: code_points = list(code_points) stacks = (stack,) for i, code_point in enumerate(code_points): # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks stacks = self._consume_code_point(code_point, stacks) stacks = tuple([tuple(stack) for stack in stacks]) return [list(stack) for stack in stacks]