LLM_DataGen / fsm.py
Quentin Lhoest
initial commit
4f83ec0
raw history blame
No virus
5.22 kB
from copy import copy
from functools import partial
from outlines.fsm.guide import RegexGuide
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
def merge_successive_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
states_to_token_maps = dict(states_to_token_maps)
transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
for s1, s2 in transitions_i.items():
while s2 in transitions_j:
s2 = transitions_j[s2]
if s2 != transitions_i[s1]:
states_to_token_maps[s1] = dict(states_to_token_maps[s1])
states_to_token_maps[s1][i] = s2
return states_to_token_maps
def replace_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
states_to_token_maps = dict(states_to_token_maps)
transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
for s1, s2 in transitions_i.items():
if s2 != transitions_j.get(s1):
states_to_token_maps[s1] = dict(states_to_token_maps[s1])
if s1 in transitions_j:
states_to_token_maps[s1][i] = transitions_j[s1]
else:
states_to_token_maps[s1].pop(i)
states_to_token_maps[s1][j] = s2
return states_to_token_maps
def find_paths_with_transitions(states_to_token_maps: dict[int, dict[int, int]], transitions: list[int]) -> list[list[int]]:
possible_s0 = {s0 for s0 in states_to_token_maps if transitions[0] in states_to_token_maps[s0]}
possible_s1 = {s1 for s1 in states_to_token_maps if transitions[1] in states_to_token_maps[s1]} - possible_s0
starts = sorted(
s0 for s0 in possible_s0
if states_to_token_maps[s0][transitions[0]] in possible_s1
)
paths = [[start] for start in starts]
for path in paths:
for i in transitions:
if i in states_to_token_maps[path[-1]]:
path.append(states_to_token_maps[path[-1]][i])
else:
break
return [path for path in paths if len(path) == len(transitions) + 1]
def replace_fields(fsm: RegexGuide, model: BaseModel, new_fields: list[str], tokenizer: PreTrainedTokenizerBase, make_infinite_loop: bool = False) -> RegexGuide:
assert len(new_fields) <= len(model.model_fields)
sttm = dict(fsm.states_to_token_maps)
encode = partial(tokenizer.encode, add_special_tokens=False)
quote = encode('"')[0]
# Let's replace the placeholder fields from the model in the finite state model by the new fields
for orig_field, new_field in zip(model.model_fields, new_fields):
orig_field_tokens = [encode(orig_field_char)[0] for orig_field_char in orig_field]
new_field_tokens = encode(new_field)
assert len(new_field_tokens) <= len(orig_field_tokens)
# Merge transitions until we have number of transitions = number of tokens in the field name
for k in reversed(range(len(new_field_tokens), len(orig_field_tokens))):
sttm = merge_successive_transitions(sttm, orig_field_tokens[k - 1], orig_field_tokens[k])
# Replace the token ids in the transitions with the ones of the new field name
for k in range(len(new_field_tokens)):
sttm = replace_transitions(sttm, orig_field_tokens[k], new_field_tokens[k])
if len(new_fields) < len(model.model_fields) or make_infinite_loop:
# Set the last field last state to generate less than the number of fields in the model
# We need to do this for every possible path
# e.g. multiple paths are used to count items when setting a min/max length
orig_last_field = list(model.model_fields)[-1]
new_last_field = new_fields[-1]
orig_last_field_paths = find_paths_with_transitions(sttm, [quote] + [encode(c)[0] for c in orig_last_field])
new_last_field_paths = find_paths_with_transitions(sttm, [quote] + encode(new_last_field))
if make_infinite_loop: # this is a hack to loop on the same states over and over again
orig_last_field_paths = [orig_last_field_paths[0]] * len(orig_last_field_paths)
for orig_last_field_path, new_last_field_path in zip(
orig_last_field_paths,
new_last_field_paths
):
orig_last_field_last_state = orig_last_field_path[-1]
new_last_field_second_last_state = new_last_field_path[-2]
sttm[new_last_field_second_last_state] = dict(sttm[new_last_field_second_last_state])
sttm[new_last_field_second_last_state][encode(new_last_field)[-1]] = orig_last_field_last_state
fsm = copy(fsm)
fsm.states_to_token_maps = sttm
return fsm