|
import logging |
|
from functools import lru_cache |
|
from typing import Dict, List, Tuple |
|
from collections import deque |
|
|
|
from transformers_gad.mapping import get_mapping |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class TrieNode: |
|
def __init__(self): |
|
self.children = {} |
|
self.is_end_of_word = False |
|
self.token_id = None |
|
|
|
|
|
class ByteTrie: |
|
def __init__(self): |
|
self.root = TrieNode() |
|
|
|
def insert(self, word, token_id=None): |
|
node = self.root |
|
for char in word: |
|
if char not in node.children: |
|
node.children[char] = TrieNode() |
|
node = node.children[char] |
|
node.is_end_of_word = True |
|
node.token_id = token_id |
|
|
|
def search(self, word): |
|
node = self.root |
|
for char in word: |
|
if char not in node.children: |
|
return False |
|
node = node.children[char] |
|
return node.is_end_of_word |
|
|
|
def start_with_prefix(self, prefix): |
|
node = self.root |
|
for char in prefix: |
|
if char not in node.children: |
|
return False |
|
node = node.children[char] |
|
return True |
|
|
|
@classmethod |
|
def from_tokenizer(cls, tokenizer, unicode=True): |
|
vocab: Dict[str, int] = tokenizer.get_vocab() |
|
trie = cls() |
|
mapping = get_mapping(tokenizer, unicode=unicode) |
|
for token_id in vocab.values(): |
|
byte_repr = mapping.map(token_id) |
|
trie.insert(byte_repr, token_id) |
|
return trie |
|
|
|
@lru_cache(maxsize=128) |
|
def __len__(self): |
|
return len(self.dfs(verbose=False)) |
|
|
|
def dfs(self, accept=lambda x: True, verbose=False) -> List[Tuple[List[int], int]]: |
|
result = [] |
|
counter = {"visited": 0, "pruned": 0} |
|
_dfs(self.root, [], result, accept, counter) |
|
return result |
|
|
|
def bfs( |
|
self, predicate=lambda x: True, verbose=False |
|
) -> List[Tuple[List[int], int]]: |
|
queue = deque([(self.root, [])]) |
|
valid_byte_seqs: List[Tuple[List[int], int]] = [] |
|
counter = {"visited": 0, "pruned": 0} |
|
|
|
while queue: |
|
counter["visited"] += 1 |
|
node, byte_seq = queue.popleft() |
|
if predicate(byte_seq): |
|
if node.is_end_of_word: |
|
valid_byte_seqs.append((byte_seq, node.token_id)) |
|
for char, next_node in node.children.items(): |
|
new_byte_seq: List[int] = byte_seq.copy() |
|
new_byte_seq.append(char) |
|
queue.append((next_node, new_byte_seq)) |
|
else: |
|
counter["pruned"] += 1 |
|
return valid_byte_seqs |
|
|
|
def get_token_acceptance( |
|
self, accept=lambda x: True, accept_eos=True, eos_token_id=None |
|
) -> List[bool]: |
|
valid_byte_seqs: List[Tuple[List[int], int]] = self.bfs(accept, verbose=True) |
|
valid_token_ids: List[int] = [token_id for _, token_id in valid_byte_seqs] |
|
token_acceptance: List[bool] = [False] * (len(self)) |
|
for token_id in valid_token_ids: |
|
token_acceptance[token_id] = True |
|
if not accept_eos: |
|
|
|
|
|
token_acceptance[eos_token_id] = False |
|
return token_acceptance |
|
|
|
|
|
def _dfs( |
|
node, |
|
cur_byte_seq: List[int], |
|
result: List[Tuple[List[int], int]], |
|
accept: callable, |
|
counter: Dict[str, int], |
|
): |
|
counter["visited"] += 1 |
|
if accept(cur_byte_seq): |
|
if node.is_end_of_word: |
|
result.append((cur_byte_seq, node.token_id)) |
|
for char, next_node in node.children.items(): |
|
new_byte_seq: List[int] = cur_byte_seq.copy() |
|
new_byte_seq.append(char) |
|
_dfs(next_node, new_byte_seq, result, accept, counter) |
|
else: |
|
|
|
counter["pruned"] += 1 |
|
return |
|
|
|
|
|
def starts_with_prefix(prefix, target): |
|
""" |
|
Check if the given prefix is a valid start of the target word or if the target word is a valid start of the given prefix. |
|
|
|
Args: |
|
prefix (str): The string prefix to be checked. |
|
target (str): The target word to compare the prefix against. |
|
|
|
Returns: |
|
bool: True if prefix is a valid start of target or if target is a valid start of prefix, False otherwise. |
|
""" |
|
|
|
|
|
|
|
if target.startswith(prefix): |
|
return True |
|
|
|
|
|
|
|
if prefix.startswith(target): |
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
if __name__ == "__main__": |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
from transformers import AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True) |
|
|
|
trie = ByteTrie.from_tokenizer(tokenizer, unicode=True) |
|
print(f"length of trie: {len(trie)}=={len(tokenizer.vocab.items())}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|