File size: 6,652 Bytes
901bbd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import logging
from functools import lru_cache
from typing import Dict, List, Tuple
from collections import deque
from transformers_gad.mapping import get_mapping
logger = logging.getLogger(__name__)
class TrieNode:
def __init__(self):
self.children = {}
self.is_end_of_word = False
self.token_id = None
class ByteTrie:
def __init__(self):
self.root = TrieNode()
def insert(self, word, token_id=None):
node = self.root
for char in word:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_end_of_word = True
node.token_id = token_id
def search(self, word):
node = self.root
for char in word:
if char not in node.children:
return False
node = node.children[char]
return node.is_end_of_word
def start_with_prefix(self, prefix):
node = self.root
for char in prefix:
if char not in node.children:
return False
node = node.children[char]
return True
@classmethod
def from_tokenizer(cls, tokenizer, unicode=True):
vocab: Dict[str, int] = tokenizer.get_vocab()
trie = cls()
mapping = get_mapping(tokenizer, unicode=unicode)
for token_id in vocab.values():
byte_repr = mapping.map(token_id)
trie.insert(byte_repr, token_id)
return trie
@lru_cache(maxsize=128)
def __len__(self):
return len(self.dfs(verbose=False))
def dfs(self, accept=lambda x: True, verbose=False) -> List[Tuple[List[int], int]]:
result = []
counter = {"visited": 0, "pruned": 0}
_dfs(self.root, [], result, accept, counter)
return result
def bfs(
self, predicate=lambda x: True, verbose=False
) -> List[Tuple[List[int], int]]:
queue = deque([(self.root, [])])
valid_byte_seqs: List[Tuple[List[int], int]] = []
counter = {"visited": 0, "pruned": 0}
while queue:
counter["visited"] += 1
node, byte_seq = queue.popleft()
if predicate(byte_seq):
if node.is_end_of_word:
valid_byte_seqs.append((byte_seq, node.token_id))
for char, next_node in node.children.items():
new_byte_seq: List[int] = byte_seq.copy()
new_byte_seq.append(char)
queue.append((next_node, new_byte_seq))
else:
counter["pruned"] += 1
return valid_byte_seqs
def get_token_acceptance(
self, accept=lambda x: True, accept_eos=True, eos_token_id=None
) -> List[bool]:
valid_byte_seqs: List[Tuple[List[int], int]] = self.bfs(accept, verbose=True)
valid_token_ids: List[int] = [token_id for _, token_id in valid_byte_seqs]
token_acceptance: List[bool] = [False] * (len(self))
for token_id in valid_token_ids:
token_acceptance[token_id] = True
if not accept_eos:
# eos_token is mapped to an empty string, so it's always accepted regardless of the accept function
# this can be undesirable, so we can set it to False to ignore it
token_acceptance[eos_token_id] = False
return token_acceptance
def _dfs(
node,
cur_byte_seq: List[int],
result: List[Tuple[List[int], int]],
accept: callable,
counter: Dict[str, int],
):
counter["visited"] += 1
if accept(cur_byte_seq):
if node.is_end_of_word:
result.append((cur_byte_seq, node.token_id))
for char, next_node in node.children.items():
new_byte_seq: List[int] = cur_byte_seq.copy()
new_byte_seq.append(char)
_dfs(next_node, new_byte_seq, result, accept, counter)
else:
# Skip the entire subtree if the predict function returns False
counter["pruned"] += 1
return
def starts_with_prefix(prefix, target):
"""
Check if the given prefix is a valid start of the target word or if the target word is a valid start of the given prefix.
Args:
prefix (str): The string prefix to be checked.
target (str): The target word to compare the prefix against.
Returns:
bool: True if prefix is a valid start of target or if target is a valid start of prefix, False otherwise.
"""
# Check if the target word starts with the given prefix.
# This covers the case where the prefix is shorter than the target word.
if target.startswith(prefix):
return True
# Check if the given prefix starts with the target word.
# This covers the case where the prefix is longer than or equal to the target word.
if prefix.startswith(target):
return True
# If neither of the above conditions are true, return False.
return False
if __name__ == "__main__":
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True)
trie = ByteTrie.from_tokenizer(tokenizer, unicode=True)
print(f"length of trie: {len(trie)}=={len(tokenizer.vocab.items())}")
#
# print(trie.search("hello")) # Example, replace with actual words from the vocab
# print(trie.start_with_prefix("hell"))
#
# # Example Usage
# words = trie.dfs(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
# for word in words:
# print(bytes(word[0]).decode("utf-8"))
#
# # Example Usage
# words = trie.bfs(predicate=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
# for word in words:
# print(bytes(word[0]).decode("utf-8"))
#
# token_acceptance = trie.get_token_acceptance(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
# print(sum(token_acceptance))
# assert sum(token_acceptance) == len(words)
########################
# UTF-8
########################
# from transformers import AutoTokenizer
#
# japanese = "こんにちは世界"
# with open("examples/grammars/japanese.ebnf", "r") as file:
# input_text = file.read()
# parsed_grammar = parse_ebnf(input_text)
#
# start_rule_id = parsed_grammar.symbol_table["root"]
#
# recognizer = GrammarRecognizer(parsed_grammar.grammar_encoding, start_rule_id)
# accept_state = recognizer.init_accept_state()
# token_acc = trie.get_token_acceptance(accept=lambda x: recognizer._probe_bytes_partial_match(x, accept_state=accept_state)) |