Benjamona97 commited on
Commit
13ebe63
0 Parent(s):

Add application file

Browse files
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ .chainlit
3
+ *__pycache__
4
+ .idea
5
+ .env
6
+ record_manager_cache.sql
7
+ storage
8
+ .DS_Store
9
+ data
10
+ *.sql
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+
7
+ import chainlit as cl
8
+ from dotenv import load_dotenv
9
+ from langchain.pydantic_v1 import BaseModel, Field
10
+ from langchain.tools import StructuredTool
11
+ from langchain.indexes import SQLRecordManager, index
12
+ from langchain.schema import Document
13
+ from langchain.agents import initialize_agent, AgentExecutor
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.vectorstores.chroma import Chroma
16
+ from langchain_community.document_loaders import CSVLoader
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
19
+ from openai import AsyncOpenAI
20
+
21
+ # from modules.database.database import PostgresDB
22
+ from modules.database.sqlitedatabase import Database
23
+
24
+ """
25
+ Here we define some environment variables and the tools that the agent will use.
26
+ Along with some configuration for the app to start.
27
+ """
28
+ load_dotenv()
29
+
30
+ chunk_size = 512
31
+ chunk_overlap = 50
32
+
33
+ embeddings_model = OpenAIEmbeddings()
34
+ openai_client = AsyncOpenAI()
35
+
36
+ CSV_STORAGE_PATH = "./data"
37
+
38
+
39
+ def remove_triple_backticks(text):
40
+ # Use a regular expression to replace all occurrences of triple backticks with an empty string
41
+ cleaned_text = re.sub(r"```", "", text)
42
+ return cleaned_text
43
+
44
+
45
+ def process_pdfs(pdf_storage_path: str):
46
+ csv_directory = Path(pdf_storage_path)
47
+ docs = [] # type: List[Document]
48
+ text_splitter = RecursiveCharacterTextSplitter(
49
+ chunk_size=chunk_size, chunk_overlap=50)
50
+
51
+ for csv_path in csv_directory.glob("*.csv"):
52
+ loader = CSVLoader(file_path=str(csv_path))
53
+ documents = loader.load()
54
+ docs += text_splitter.split_documents(documents)
55
+
56
+ documents_search = Chroma.from_documents(docs, embeddings_model)
57
+
58
+ namespace = "chromadb/my_documents"
59
+ record_manager = SQLRecordManager(
60
+ namespace, db_url="sqlite:///record_manager_cache.sql"
61
+ )
62
+ record_manager.create_schema()
63
+
64
+ index_result = index(
65
+ docs,
66
+ record_manager,
67
+ documents_search,
68
+ cleanup="incremental",
69
+ source_id_key="source",
70
+ )
71
+
72
+ print(f"Indexing stats: {index_result}")
73
+
74
+ return documents_search
75
+
76
+
77
+ doc_search = process_pdfs(CSV_STORAGE_PATH)
78
+
79
+ """
80
+ Execute SQL query tool definition along schemas.
81
+ """
82
+
83
+
84
+ def execute_sql(query: str) -> str:
85
+ """
86
+ Execute SQLite queries queries against the database. Delete all markdown code and backticks from the query.
87
+ """
88
+ db = Database("./db/mydatabase.db")
89
+ db.connect()
90
+
91
+ # results = db.run_sql_to_markdown(query)
92
+ cleaned_query = remove_triple_backticks(query)
93
+
94
+ results = db.execute_query(cleaned_query)
95
+
96
+ return results + f"\nQuery used:\n```sql{cleaned_query}```"
97
+
98
+
99
+ class ExecuteSqlToolInput(BaseModel):
100
+ query: str = Field(
101
+ description="A SQLite query to be executed agains the database")
102
+
103
+
104
+ execute_sql_tool = StructuredTool(
105
+ func=execute_sql,
106
+ name="Execute SQL",
107
+ description="useful for when you need to execute SQL queries against the database. Always use a clause LIMIT 10",
108
+ args_schema=ExecuteSqlToolInput
109
+ )
110
+
111
+ """
112
+ Research database tool definition along schemas.
113
+ """
114
+
115
+
116
+ def research_database(user_request: str) -> str:
117
+ """
118
+ Searches for table definitions matching the user request
119
+ """
120
+ search_kwargs = {"k": 30}
121
+
122
+ retriever = doc_search.as_retriever(search_kwargs=search_kwargs)
123
+
124
+ def format_docs(docs):
125
+ for i, doc in enumerate(docs):
126
+ print(f"{i+1}. {doc.page_content}")
127
+ return "\n\n".join([d.page_content for d in docs])
128
+
129
+ results = retriever.invoke(user_request)
130
+
131
+ return format_docs(results)
132
+
133
+
134
+ class ResearchDatabaseToolInput(BaseModel):
135
+ user_request: str = Field(
136
+ description="The user query to search against the table definitions for matches.")
137
+
138
+
139
+ research_database_tool = StructuredTool(
140
+ func=research_database,
141
+ name="Search db info",
142
+ description="Search for database table definitions so you can have context for building SQL queries. The queries needs to be SQLite compatible.",
143
+ args_schema=ResearchDatabaseToolInput
144
+ )
145
+
146
+
147
+ @cl.on_chat_start
148
+ def start():
149
+ tools = [execute_sql_tool, research_database_tool]
150
+
151
+ llm = ChatOpenAI(model="gpt-4", temperature=0, verbose=True)
152
+
153
+ prompt = ChatPromptTemplate.from_template(
154
+ """
155
+ You are a SQLite world class data scientist, based on user query
156
+ use your tools to do the job. Usually you would start by analyzing
157
+ for possible SQL queries the user wants to build based on your knowledge base.
158
+ Remember your tools are:
159
+
160
+ - execute_sql (bring back the results as of running the query against the database)
161
+ - research_database (search for table definitions so you can build a SQLite Query)
162
+
163
+ Remember, you are building SQLite compatible queries. If you don't know the answer don't
164
+ make anything up. Always ask for feedback. One last detail: always run the querys with LIMIT 10 and add
165
+ the SQL query as markdown to the final answer so the user knows what SQL query was used for the job and
166
+ can copy it for further use.
167
+
168
+ REMEMBER TO GENERATE ALWAYS SQLITE COMPATIBLE QUERIES.
169
+
170
+ User query: {input}
171
+ """
172
+ )
173
+
174
+ agent = initialize_agent(tools=tools, prompt=prompt,
175
+ llm=llm, handle_parsing_errors=True)
176
+
177
+ cl.user_session.set("agent", agent)
178
+
179
+
180
+ @cl.on_message
181
+ async def main(message: cl.Message):
182
+ agent = cl.user_session.get("agent") # type: AgentExecutor
183
+ res = await agent.arun(
184
+ message.content, callbacks=[cl.AsyncLangchainCallbackHandler()]
185
+ )
186
+
187
+ await cl.Message(content=res).send()
chainlit.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ To start, you can try to ask a simple question:
2
+
3
+ ```"I need all customers please!"```
4
+
5
+ #### Then you can expand each cell execution to watch the agent's work and analize each step trough the process.
modules/database/__init__.py ADDED
File without changes
modules/database/database.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import psycopg2
4
+ from tabulate import tabulate
5
+
6
+
7
+ class PostgresDB:
8
+ """
9
+ A class to manage postgres connections and queries
10
+ """
11
+
12
+ def __init__(self):
13
+ self.conn = None
14
+ self.cur = None
15
+
16
+ def __enter__(self):
17
+ return self
18
+
19
+ def __exit__(self, exc_type, exc_val, exc_tb):
20
+ if self.cur:
21
+ self.cur.close()
22
+ if self.conn:
23
+ self.conn.close()
24
+
25
+ def connect_with_url(self, url):
26
+ self.conn = psycopg2.connect(url)
27
+ self.cur = self.conn.cursor()
28
+
29
+ def close(self):
30
+ if self.cur:
31
+ self.cur.close()
32
+ if self.conn:
33
+ self.conn.close()
34
+
35
+ def run_sql(self, sql) -> str:
36
+ """
37
+ Run a SQL query against the postgres database.
38
+ Returns JSON.
39
+ """
40
+ self.cur.execute(sql)
41
+ columns = [desc[0] for desc in self.cur.description]
42
+ res = self.cur.fetchall()
43
+
44
+ list_of_dicts = [dict(zip(columns, row)) for row in res]
45
+
46
+ json_result = json.dumps(list_of_dicts, indent=4)
47
+
48
+ return json_result
49
+
50
+ # method to run a sql and return markdown
51
+ def run_sql_to_markdown(self, sql) -> str:
52
+ """
53
+ Run a SQL query against the postgres database
54
+ Returns markdown table.
55
+ """
56
+ self.cur.execute(sql)
57
+ columns = [desc[0] for desc in self.cur.description]
58
+ res = self.cur.fetchall()
59
+
60
+ list_of_dicts = [dict(zip(columns, row)) for row in res]
61
+
62
+ markdown_table = self.to_markdown(list_of_dicts)
63
+ print(markdown_table)
64
+ return markdown_table
65
+
66
+ @staticmethod
67
+ def to_markdown(data) -> str:
68
+ """
69
+ Convert a list of dictionaries to markdown
70
+ """
71
+ return tabulate(data, headers="keys", tablefmt="pipe")
modules/database/sqlitedatabase.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import tabulate
3
+
4
+ class Database:
5
+ def __init__(self, uri):
6
+ self.uri = uri
7
+ self.connection = None
8
+
9
+ def connect(self):
10
+ self.connection = sqlite3.connect(self.uri)
11
+
12
+ def execute_query(self, query):
13
+ cursor = self.connection.cursor()
14
+ cursor.execute(query)
15
+ result = cursor.fetchall()
16
+ cursor.close()
17
+ headers = [description[0] for description in cursor.description]
18
+ return tabulate.tabulate(result, headers, tablefmt="pipe")
19
+
20
+ def close(self):
21
+ self.connection.close()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai==1.10.0
2
+ psycopg==3.1.14
3
+ psycopg2-binary==2.9.9
4
+ python-dotenv==1.0.0
5
+ tiktoken==0.5.2
6
+ python-dotenv==1.0.0
7
+ sqlalchemy[asyncio]
8
+ chainlit==1.0.200
9
+ langchain==0.1.4
10
+ langchain-community==0.0.16
11
+ langchain-openai==0.0.5
12
+ asyncpg
13
+ db-dtypes==1.2.0
14
+ tabulate==0.9.0
15
+ chromadb==0.4.22