Spaces:
Sleeping
Sleeping
File size: 5,667 Bytes
13ebe63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import re
from pathlib import Path
from typing import List
import chainlit as cl
from dotenv import load_dotenv
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import StructuredTool
from langchain.indexes import SQLRecordManager, index
from langchain.schema import Document
from langchain.agents import initialize_agent, AgentExecutor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain_community.document_loaders import CSVLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from openai import AsyncOpenAI
# from modules.database.database import PostgresDB
from modules.database.sqlitedatabase import Database
"""
Here we define some environment variables and the tools that the agent will use.
Along with some configuration for the app to start.
"""
load_dotenv()
chunk_size = 512
chunk_overlap = 50
embeddings_model = OpenAIEmbeddings()
openai_client = AsyncOpenAI()
CSV_STORAGE_PATH = "./data"
def remove_triple_backticks(text):
# Use a regular expression to replace all occurrences of triple backticks with an empty string
cleaned_text = re.sub(r"```", "", text)
return cleaned_text
def process_pdfs(pdf_storage_path: str):
csv_directory = Path(pdf_storage_path)
docs = [] # type: List[Document]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=50)
for csv_path in csv_directory.glob("*.csv"):
loader = CSVLoader(file_path=str(csv_path))
documents = loader.load()
docs += text_splitter.split_documents(documents)
documents_search = Chroma.from_documents(docs, embeddings_model)
namespace = "chromadb/my_documents"
record_manager = SQLRecordManager(
namespace, db_url="sqlite:///record_manager_cache.sql"
)
record_manager.create_schema()
index_result = index(
docs,
record_manager,
documents_search,
cleanup="incremental",
source_id_key="source",
)
print(f"Indexing stats: {index_result}")
return documents_search
doc_search = process_pdfs(CSV_STORAGE_PATH)
"""
Execute SQL query tool definition along schemas.
"""
def execute_sql(query: str) -> str:
"""
Execute SQLite queries queries against the database. Delete all markdown code and backticks from the query.
"""
db = Database("./db/mydatabase.db")
db.connect()
cleaned_query = remove_triple_backticks(query)
results = db.execute_query(cleaned_query)
return results + f"\nQuery used:\n```sql{cleaned_query}```"
class ExecuteSqlToolInput(BaseModel):
query: str = Field(
description="A SQLite query to be executed agains the database")
execute_sql_tool = StructuredTool(
func=execute_sql,
name="Execute SQL",
description="useful for when you need to execute SQL queries against the database. Always use a clause LIMIT 10",
args_schema=ExecuteSqlToolInput
)
"""
Research database tool definition along schemas.
"""
def research_database(user_request: str) -> str:
"""
Searches for table definitions matching the user request
"""
search_kwargs = {"k": 30}
retriever = doc_search.as_retriever(search_kwargs=search_kwargs)
def format_docs(docs):
for i, doc in enumerate(docs):
print(f"{i+1}. {doc.page_content}")
return "\n\n".join([d.page_content for d in docs])
results = retriever.invoke(user_request)
return format_docs(results)
class ResearchDatabaseToolInput(BaseModel):
user_request: str = Field(
description="The user query to search against the table definitions for matches.")
research_database_tool = StructuredTool(
func=research_database,
name="Search db info",
description="Search for database table definitions so you can have context for building SQL queries. The queries needs to be SQLite compatible.",
args_schema=ResearchDatabaseToolInput
)
@cl.on_chat_start
def start():
tools = [execute_sql_tool, research_database_tool]
llm = ChatOpenAI(model="gpt-4", temperature=0, verbose=True)
prompt = ChatPromptTemplate.from_template(
"""
You are a SQLite world class data scientist, based on user query
use your tools to do the job. Usually you would start by analyzing
for possible SQL queries the user wants to build based on your knowledge base.
Remember your tools are:
- execute_sql (bring back the results as of running the query against the database)
- research_database (search for table definitions so you can build a SQLite Query)
Remember, you are building SQLite compatible queries. If you don't know the answer don't
make anything up. Always ask for feedback. One last detail: always run the querys with LIMIT 10 and add
the SQL query as markdown to the final answer so the user knows what SQL query was used for the job and
can copy it for further use.
REMEMBER TO GENERATE ALWAYS SQLITE COMPATIBLE QUERIES.
User query: {input}
"""
)
agent = initialize_agent(tools=tools, prompt=prompt,
llm=llm, handle_parsing_errors=True)
cl.user_session.set("agent", agent)
@cl.on_message
async def main(message: cl.Message):
agent = cl.user_session.get("agent") # type: AgentExecutor
res = await agent.arun(
message.content, callbacks=[cl.AsyncLangchainCallbackHandler()]
)
await cl.Message(content=res).send()
|