Spaces:
Sleeping
Sleeping
File size: 8,448 Bytes
68f18b5 fd4f795 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 68f18b5 291bc70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import logging
import json
import os
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.agent_toolkits import create_sql_agent
from langchain_core.prompts import (
ChatPromptTemplate,
FewShotPromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_community.utilities import SQLDatabase
from dotenv import load_dotenv
load_dotenv(".env")
logging.basicConfig(level=logging.INFO)
# Save the log to a file
handler = logging.FileHandler('extractor.log')
logger = logging.getLogger(__name__)
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
if os.getenv('LANGSMITH'):
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
os.environ[
'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
os.environ['LANGCHAIN_PROJECT'] = os.getenv('LANGSMITH_PROJECT')
def load_json(file_path: str) -> dict:
with open(file_path, 'r') as file:
return json.load(file)
class SqlChain:
def __init__(self, few_shot_prompts: str, llm_model="gpt-3.5-turbo", db_uri="sqlite:///data/games.db",
few_shot_k=2):
self.llm = ChatOpenAI(model=llm_model, temperature=0)
self.db = SQLDatabase.from_uri(db_uri)
self.few_shot_k = few_shot_k
self.few_shot = self._set_up_few_shot_prompts(load_json(few_shot_prompts))
self.full_prompt = None
self.agent = create_sql_agent(
llm=self.llm,
db=self.db,
prompt=self.full_prompt,
max_iterations=10,
verbose=True,
agent_type="openai-tools",
# Default to 10 examples - Can be overwritten with the prompt
top_k=30,
)
def _set_up_few_shot_prompts(self, few_shot_prompts: dict) -> None:
few_shots = SemanticSimilarityExampleSelector.from_examples(
few_shot_prompts,
OpenAIEmbeddings(),
FAISS,
k=self.few_shot_k,
input_keys=["input"],
)
return few_shots
def few_prompt_construct(self, query: str, top_k=5, dialect="SQLite") -> str:
system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
ALWAYS query the database before returning an answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return 'I don't know' as the answer.
DO NOT include information that is not present in the database in your answer.
Here are some examples of user inputs and their corresponding SQL queries. They are tested and works.
Use them as a guide when creating your own queries:"""
# SUFFIX = """Begin!
#
# Question: {input}
# Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
# I will not stop until I query the database and return the answer.
# {agent_scratchpad}"""
SUFFIX = """Begin!
Question: {input}
Thought: I should look at the examples provided and see if I can use them to identify tables and how to build the query.
Then I should query the schema of the most relevant tables.
I will not stop until I query the database and return the answer.
{agent_scratchpad}"""
few_shot_prompt = FewShotPromptTemplate(
example_selector=self.few_shot,
example_prompt=PromptTemplate.from_template(
"User input: {input}\nSQL query: {query}"
),
input_variables=["input", "dialect", "top_k"],
prefix=system_prefix,
suffix=SUFFIX,
)
full_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate(prompt=few_shot_prompt),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
self.full_prompt = full_prompt.invoke(
{
"input": query,
"top_k": top_k,
"dialect": dialect,
"agent_scratchpad": [],
}
)
def prompt_no_few_shot(self, query: str, dialect="SQLite") -> str:
system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return 'I don't know' as the answer.
DO NOT include information that is not present in the database in your answer."""
return f"{system_prefix}\n{query}"
def ask(self, query: str, few_prompt: bool = True, rag_test=False) -> str:
if rag_test:
self.few_prompt_construct(query)
# Alter the self.full_prompt to only include whats added by the RAG system
# Get content in self.full_prompt[messages][0][content]
prompt = self.full_prompt.messages
prompt = prompt[0].content
prompt = prompt.split("Use them as a guide when creating your own queries:\n\n")[1]
# Then remove everything after \n\nBegin!\n\n
prompt = prompt.split("\n\nBegin!\n\n")[0]
# Lets split it to a list. One element for each "User input: {input}\nSQL query: {query}"
prompt = prompt.split("User input: ")
# Then remove the first element
prompt = prompt[1:]
return prompt
if few_prompt:
self.few_prompt_construct(query)
return self.agent.invoke({"input": self.full_prompt}), self.full_prompt
else:
return self.agent.invoke(self.prompt_no_few_shot(query)), self.prompt_no_few_shot(query)
def create_agent(few_shot_prompts: str = "src/conf/sqls.json", llm_model="gpt-3.5-turbo-0125",
db_uri="config", few_shot_k=2):
""" Create an agent with the given few_shot_prompts, llm_model and db_uri
Call it with agent.ask(prompt)"""
if db_uri == "config":
db_uri = os.getenv('DATABASE_PATH')
db_uri = f"sqlite:///{db_uri}"
# print(db_uri)
# print("sqlite:///data/games.db")
# exit(0)
return SqlChain(few_shot_prompts, llm_model, db_uri, few_shot_k)
if __name__ == "__main__":
chain = SqlChain("src/conf/sqls.json")
chain.ask("Is Manchester United in the database?", rag_test=True)
|