File size: 5,667 Bytes
13ebe63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import re
from pathlib import Path
from typing import List


import chainlit as cl
from dotenv import load_dotenv
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import StructuredTool
from langchain.indexes import SQLRecordManager, index
from langchain.schema import Document
from langchain.agents import initialize_agent, AgentExecutor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain_community.document_loaders import CSVLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from openai import AsyncOpenAI

# from modules.database.database import PostgresDB
from modules.database.sqlitedatabase import Database

"""
Here we define some environment variables and the tools that the agent will use.
Along with some configuration for the app to start.
"""
load_dotenv()

chunk_size = 512
chunk_overlap = 50

embeddings_model = OpenAIEmbeddings()
openai_client = AsyncOpenAI()

CSV_STORAGE_PATH = "./data"


def remove_triple_backticks(text):
    # Use a regular expression to replace all occurrences of triple backticks with an empty string
    cleaned_text = re.sub(r"```", "", text)
    return cleaned_text


def process_pdfs(pdf_storage_path: str):
    csv_directory = Path(pdf_storage_path)
    docs = []  # type: List[Document]
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=50)

    for csv_path in csv_directory.glob("*.csv"):
        loader = CSVLoader(file_path=str(csv_path))
        documents = loader.load()
        docs += text_splitter.split_documents(documents)

    documents_search = Chroma.from_documents(docs, embeddings_model)

    namespace = "chromadb/my_documents"
    record_manager = SQLRecordManager(
        namespace, db_url="sqlite:///record_manager_cache.sql"
    )
    record_manager.create_schema()

    index_result = index(
        docs,
        record_manager,
        documents_search,
        cleanup="incremental",
        source_id_key="source",
    )

    print(f"Indexing stats: {index_result}")

    return documents_search


doc_search = process_pdfs(CSV_STORAGE_PATH)

"""
Execute SQL query tool definition along schemas.
"""


def execute_sql(query: str) -> str:
    """
    Execute SQLite queries queries against the database. Delete all markdown code and backticks from the query.
    """
    db = Database("./db/mydatabase.db")
    db.connect()

    cleaned_query = remove_triple_backticks(query)

    results = db.execute_query(cleaned_query)

    return results + f"\nQuery used:\n```sql{cleaned_query}```"


class ExecuteSqlToolInput(BaseModel):
    query: str = Field(
        description="A SQLite query to be executed agains the database")


execute_sql_tool = StructuredTool(
    func=execute_sql,
    name="Execute SQL",
    description="useful for when you need to execute SQL queries against the database. Always use a clause LIMIT 10",
    args_schema=ExecuteSqlToolInput
)

"""
Research database tool definition along schemas.
"""


def research_database(user_request: str) -> str:
    """
    Searches for table definitions matching the user request
    """
    search_kwargs = {"k": 30}

    retriever = doc_search.as_retriever(search_kwargs=search_kwargs)

    def format_docs(docs):
        for i, doc in enumerate(docs):
            print(f"{i+1}. {doc.page_content}")
        return "\n\n".join([d.page_content for d in docs])

    results = retriever.invoke(user_request)

    return format_docs(results)


class ResearchDatabaseToolInput(BaseModel):
    user_request: str = Field(
        description="The user query to search against the table definitions for matches.")


research_database_tool = StructuredTool(
    func=research_database,
    name="Search db info",
    description="Search for database table definitions so you can have context for building SQL queries. The queries needs to be SQLite compatible.",
    args_schema=ResearchDatabaseToolInput
)


@cl.on_chat_start
def start():
    tools = [execute_sql_tool, research_database_tool]

    llm = ChatOpenAI(model="gpt-4", temperature=0, verbose=True)

    prompt = ChatPromptTemplate.from_template(
        """
            You are a SQLite world class data scientist, based on user query
            use your tools to do the job. Usually you would start by analyzing
            for possible SQL queries the user wants to build based on your knowledge base.
            Remember your tools are:

            - execute_sql (bring back the results as of running the query against the database)
            - research_database (search for table definitions so you can build a SQLite Query)

            Remember, you are building SQLite compatible queries. If you don't know the answer don't
            make anything up. Always ask for feedback. One last detail: always run the querys with LIMIT 10 and add
            the SQL query as markdown to the final answer so the user knows what SQL query was used for the job and
            can copy it for further use.

            REMEMBER TO GENERATE ALWAYS SQLITE COMPATIBLE QUERIES.

            User query: {input}
        """
    )

    agent = initialize_agent(tools=tools, prompt=prompt,
                             llm=llm, handle_parsing_errors=True)

    cl.user_session.set("agent", agent)


@cl.on_message
async def main(message: cl.Message):
    agent = cl.user_session.get("agent")  # type: AgentExecutor
    res = await agent.arun(
        message.content, callbacks=[cl.AsyncLangchainCallbackHandler()]
    )

    await cl.Message(content=res).send()