genaitiwari commited on
Commit
3f4dbc7
·
1 Parent(s): 40fd4e5

sql spider

Browse files
README.md CHANGED
@@ -1,2 +1,4 @@
1
  # AutogenMultiAgent
2
  Autogen Multiagent
 
 
 
1
  # AutogenMultiAgent
2
  Autogen Multiagent
3
+
4
+
app.py CHANGED
@@ -4,6 +4,7 @@ from configfile import Config
4
  from src.streamlitui.loadui import LoadStreamlitUI
5
  from src.usecases.multiagentschat import MultiAgentChat
6
  from src.usecases.withllamaIndex import WithLlamaIndexMultiAgentChat
 
7
  from src.LLMS.groqllm import GroqLLM
8
 
9
 
@@ -38,3 +39,12 @@ if __name__ == "__main__":
38
  problem=problem,user_input=user_input)
39
  obj_usecases_with_llamaIndex_multichat.run()
40
 
 
 
 
 
 
 
 
 
 
 
4
  from src.streamlitui.loadui import LoadStreamlitUI
5
  from src.usecases.multiagentschat import MultiAgentChat
6
  from src.usecases.withllamaIndex import WithLlamaIndexMultiAgentChat
7
+ from src.usecases.agentchatsqlspider import AgentChatSqlSpider
8
  from src.LLMS.groqllm import GroqLLM
9
 
10
 
 
39
  problem=problem,user_input=user_input)
40
  obj_usecases_with_llamaIndex_multichat.run()
41
 
42
+
43
+ elif user_input['selected_usecase'] == "AgentChat Sql Spider":
44
+ obj_sql_spider = AgentChatSqlSpider(assistant_name="Assistant", user_proxy_name='Userproxy',
45
+ llm_config=llm_config,
46
+ problem=problem)
47
+
48
+ obj_sql_spider.run()
49
+
50
+
configfile.ini CHANGED
@@ -1,6 +1,6 @@
1
  [DEFAULT]
2
  PAGE_TITLE = AUTOGEN IN ACTION
3
  LLM_OPTIONS = Groq, Huggingface
4
- USECASE_OPTIONS = MultiAgent Chat, RAG Chat, With LLamaIndex Tool, Teachable Agent, With Langchain
5
  GROQ_MODEL_OPTIONS = mixtral-8x7b-32768, llama3-8b-8192, llama3-70b-8192, gemma-7b-i
6
 
 
1
  [DEFAULT]
2
  PAGE_TITLE = AUTOGEN IN ACTION
3
  LLM_OPTIONS = Groq, Huggingface
4
+ USECASE_OPTIONS = MultiAgent Chat, RAG Chat, With LLamaIndex Tool, AgentChat Sql Spider
5
  GROQ_MODEL_OPTIONS = mixtral-8x7b-32768, llama3-8b-8192, llama3-70b-8192, gemma-7b-i
6
 
requirements.txt CHANGED
@@ -5,4 +5,5 @@ llama-index
5
  llama-index-tools-wikipedia
6
  llama-index-readers-wikipedia
7
  wikipedia
8
- llama-index-llms-groq
 
 
5
  llama-index-tools-wikipedia
6
  llama-index-readers-wikipedia
7
  wikipedia
8
+ llama-index-llms-groq
9
+ spider-env
src/agents/sqlspideragent.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from autogen import ConversableAgent
2
+ import streamlit as st
3
+
4
+
5
+ class TrackableSpiderConversableAgent(ConversableAgent):
6
+ def _process_received_message(self, message, sender, silent):
7
+ with st.chat_message(sender.name):
8
+ st.write(message)
9
+ return super()._process_received_message(message, sender, silent)
src/usecases/agentchatsqlspider.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ from typing import Dict
4
+
5
+ from autogen import ConversableAgent
6
+
7
+ from src.agents.sqlspideragent import TrackableSpiderConversableAgent
8
+ from src.agents.userproxyagent import TrackableUserProxyAgent
9
+ import streamlit as st
10
+ import os
11
+
12
+ from src.utils.sqlexecutor import SQLExec
13
+
14
+ os.environ["AUTOGEN_USE_DOCKER"] = "False"
15
+
16
+
17
+ class AgentChatSqlSpider:
18
+ def __init__(self, assistant_name, user_proxy_name, llm_config, problem):
19
+ self.schema = None
20
+ self.question = None
21
+ self.sql_writer = TrackableSpiderConversableAgent(
22
+ "sql_writer",
23
+ llm_config=llm_config,
24
+ system_message="You are good at writing SQL queries. Always respond with a function call to execute_sql().",
25
+ is_termination_msg=self.check_termination,
26
+ )
27
+ self.user_proxy = TrackableUserProxyAgent(name=user_proxy_name,
28
+ system_message="You are Admin",
29
+ human_input_mode="NEVER",
30
+ llm_config=llm_config,
31
+ code_execution_config=False,
32
+ is_termination_msg=lambda x: x.get("content", "").strip().endswith(
33
+ "TERMINATE"))
34
+
35
+ self.problem = problem
36
+ self.loop = asyncio.new_event_loop()
37
+ asyncio.set_event_loop(self.loop)
38
+
39
+ async def initiate_chat(self):
40
+ message = f"""Below is the schema for a SQL database:
41
+ {self.schema}
42
+ Generate a SQL query to answer the following question:
43
+ {self.question}
44
+ """
45
+
46
+ obj_SQLExec = SQLExec(self.sql_writer,self.user_proxy)
47
+
48
+ await self.user_proxy.a_initiate_chat(self.sql_writer, message=message,
49
+ clear_history=st.session_state["chat_with_history"])
50
+
51
+ def run(self):
52
+ self.loop.run_until_complete(self.initiate_chat())
53
+
54
+ def check_termination(msg: Dict):
55
+ if "tool_responses" not in msg:
56
+ return False
57
+ json_str = msg["tool_responses"][0]["content"]
58
+ obj = json.loads(json_str)
59
+ return "error" not in obj or obj["error"] is None and obj["reward"] == 1
src/usecases/withllamaIndex.py CHANGED
@@ -64,7 +64,7 @@ class WithLlamaIndexMultiAgentChat:
64
  llm = Groq(model=self.user_input['selected_groq_model'], api_key=st.session_state["GROQ_API_KEY"])
65
  llm_70b = Groq(model="llama3-70b-8192")
66
 
67
- location_specialist = ReActAgent.from_tools(tools=[wikipedia_tool], llm=llm, max_iterations=3,
68
  verbose=True)
69
 
70
  return location_specialist
 
64
  llm = Groq(model=self.user_input['selected_groq_model'], api_key=st.session_state["GROQ_API_KEY"])
65
  llm_70b = Groq(model="llama3-70b-8192")
66
 
67
+ location_specialist = ReActAgent.from_tools(tools=[wikipedia_tool], llm=llm, max_iterations=1,
68
  verbose=True)
69
 
70
  return location_specialist
src/utils/sqlexecutor.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Annotated, Dict
4
+
5
+ from spider_env import SpiderEnv
6
+
7
+
8
+ class SQLExec:
9
+ def __init__(self, sql_writer, user_proxy):
10
+ self.gym = SpiderEnv(cache_dir='./.cache')
11
+ self.sql_writer = sql_writer
12
+ self.user_proxy = user_proxy
13
+
14
+ def sql_spider(self):
15
+ # %pip install spider-env
16
+
17
+ #gym = SpiderEnv()
18
+
19
+ # Randomly select a question from Spider
20
+ observation, info = self.gym.reset()
21
+ # The natural language question
22
+ question = observation["instruction"]
23
+ print(question)
24
+ # The schema of the corresponding database
25
+ schema = info["schema"]
26
+ print(schema)
27
+
28
+ def sql_exec(self):
29
+ @self.sql_writer.register_for_llm(description="Function for executing SQL query and returning a response")
30
+ @self.user_proxy.register_for_execution()
31
+ def execute_sql(
32
+ reflection: Annotated[str, "Think about what to do"], sql: Annotated[str, "SQL query"]
33
+ ) -> Annotated[Dict[str, str], "Dictionary with keys 'result' and 'error'"]:
34
+ observation, reward, _, _, info = self.gym.step(sql)
35
+ error = observation["feedback"]["error"]
36
+ if not error and reward == 0:
37
+ error = "The SQL query returned an incorrect result"
38
+ if error:
39
+ return {
40
+ "error": error,
41
+ "wrong_result": observation["feedback"]["result"],
42
+ "correct_result": info["gold_result"],
43
+ }
44
+ else:
45
+ return {
46
+ "result": observation["feedback"]["result"],
47
+ }