File size: 5,496 Bytes
291bc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
from src.extractor import create_extractor
from src.sql_chain import create_agent
from dotenv import load_dotenv
import chainlit as cl
import json
# Loading the environment variables
load_dotenv(".env")
# Create the extractor and agent

model = os.getenv('OPENAI_MODEL')
# Check if model exists, if not, set it to default
# if not model:
#     model = "gpt-3.5-turbo-0125"
ex = create_extractor()
ag = create_agent(llm_model=model)
# ag = create_agent(llm_model = "gpt-4-0125-preview")
openai_api_key = os.getenv('OPENAI_API_KEY')




def extract_func(user_prompt: str):
    """

    Parameters
    ----------
    user_prompt: str

    Returns
    -------
    A dictionary of extracted properties
    """
    extracted = ex.extract_chainlit(user_prompt)
    return extracted
def validate_func(properties:dict):  # Auto validate as much as possible
    """
    Parameters
    ----------
    extracted properties: dict

    Returns
    -------
    Two dictionaries:
    1. validated: The validated properties
    2. need_input: Properties that need human validation
    """
    validated, need_input = ex.validate_chainlit(properties)
    return validated, need_input

def human_validate_func(human, validated, user_prompt):
    """

    Parameters
    ----------
    human - Human validated properties in the form of a list of dictionaries
    validated - Validated properties in the form of a dictionary
    user_prompt - The user prompt

    Returns
    -------
    The cleaned prompt with updated values
    """
    for item in human:
        # Iterate through key-value pairs in the current dictionary
        for key, value in item.items():
            if value == "":
                continue
            # Check if the key exists in the validated dictionary
            if key in validated:
                # Append the value to the existing list
                validated[key].append(value)
            else:
                # Create a new key with the value as a new list
                validated[key] = [value]
    val_list = [validated]

    return ex.build_prompt_chainlit(val_list, user_prompt)

def no_human(validated, user_prompt):
    """
    In case there is no need for human validation, this function will be called
    Parameters
    ----------
    validated
    user_prompt

    Returns
    -------
    Updated prompt
    """
    return ex.build_prompt_chainlit([validated], user_prompt)


def ask(text):
    """
    Calls the SQL Agent to get the final answer
    Parameters
    ----------
    text

    Returns
    -------
    The final answer
    """
    ans, const = ag.ask(text)
    return {"output": ans["output"]}, 12


@cl.step
async def Cleaner(text):  # just for printing
    return text


@cl.step
async def LLM(cleaned_prompt):  # just for printing
    ans, const = ask(cleaned_prompt)
    return ans, const


@cl.step
async def Choice(text):
    return text

@cl.step
async def Extractor(user_prompt):
    extracted_values = extract_func(user_prompt)
    return extracted_values


@cl.on_message  # this function will be called every time a user inputs a message in the UI
async def main(message: cl.Message):
    user_prompt = message.content # Get the user prompt
    # extracted_values = extract_func(user_prompt)
    #
    # json_formatted = json.dumps(extracted_values, indent=4)
    extracted_values = await Extractor(user_prompt)
    json_formatted = json.dumps(extracted_values, indent=4)
    # Print the extracted values in json format
    await cl.Message(author="Extractor", content=f"Extracted properties:\n```json\n{json_formatted}\n```").send()
    # Try to validate everything
    validated, need_input = validate_func(extracted_values)
    await cl.Message(author="Validator", content=f"Extracted properties will now be validated against the database.").send()
    if need_input:
        # If we need validation, we will ask the user to select the correct value
        for element in need_input:
            key = next(iter(element))  # Get the first key in the dictionary
            # Present user with options to choose from
            actions = [
                cl.Action(name=value, value=value, description=str(value))
                for value in element['top_matches']
            ]
            actions.append(cl.Action(name="No Update", value="", description="No Update"))
            # Add a "No Update" option
            res = await cl.AskActionMessage(
                author="Validator",
                content=f"Select the correct value for {element[key]}",
                actions=actions
            ).send()
            selected_value = res.get("value", "") if res else ""
            element[key] = selected_value
            element.pop("top_matches")
            await Choice(selected_value)  # Logging choice
        # Get the cleaned prompt
        cleaned_prompt = human_validate_func(need_input, validated, user_prompt)
    else:
        cleaned_prompt = no_human(validated, user_prompt)
    # Print the cleaned prompt
    cleaner_message = cl.Message(author="Cleaner", content=f"New prompt is as follows:\n{cleaned_prompt}")
    await cleaner_message.send()

    # Call the SQL agent to get the final answer
    # ans, const = ask(cleaned_prompt)  # Get the final answer from some function
    await cl.Message(content=f"I will now query the database for information.").send()
    ans, const = await LLM(cleaned_prompt)
    await cl.Message(content=f"This is the final answer: \n\n{ans['output']}").send()