LLM_DataGen / fsm.py
Quentin Lhoest
initial commit
4f83ec0
from copy import copy
from functools import partial
from outlines.fsm.guide import RegexGuide
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
def merge_successive_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
states_to_token_maps = dict(states_to_token_maps)
transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
for s1, s2 in transitions_i.items():
while s2 in transitions_j:
s2 = transitions_j[s2]
if s2 != transitions_i[s1]:
states_to_token_maps[s1] = dict(states_to_token_maps[s1])
states_to_token_maps[s1][i] = s2
return states_to_token_maps
def replace_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
states_to_token_maps = dict(states_to_token_maps)
transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
for s1, s2 in transitions_i.items():
if s2 != transitions_j.get(s1):
states_to_token_maps[s1] = dict(states_to_token_maps[s1])
if s1 in transitions_j:
states_to_token_maps[s1][i] = transitions_j[s1]
else:
states_to_token_maps[s1].pop(i)
states_to_token_maps[s1][j] = s2
return states_to_token_maps
def find_paths_with_transitions(states_to_token_maps: dict[int, dict[int, int]], transitions: list[int]) -> list[list[int]]:
possible_s0 = {s0 for s0 in states_to_token_maps if transitions[0] in states_to_token_maps[s0]}
possible_s1 = {s1 for s1 in states_to_token_maps if transitions[1] in states_to_token_maps[s1]} - possible_s0
starts = sorted(
s0 for s0 in possible_s0
if states_to_token_maps[s0][transitions[0]] in possible_s1
)
paths = [[start] for start in starts]
for path in paths:
for i in transitions:
if i in states_to_token_maps[path[-1]]:
path.append(states_to_token_maps[path[-1]][i])
else:
break
return [path for path in paths if len(path) == len(transitions) + 1]
def replace_fields(fsm: RegexGuide, model: BaseModel, new_fields: list[str], tokenizer: PreTrainedTokenizerBase, make_infinite_loop: bool = False) -> RegexGuide:
assert len(new_fields) <= len(model.model_fields)
sttm = dict(fsm.states_to_token_maps)
encode = partial(tokenizer.encode, add_special_tokens=False)
quote = encode('"')[0]
# Let's replace the placeholder fields from the model in the finite state model by the new fields
for orig_field, new_field in zip(model.model_fields, new_fields):
orig_field_tokens = [encode(orig_field_char)[0] for orig_field_char in orig_field]
new_field_tokens = encode(new_field)
assert len(new_field_tokens) <= len(orig_field_tokens)
# Merge transitions until we have number of transitions = number of tokens in the field name
for k in reversed(range(len(new_field_tokens), len(orig_field_tokens))):
sttm = merge_successive_transitions(sttm, orig_field_tokens[k - 1], orig_field_tokens[k])
# Replace the token ids in the transitions with the ones of the new field name
for k in range(len(new_field_tokens)):
sttm = replace_transitions(sttm, orig_field_tokens[k], new_field_tokens[k])
if len(new_fields) < len(model.model_fields) or make_infinite_loop:
# Set the last field last state to generate less than the number of fields in the model
# We need to do this for every possible path
# e.g. multiple paths are used to count items when setting a min/max length
orig_last_field = list(model.model_fields)[-1]
new_last_field = new_fields[-1]
orig_last_field_paths = find_paths_with_transitions(sttm, [quote] + [encode(c)[0] for c in orig_last_field])
new_last_field_paths = find_paths_with_transitions(sttm, [quote] + encode(new_last_field))
if make_infinite_loop: # this is a hack to loop on the same states over and over again
orig_last_field_paths = [orig_last_field_paths[0]] * len(orig_last_field_paths)
for orig_last_field_path, new_last_field_path in zip(
orig_last_field_paths,
new_last_field_paths
):
orig_last_field_last_state = orig_last_field_path[-1]
new_last_field_second_last_state = new_last_field_path[-2]
sttm[new_last_field_second_last_state] = dict(sttm[new_last_field_second_last_state])
sttm[new_last_field_second_last_state][encode(new_last_field)[-1]] = orig_last_field_last_state
fsm = copy(fsm)
fsm.states_to_token_maps = sttm
return fsm