Spaces:
Sleeping
Sleeping
Commit
·
3f4dbc7
1
Parent(s):
40fd4e5
sql spider
Browse files- README.md +2 -0
- app.py +10 -0
- configfile.ini +1 -1
- requirements.txt +2 -1
- src/agents/sqlspideragent.py +9 -0
- src/usecases/agentchatsqlspider.py +59 -0
- src/usecases/withllamaIndex.py +1 -1
- src/utils/sqlexecutor.py +47 -0
README.md
CHANGED
|
@@ -1,2 +1,4 @@
|
|
| 1 |
# AutogenMultiAgent
|
| 2 |
Autogen Multiagent
|
|
|
|
|
|
|
|
|
| 1 |
# AutogenMultiAgent
|
| 2 |
Autogen Multiagent
|
| 3 |
+
|
| 4 |
+
|
app.py
CHANGED
|
@@ -4,6 +4,7 @@ from configfile import Config
|
|
| 4 |
from src.streamlitui.loadui import LoadStreamlitUI
|
| 5 |
from src.usecases.multiagentschat import MultiAgentChat
|
| 6 |
from src.usecases.withllamaIndex import WithLlamaIndexMultiAgentChat
|
|
|
|
| 7 |
from src.LLMS.groqllm import GroqLLM
|
| 8 |
|
| 9 |
|
|
@@ -38,3 +39,12 @@ if __name__ == "__main__":
|
|
| 38 |
problem=problem,user_input=user_input)
|
| 39 |
obj_usecases_with_llamaIndex_multichat.run()
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from src.streamlitui.loadui import LoadStreamlitUI
|
| 5 |
from src.usecases.multiagentschat import MultiAgentChat
|
| 6 |
from src.usecases.withllamaIndex import WithLlamaIndexMultiAgentChat
|
| 7 |
+
from src.usecases.agentchatsqlspider import AgentChatSqlSpider
|
| 8 |
from src.LLMS.groqllm import GroqLLM
|
| 9 |
|
| 10 |
|
|
|
|
| 39 |
problem=problem,user_input=user_input)
|
| 40 |
obj_usecases_with_llamaIndex_multichat.run()
|
| 41 |
|
| 42 |
+
|
| 43 |
+
elif user_input['selected_usecase'] == "AgentChat Sql Spider":
|
| 44 |
+
obj_sql_spider = AgentChatSqlSpider(assistant_name="Assistant", user_proxy_name='Userproxy',
|
| 45 |
+
llm_config=llm_config,
|
| 46 |
+
problem=problem)
|
| 47 |
+
|
| 48 |
+
obj_sql_spider.run()
|
| 49 |
+
|
| 50 |
+
|
configfile.ini
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
[DEFAULT]
|
| 2 |
PAGE_TITLE = AUTOGEN IN ACTION
|
| 3 |
LLM_OPTIONS = Groq, Huggingface
|
| 4 |
-
USECASE_OPTIONS = MultiAgent Chat, RAG Chat, With LLamaIndex Tool,
|
| 5 |
GROQ_MODEL_OPTIONS = mixtral-8x7b-32768, llama3-8b-8192, llama3-70b-8192, gemma-7b-i
|
| 6 |
|
|
|
|
| 1 |
[DEFAULT]
|
| 2 |
PAGE_TITLE = AUTOGEN IN ACTION
|
| 3 |
LLM_OPTIONS = Groq, Huggingface
|
| 4 |
+
USECASE_OPTIONS = MultiAgent Chat, RAG Chat, With LLamaIndex Tool, AgentChat Sql Spider
|
| 5 |
GROQ_MODEL_OPTIONS = mixtral-8x7b-32768, llama3-8b-8192, llama3-70b-8192, gemma-7b-i
|
| 6 |
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ llama-index
|
|
| 5 |
llama-index-tools-wikipedia
|
| 6 |
llama-index-readers-wikipedia
|
| 7 |
wikipedia
|
| 8 |
-
llama-index-llms-groq
|
|
|
|
|
|
| 5 |
llama-index-tools-wikipedia
|
| 6 |
llama-index-readers-wikipedia
|
| 7 |
wikipedia
|
| 8 |
+
llama-index-llms-groq
|
| 9 |
+
spider-env
|
src/agents/sqlspideragent.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from autogen import ConversableAgent
|
| 2 |
+
import streamlit as st
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TrackableSpiderConversableAgent(ConversableAgent):
|
| 6 |
+
def _process_received_message(self, message, sender, silent):
|
| 7 |
+
with st.chat_message(sender.name):
|
| 8 |
+
st.write(message)
|
| 9 |
+
return super()._process_received_message(message, sender, silent)
|
src/usecases/agentchatsqlspider.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from autogen import ConversableAgent
|
| 6 |
+
|
| 7 |
+
from src.agents.sqlspideragent import TrackableSpiderConversableAgent
|
| 8 |
+
from src.agents.userproxyagent import TrackableUserProxyAgent
|
| 9 |
+
import streamlit as st
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
from src.utils.sqlexecutor import SQLExec
|
| 13 |
+
|
| 14 |
+
os.environ["AUTOGEN_USE_DOCKER"] = "False"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AgentChatSqlSpider:
|
| 18 |
+
def __init__(self, assistant_name, user_proxy_name, llm_config, problem):
|
| 19 |
+
self.schema = None
|
| 20 |
+
self.question = None
|
| 21 |
+
self.sql_writer = TrackableSpiderConversableAgent(
|
| 22 |
+
"sql_writer",
|
| 23 |
+
llm_config=llm_config,
|
| 24 |
+
system_message="You are good at writing SQL queries. Always respond with a function call to execute_sql().",
|
| 25 |
+
is_termination_msg=self.check_termination,
|
| 26 |
+
)
|
| 27 |
+
self.user_proxy = TrackableUserProxyAgent(name=user_proxy_name,
|
| 28 |
+
system_message="You are Admin",
|
| 29 |
+
human_input_mode="NEVER",
|
| 30 |
+
llm_config=llm_config,
|
| 31 |
+
code_execution_config=False,
|
| 32 |
+
is_termination_msg=lambda x: x.get("content", "").strip().endswith(
|
| 33 |
+
"TERMINATE"))
|
| 34 |
+
|
| 35 |
+
self.problem = problem
|
| 36 |
+
self.loop = asyncio.new_event_loop()
|
| 37 |
+
asyncio.set_event_loop(self.loop)
|
| 38 |
+
|
| 39 |
+
async def initiate_chat(self):
|
| 40 |
+
message = f"""Below is the schema for a SQL database:
|
| 41 |
+
{self.schema}
|
| 42 |
+
Generate a SQL query to answer the following question:
|
| 43 |
+
{self.question}
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
obj_SQLExec = SQLExec(self.sql_writer,self.user_proxy)
|
| 47 |
+
|
| 48 |
+
await self.user_proxy.a_initiate_chat(self.sql_writer, message=message,
|
| 49 |
+
clear_history=st.session_state["chat_with_history"])
|
| 50 |
+
|
| 51 |
+
def run(self):
|
| 52 |
+
self.loop.run_until_complete(self.initiate_chat())
|
| 53 |
+
|
| 54 |
+
def check_termination(msg: Dict):
|
| 55 |
+
if "tool_responses" not in msg:
|
| 56 |
+
return False
|
| 57 |
+
json_str = msg["tool_responses"][0]["content"]
|
| 58 |
+
obj = json.loads(json_str)
|
| 59 |
+
return "error" not in obj or obj["error"] is None and obj["reward"] == 1
|
src/usecases/withllamaIndex.py
CHANGED
|
@@ -64,7 +64,7 @@ class WithLlamaIndexMultiAgentChat:
|
|
| 64 |
llm = Groq(model=self.user_input['selected_groq_model'], api_key=st.session_state["GROQ_API_KEY"])
|
| 65 |
llm_70b = Groq(model="llama3-70b-8192")
|
| 66 |
|
| 67 |
-
location_specialist = ReActAgent.from_tools(tools=[wikipedia_tool], llm=llm, max_iterations=
|
| 68 |
verbose=True)
|
| 69 |
|
| 70 |
return location_specialist
|
|
|
|
| 64 |
llm = Groq(model=self.user_input['selected_groq_model'], api_key=st.session_state["GROQ_API_KEY"])
|
| 65 |
llm_70b = Groq(model="llama3-70b-8192")
|
| 66 |
|
| 67 |
+
location_specialist = ReActAgent.from_tools(tools=[wikipedia_tool], llm=llm, max_iterations=1,
|
| 68 |
verbose=True)
|
| 69 |
|
| 70 |
return location_specialist
|
src/utils/sqlexecutor.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import Annotated, Dict
|
| 4 |
+
|
| 5 |
+
from spider_env import SpiderEnv
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SQLExec:
|
| 9 |
+
def __init__(self, sql_writer, user_proxy):
|
| 10 |
+
self.gym = SpiderEnv(cache_dir='./.cache')
|
| 11 |
+
self.sql_writer = sql_writer
|
| 12 |
+
self.user_proxy = user_proxy
|
| 13 |
+
|
| 14 |
+
def sql_spider(self):
|
| 15 |
+
# %pip install spider-env
|
| 16 |
+
|
| 17 |
+
#gym = SpiderEnv()
|
| 18 |
+
|
| 19 |
+
# Randomly select a question from Spider
|
| 20 |
+
observation, info = self.gym.reset()
|
| 21 |
+
# The natural language question
|
| 22 |
+
question = observation["instruction"]
|
| 23 |
+
print(question)
|
| 24 |
+
# The schema of the corresponding database
|
| 25 |
+
schema = info["schema"]
|
| 26 |
+
print(schema)
|
| 27 |
+
|
| 28 |
+
def sql_exec(self):
|
| 29 |
+
@self.sql_writer.register_for_llm(description="Function for executing SQL query and returning a response")
|
| 30 |
+
@self.user_proxy.register_for_execution()
|
| 31 |
+
def execute_sql(
|
| 32 |
+
reflection: Annotated[str, "Think about what to do"], sql: Annotated[str, "SQL query"]
|
| 33 |
+
) -> Annotated[Dict[str, str], "Dictionary with keys 'result' and 'error'"]:
|
| 34 |
+
observation, reward, _, _, info = self.gym.step(sql)
|
| 35 |
+
error = observation["feedback"]["error"]
|
| 36 |
+
if not error and reward == 0:
|
| 37 |
+
error = "The SQL query returned an incorrect result"
|
| 38 |
+
if error:
|
| 39 |
+
return {
|
| 40 |
+
"error": error,
|
| 41 |
+
"wrong_result": observation["feedback"]["result"],
|
| 42 |
+
"correct_result": info["gold_result"],
|
| 43 |
+
}
|
| 44 |
+
else:
|
| 45 |
+
return {
|
| 46 |
+
"result": observation["feedback"]["result"],
|
| 47 |
+
}
|