kjcjohnson's picture
Add GAD libraries
901bbd9
import logging
from functools import lru_cache
from typing import Dict, List, Tuple
from collections import deque
from transformers_gad.mapping import get_mapping
logger = logging.getLogger(__name__)
class TrieNode:
def __init__(self):
self.children = {}
self.is_end_of_word = False
self.token_id = None
class ByteTrie:
def __init__(self):
self.root = TrieNode()
def insert(self, word, token_id=None):
node = self.root
for char in word:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_end_of_word = True
node.token_id = token_id
def search(self, word):
node = self.root
for char in word:
if char not in node.children:
return False
node = node.children[char]
return node.is_end_of_word
def start_with_prefix(self, prefix):
node = self.root
for char in prefix:
if char not in node.children:
return False
node = node.children[char]
return True
@classmethod
def from_tokenizer(cls, tokenizer, unicode=True):
vocab: Dict[str, int] = tokenizer.get_vocab()
trie = cls()
mapping = get_mapping(tokenizer, unicode=unicode)
for token_id in vocab.values():
byte_repr = mapping.map(token_id)
trie.insert(byte_repr, token_id)
return trie
@lru_cache(maxsize=128)
def __len__(self):
return len(self.dfs(verbose=False))
def dfs(self, accept=lambda x: True, verbose=False) -> List[Tuple[List[int], int]]:
result = []
counter = {"visited": 0, "pruned": 0}
_dfs(self.root, [], result, accept, counter)
return result
def bfs(
self, predicate=lambda x: True, verbose=False
) -> List[Tuple[List[int], int]]:
queue = deque([(self.root, [])])
valid_byte_seqs: List[Tuple[List[int], int]] = []
counter = {"visited": 0, "pruned": 0}
while queue:
counter["visited"] += 1
node, byte_seq = queue.popleft()
if predicate(byte_seq):
if node.is_end_of_word:
valid_byte_seqs.append((byte_seq, node.token_id))
for char, next_node in node.children.items():
new_byte_seq: List[int] = byte_seq.copy()
new_byte_seq.append(char)
queue.append((next_node, new_byte_seq))
else:
counter["pruned"] += 1
return valid_byte_seqs
def get_token_acceptance(
self, accept=lambda x: True, accept_eos=True, eos_token_id=None
) -> List[bool]:
valid_byte_seqs: List[Tuple[List[int], int]] = self.bfs(accept, verbose=True)
valid_token_ids: List[int] = [token_id for _, token_id in valid_byte_seqs]
token_acceptance: List[bool] = [False] * (len(self))
for token_id in valid_token_ids:
token_acceptance[token_id] = True
if not accept_eos:
# eos_token is mapped to an empty string, so it's always accepted regardless of the accept function
# this can be undesirable, so we can set it to False to ignore it
token_acceptance[eos_token_id] = False
return token_acceptance
def _dfs(
node,
cur_byte_seq: List[int],
result: List[Tuple[List[int], int]],
accept: callable,
counter: Dict[str, int],
):
counter["visited"] += 1
if accept(cur_byte_seq):
if node.is_end_of_word:
result.append((cur_byte_seq, node.token_id))
for char, next_node in node.children.items():
new_byte_seq: List[int] = cur_byte_seq.copy()
new_byte_seq.append(char)
_dfs(next_node, new_byte_seq, result, accept, counter)
else:
# Skip the entire subtree if the predict function returns False
counter["pruned"] += 1
return
def starts_with_prefix(prefix, target):
"""
Check if the given prefix is a valid start of the target word or if the target word is a valid start of the given prefix.
Args:
prefix (str): The string prefix to be checked.
target (str): The target word to compare the prefix against.
Returns:
bool: True if prefix is a valid start of target or if target is a valid start of prefix, False otherwise.
"""
# Check if the target word starts with the given prefix.
# This covers the case where the prefix is shorter than the target word.
if target.startswith(prefix):
return True
# Check if the given prefix starts with the target word.
# This covers the case where the prefix is longer than or equal to the target word.
if prefix.startswith(target):
return True
# If neither of the above conditions are true, return False.
return False
if __name__ == "__main__":
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True)
trie = ByteTrie.from_tokenizer(tokenizer, unicode=True)
print(f"length of trie: {len(trie)}=={len(tokenizer.vocab.items())}")
#
# print(trie.search("hello")) # Example, replace with actual words from the vocab
# print(trie.start_with_prefix("hell"))
#
# # Example Usage
# words = trie.dfs(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
# for word in words:
# print(bytes(word[0]).decode("utf-8"))
#
# # Example Usage
# words = trie.bfs(predicate=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
# for word in words:
# print(bytes(word[0]).decode("utf-8"))
#
# token_acceptance = trie.get_token_acceptance(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
# print(sum(token_acceptance))
# assert sum(token_acceptance) == len(words)
########################
# UTF-8
########################
# from transformers import AutoTokenizer
#
# japanese = "こんにちは世界"
# with open("examples/grammars/japanese.ebnf", "r") as file:
# input_text = file.read()
# parsed_grammar = parse_ebnf(input_text)
#
# start_rule_id = parsed_grammar.symbol_table["root"]
#
# recognizer = GrammarRecognizer(parsed_grammar.grammar_encoding, start_rule_id)
# accept_state = recognizer.init_accept_state()
# token_acc = trie.get_token_acceptance(accept=lambda x: recognizer._probe_bytes_partial_match(x, accept_state=accept_state))