File size: 5,219 Bytes
4f83ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from copy import copy
from functools import partial

from outlines.fsm.guide import RegexGuide
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase


def merge_successive_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
    states_to_token_maps = dict(states_to_token_maps)
    transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
    transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
    transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
    for s1, s2 in transitions_i.items():
        while s2 in transitions_j:
            s2 = transitions_j[s2]
        if s2 != transitions_i[s1]:
            states_to_token_maps[s1] = dict(states_to_token_maps[s1])
            states_to_token_maps[s1][i] = s2
    return states_to_token_maps


def replace_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j):
    states_to_token_maps = dict(states_to_token_maps)
    transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]}
    transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]}
    transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i)
    for s1, s2 in transitions_i.items():
        if s2 != transitions_j.get(s1):
            states_to_token_maps[s1] = dict(states_to_token_maps[s1])
            if s1 in transitions_j:
                states_to_token_maps[s1][i] = transitions_j[s1]
            else:
                states_to_token_maps[s1].pop(i)
            states_to_token_maps[s1][j] = s2
    return states_to_token_maps


def find_paths_with_transitions(states_to_token_maps: dict[int, dict[int, int]], transitions: list[int]) -> list[list[int]]:
    possible_s0 = {s0 for s0 in states_to_token_maps if transitions[0] in states_to_token_maps[s0]}
    possible_s1 = {s1 for s1 in states_to_token_maps if transitions[1] in states_to_token_maps[s1]} - possible_s0
    starts = sorted(
        s0 for s0 in possible_s0
        if states_to_token_maps[s0][transitions[0]] in possible_s1
    )
    paths = [[start] for start in starts]
    for path in paths:
        for i in transitions:
            if i in states_to_token_maps[path[-1]]:
                path.append(states_to_token_maps[path[-1]][i])
            else:
                break
    return [path for path in paths if len(path) == len(transitions) + 1]


def replace_fields(fsm: RegexGuide, model: BaseModel, new_fields: list[str], tokenizer: PreTrainedTokenizerBase, make_infinite_loop: bool = False) -> RegexGuide:
    assert len(new_fields) <= len(model.model_fields)
    sttm = dict(fsm.states_to_token_maps)
    encode = partial(tokenizer.encode, add_special_tokens=False)
    quote = encode('"')[0]
    # Let's replace the placeholder fields from the model in the finite state model by the new fields
    for orig_field, new_field in zip(model.model_fields, new_fields):
        orig_field_tokens = [encode(orig_field_char)[0] for orig_field_char in orig_field]
        new_field_tokens = encode(new_field)
        assert len(new_field_tokens) <= len(orig_field_tokens)
        # Merge transitions until we have number of transitions = number of tokens in the field name
        for k in reversed(range(len(new_field_tokens), len(orig_field_tokens))):
            sttm = merge_successive_transitions(sttm, orig_field_tokens[k - 1], orig_field_tokens[k])
        # Replace the token ids in the transitions with the ones of the new field name
        for k in range(len(new_field_tokens)):
            sttm = replace_transitions(sttm, orig_field_tokens[k], new_field_tokens[k])
    if len(new_fields) < len(model.model_fields) or make_infinite_loop:
        # Set the last field last state to generate less than the number of fields in the model
        # We need to do this for every possible path
        # e.g. multiple paths are used to count items when setting a min/max length
        orig_last_field = list(model.model_fields)[-1]
        new_last_field = new_fields[-1]
        orig_last_field_paths = find_paths_with_transitions(sttm, [quote] + [encode(c)[0] for c in orig_last_field])
        new_last_field_paths = find_paths_with_transitions(sttm, [quote] + encode(new_last_field))
        if make_infinite_loop:  # this is a hack to loop on the same states over and over again
            orig_last_field_paths = [orig_last_field_paths[0]] * len(orig_last_field_paths)
        for orig_last_field_path, new_last_field_path in zip(
            orig_last_field_paths,
            new_last_field_paths
        ):
            orig_last_field_last_state = orig_last_field_path[-1]
            new_last_field_second_last_state = new_last_field_path[-2]
            sttm[new_last_field_second_last_state] = dict(sttm[new_last_field_second_last_state])
            sttm[new_last_field_second_last_state][encode(new_last_field)[-1]] = orig_last_field_last_state
    fsm = copy(fsm)
    fsm.states_to_token_maps = sttm
    return fsm