sql-agent / app.py
Benjamona97's picture
Adds final changes
1119b15
raw
history blame contribute delete
No virus
5.67 kB
import re
from pathlib import Path
from typing import List
import chainlit as cl
from dotenv import load_dotenv
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import StructuredTool
from langchain.indexes import SQLRecordManager, index
from langchain.schema import Document
from langchain.agents import initialize_agent, AgentExecutor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain_community.document_loaders import CSVLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from openai import AsyncOpenAI
# from modules.database.database import PostgresDB
from modules.database.sqlitedatabase import Database
"""
Here we define some environment variables and the tools that the agent will use.
Along with some configuration for the app to start.
"""
load_dotenv()
chunk_size = 512
chunk_overlap = 50
embeddings_model = OpenAIEmbeddings()
openai_client = AsyncOpenAI()
CSV_STORAGE_PATH = "./data"
def remove_triple_backticks(text):
# Use a regular expression to replace all occurrences of triple backticks with an empty string
cleaned_text = re.sub(r"```", "", text)
return cleaned_text
def process_pdfs(pdf_storage_path: str):
csv_directory = Path(pdf_storage_path)
docs = [] # type: List[Document]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=50)
for csv_path in csv_directory.glob("*.csv"):
loader = CSVLoader(file_path=str(csv_path))
documents = loader.load()
docs += text_splitter.split_documents(documents)
documents_search = Chroma.from_documents(docs, embeddings_model)
namespace = "chromadb/my_documents"
record_manager = SQLRecordManager(
namespace, db_url="sqlite:///record_manager_cache.sql"
)
record_manager.create_schema()
index_result = index(
docs,
record_manager,
documents_search,
cleanup="incremental",
source_id_key="source",
)
print(f"Indexing stats: {index_result}")
return documents_search
doc_search = process_pdfs(CSV_STORAGE_PATH)
"""
Execute SQL query tool definition along schemas.
"""
def execute_sql(query: str) -> str:
"""
Execute SQLite queries queries against the database. Delete all markdown code and backticks from the query.
"""
db = Database("./db/mydatabase.db")
db.connect()
cleaned_query = remove_triple_backticks(query)
results = db.execute_query(cleaned_query)
return results + f"\nQuery used:\n```sql{cleaned_query}```"
class ExecuteSqlToolInput(BaseModel):
query: str = Field(
description="A SQLite query to be executed agains the database")
execute_sql_tool = StructuredTool(
func=execute_sql,
name="Execute SQL",
description="useful for when you need to execute SQL queries against the database. Always use a clause LIMIT 10",
args_schema=ExecuteSqlToolInput
)
"""
Research database tool definition along schemas.
"""
def research_database(user_request: str) -> str:
"""
Searches for table definitions matching the user request
"""
search_kwargs = {"k": 30}
retriever = doc_search.as_retriever(search_kwargs=search_kwargs)
def format_docs(docs):
for i, doc in enumerate(docs):
print(f"{i+1}. {doc.page_content}")
return "\n\n".join([d.page_content for d in docs])
results = retriever.invoke(user_request)
return format_docs(results)
class ResearchDatabaseToolInput(BaseModel):
user_request: str = Field(
description="The user query to search against the table definitions for matches.")
research_database_tool = StructuredTool(
func=research_database,
name="Search db info",
description="Search for database table definitions so you can have context for building SQL queries. The queries needs to be SQLite compatible.",
args_schema=ResearchDatabaseToolInput
)
@cl.on_chat_start
def start():
tools = [execute_sql_tool, research_database_tool]
llm = ChatOpenAI(model="gpt-4", temperature=0, verbose=True)
prompt = ChatPromptTemplate.from_template(
"""
You are a SQLite world class data scientist, based on user query
use your tools to do the job. Usually you would start by analyzing
for possible SQL queries the user wants to build based on your knowledge base.
Remember your tools are:
- execute_sql (bring back the results as of running the query against the database)
- research_database (search for table definitions so you can build a SQLite Query)
Remember, you are building SQLite compatible queries. If you don't know the answer don't
make anything up. Always ask for feedback. One last detail: always run the querys with LIMIT 10 and add
the SQL query as markdown to the final answer so the user knows what SQL query was used for the job and
can copy it for further use.
REMEMBER TO GENERATE ALWAYS SQLITE COMPATIBLE QUERIES.
User query: {input}
"""
)
agent = initialize_agent(tools=tools, prompt=prompt,
llm=llm, handle_parsing_errors=True)
cl.user_session.set("agent", agent)
@cl.on_message
async def main(message: cl.Message):
agent = cl.user_session.get("agent") # type: AgentExecutor
res = await agent.arun(
message.content, callbacks=[cl.AsyncLangchainCallbackHandler()]
)
await cl.Message(content=res).send()