Spaces:
Sleeping
Sleeping
import os | |
import re | |
from pathlib import Path | |
from typing import List | |
import chainlit as cl | |
from dotenv import load_dotenv | |
from langchain.pydantic_v1 import BaseModel, Field | |
from langchain.tools import StructuredTool | |
from langchain.indexes import SQLRecordManager, index | |
from langchain.schema import Document | |
from langchain.agents import initialize_agent, AgentExecutor | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores.chroma import Chroma | |
from langchain_community.document_loaders import CSVLoader | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from openai import AsyncOpenAI | |
# from modules.database.database import PostgresDB | |
from modules.database.sqlitedatabase import Database | |
""" | |
Here we define some environment variables and the tools that the agent will use. | |
Along with some configuration for the app to start. | |
""" | |
load_dotenv() | |
chunk_size = 512 | |
chunk_overlap = 50 | |
embeddings_model = OpenAIEmbeddings() | |
openai_client = AsyncOpenAI() | |
CSV_STORAGE_PATH = "./data" | |
def remove_triple_backticks(text): | |
# Use a regular expression to replace all occurrences of triple backticks with an empty string | |
cleaned_text = re.sub(r"```", "", text) | |
return cleaned_text | |
def process_pdfs(pdf_storage_path: str): | |
csv_directory = Path(pdf_storage_path) | |
docs = [] # type: List[Document] | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=50) | |
for csv_path in csv_directory.glob("*.csv"): | |
loader = CSVLoader(file_path=str(csv_path)) | |
documents = loader.load() | |
docs += text_splitter.split_documents(documents) | |
documents_search = Chroma.from_documents(docs, embeddings_model) | |
namespace = "chromadb/my_documents" | |
record_manager = SQLRecordManager( | |
namespace, db_url="sqlite:///record_manager_cache.sql" | |
) | |
record_manager.create_schema() | |
index_result = index( | |
docs, | |
record_manager, | |
documents_search, | |
cleanup="incremental", | |
source_id_key="source", | |
) | |
print(f"Indexing stats: {index_result}") | |
return documents_search | |
doc_search = process_pdfs(CSV_STORAGE_PATH) | |
""" | |
Execute SQL query tool definition along schemas. | |
""" | |
def execute_sql(query: str) -> str: | |
""" | |
Execute SQLite queries queries against the database. Delete all markdown code and backticks from the query. | |
""" | |
db = Database("./db/mydatabase.db") | |
db.connect() | |
# results = db.run_sql_to_markdown(query) | |
cleaned_query = remove_triple_backticks(query) | |
results = db.execute_query(cleaned_query) | |
return results + f"\nQuery used:\n```sql{cleaned_query}```" | |
class ExecuteSqlToolInput(BaseModel): | |
query: str = Field( | |
description="A SQLite query to be executed agains the database") | |
execute_sql_tool = StructuredTool( | |
func=execute_sql, | |
name="Execute SQL", | |
description="useful for when you need to execute SQL queries against the database. Always use a clause LIMIT 10", | |
args_schema=ExecuteSqlToolInput | |
) | |
""" | |
Research database tool definition along schemas. | |
""" | |
def research_database(user_request: str) -> str: | |
""" | |
Searches for table definitions matching the user request | |
""" | |
search_kwargs = {"k": 30} | |
retriever = doc_search.as_retriever(search_kwargs=search_kwargs) | |
def format_docs(docs): | |
for i, doc in enumerate(docs): | |
print(f"{i+1}. {doc.page_content}") | |
return "\n\n".join([d.page_content for d in docs]) | |
results = retriever.invoke(user_request) | |
return format_docs(results) | |
class ResearchDatabaseToolInput(BaseModel): | |
user_request: str = Field( | |
description="The user query to search against the table definitions for matches.") | |
research_database_tool = StructuredTool( | |
func=research_database, | |
name="Search db info", | |
description="Search for database table definitions so you can have context for building SQL queries. The queries needs to be SQLite compatible.", | |
args_schema=ResearchDatabaseToolInput | |
) | |
def start(): | |
tools = [execute_sql_tool, research_database_tool] | |
llm = ChatOpenAI(model="gpt-4", temperature=0, verbose=True) | |
prompt = ChatPromptTemplate.from_template( | |
""" | |
You are a SQLite world class data scientist, based on user query | |
use your tools to do the job. Usually you would start by analyzing | |
for possible SQL queries the user wants to build based on your knowledge base. | |
Remember your tools are: | |
- execute_sql (bring back the results as of running the query against the database) | |
- research_database (search for table definitions so you can build a SQLite Query) | |
Remember, you are building SQLite compatible queries. If you don't know the answer don't | |
make anything up. Always ask for feedback. One last detail: always run the querys with LIMIT 10 and add | |
the SQL query as markdown to the final answer so the user knows what SQL query was used for the job and | |
can copy it for further use. | |
REMEMBER TO GENERATE ALWAYS SQLITE COMPATIBLE QUERIES. | |
User query: {input} | |
""" | |
) | |
agent = initialize_agent(tools=tools, prompt=prompt, | |
llm=llm, handle_parsing_errors=True) | |
cl.user_session.set("agent", agent) | |
async def main(message: cl.Message): | |
agent = cl.user_session.get("agent") # type: AgentExecutor | |
res = await agent.arun( | |
message.content, callbacks=[cl.AsyncLangchainCallbackHandler()] | |
) | |
await cl.Message(content=res).send() | |