bstraehle commited on
Commit
d3f3fad
1 Parent(s): dd92614

Update rag_langgraph.py

Browse files
Files changed (1) hide show
  1. rag_langgraph.py +84 -109
rag_langgraph.py CHANGED
@@ -1,12 +1,20 @@
1
- import os
2
-
3
  from langchain_core.messages import (
 
4
  BaseMessage,
5
  ToolMessage,
6
  HumanMessage,
7
  )
8
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
 
 
 
 
9
  from langgraph.graph import END, StateGraph
 
 
 
10
 
11
  def create_agent(llm, tools, system_message: str):
12
  """Create an agent."""
@@ -29,21 +37,8 @@ def create_agent(llm, tools, system_message: str):
29
  prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
30
  return prompt | llm.bind_tools(tools)
31
 
32
- from langchain_core.tools import tool
33
- from typing import Annotated
34
- from langchain_experimental.utilities import PythonREPL
35
- from langchain_community.tools.tavily_search import TavilySearchResults
36
-
37
- tavily_tool = TavilySearchResults(max_results=5)
38
-
39
- # Warning: This executes code locally, which can be unsafe when not sandboxed
40
-
41
- repl = PythonREPL()
42
-
43
  @tool
44
- def python_repl(
45
- code: Annotated[str, "The python code to execute to generate your chart."]
46
- ):
47
  """Use this to execute python code. If you want to see the output of a value,
48
  you should print it out with `print(...)`. This is visible to the user."""
49
  try:
@@ -55,23 +50,12 @@ def python_repl(
55
  result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
56
  )
57
 
58
- import operator
59
- from typing import Annotated, Sequence, TypedDict
60
-
61
- from langchain_openai import ChatOpenAI
62
- from typing_extensions import TypedDict
63
-
64
-
65
  # This defines the object that is passed between each node
66
  # in the graph. We will create different nodes for each agent and tool
67
  class AgentState(TypedDict):
68
  messages: Annotated[Sequence[BaseMessage], operator.add]
69
  sender: str
70
 
71
- import functools
72
- from langchain_core.messages import AIMessage
73
-
74
-
75
  # Helper function to create a node for a given agent
76
  def agent_node(state, agent, name):
77
  result = agent.invoke(state)
@@ -87,32 +71,6 @@ def agent_node(state, agent, name):
87
  "sender": name,
88
  }
89
 
90
- llm = ChatOpenAI(model="gpt-4-1106-preview")
91
-
92
- # Research agent and node
93
- research_agent = create_agent(
94
- llm,
95
- [tavily_tool],
96
- system_message="You should provide accurate data for the chart_generator to use.",
97
- )
98
- research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
99
-
100
- # chart_generator
101
- chart_agent = create_agent(
102
- llm,
103
- [python_repl],
104
- system_message="Any charts you display will be visible by the user.",
105
- )
106
- chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator")
107
-
108
- from langgraph.prebuilt import ToolNode
109
-
110
- tools = [tavily_tool, python_repl]
111
- tool_node = ToolNode(tools)
112
-
113
- # Either agent can decide to end
114
- from typing import Literal
115
-
116
  def router(state) -> Literal["call_tool", "__end__", "continue"]:
117
  # This is the router
118
  messages = state["messages"]
@@ -125,62 +83,79 @@ def router(state) -> Literal["call_tool", "__end__", "continue"]:
125
  return "__end__"
126
  return "continue"
127
 
128
- workflow = StateGraph(AgentState)
129
-
130
- workflow.add_node("Researcher", research_node)
131
- workflow.add_node("chart_generator", chart_node)
132
- workflow.add_node("call_tool", tool_node)
133
-
134
- workflow.add_conditional_edges(
135
- "Researcher",
136
- router,
137
- {"continue": "chart_generator", "call_tool": "call_tool", "__end__": END},
138
- )
139
- workflow.add_conditional_edges(
140
- "chart_generator",
141
- router,
142
- {"continue": "Researcher", "call_tool": "call_tool", "__end__": END},
143
- )
144
-
145
- workflow.add_conditional_edges(
146
- "call_tool",
147
- # Each agent node updates the 'sender' field
148
- # the tool calling node does not, meaning
149
- # this edge will route back to the original agent
150
- # who invoked the tool
151
- lambda x: x["sender"],
152
- {
153
- "Researcher": "Researcher",
154
- "chart_generator": "chart_generator",
155
- },
156
- )
157
- workflow.set_entry_point("Researcher")
158
- graph = workflow.compile()
159
-
160
- from IPython.display import Image, display
161
 
162
- try:
163
- display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
164
- except:
165
- # This requires some extra dependencies and is optional
166
- pass
167
 
168
- events = graph.stream(
169
- {
170
- "messages": [
171
- HumanMessage(
172
- content="Fetch the UK's GDP over the past 5 years,"
173
- " then draw a line graph of it."
174
- " Once you code it up, finish."
175
- )
176
- ],
177
- },
178
- # Maximum number of steps to take in the graph
179
- {"recursion_limit": 150},
180
- )
181
- for s in events:
182
- print(s)
183
- print("----")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- def run_multi_agent():
186
- return "DONE"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools, operator
2
+ from IPython.display import Image, display
3
  from langchain_core.messages import (
4
+ AIMessage,
5
  BaseMessage,
6
  ToolMessage,
7
  HumanMessage,
8
  )
9
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
10
+ from langchain_core.tools import tool
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+ from langchain_experimental.utilities import PythonREPL
13
+ from langchain_openai import ChatOpenAI
14
  from langgraph.graph import END, StateGraph
15
+ from langgraph.prebuilt import ToolNode
16
+ from typing import Annotated, Literal, Sequence, TypedDict
17
+ from typing_extensions import TypedDict
18
 
19
  def create_agent(llm, tools, system_message: str):
20
  """Create an agent."""
 
37
  prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
38
  return prompt | llm.bind_tools(tools)
39
 
 
 
 
 
 
 
 
 
 
 
 
40
  @tool
41
+ def python_repl(code: Annotated[str, "The python code to execute to generate your chart."]):
 
 
42
  """Use this to execute python code. If you want to see the output of a value,
43
  you should print it out with `print(...)`. This is visible to the user."""
44
  try:
 
50
  result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
51
  )
52
 
 
 
 
 
 
 
 
53
  # This defines the object that is passed between each node
54
  # in the graph. We will create different nodes for each agent and tool
55
  class AgentState(TypedDict):
56
  messages: Annotated[Sequence[BaseMessage], operator.add]
57
  sender: str
58
 
 
 
 
 
59
  # Helper function to create a node for a given agent
60
  def agent_node(state, agent, name):
61
  result = agent.invoke(state)
 
71
  "sender": name,
72
  }
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def router(state) -> Literal["call_tool", "__end__", "continue"]:
75
  # This is the router
76
  messages = state["messages"]
 
83
  return "__end__"
84
  return "continue"
85
 
86
+ def run_multi_agent(prompt):
87
+ tavily_tool = TavilySearchResults(max_results=5)
88
+ repl = PythonREPL()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ llm = ChatOpenAI(model="gpt-4o")
 
 
 
 
91
 
92
+ # Research agent and node
93
+ research_agent = create_agent(
94
+ llm,
95
+ [tavily_tool],
96
+ system_message="You should provide accurate data for the chart_generator to use.",
97
+ )
98
+ research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
99
+
100
+ # chart_generator
101
+ chart_agent = create_agent(
102
+ llm,
103
+ [python_repl],
104
+ system_message="Any charts you display will be visible by the user.",
105
+ )
106
+ chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator")
107
+
108
+ tools = [tavily_tool, python_repl]
109
+ tool_node = ToolNode(tools)
110
+
111
+ workflow = StateGraph(AgentState)
112
+
113
+ workflow.add_node("Researcher", research_node)
114
+ workflow.add_node("chart_generator", chart_node)
115
+ workflow.add_node("call_tool", tool_node)
116
+
117
+ workflow.add_conditional_edges(
118
+ "Researcher",
119
+ router,
120
+ {"continue": "chart_generator", "call_tool": "call_tool", "__end__": END},
121
+ )
122
+ workflow.add_conditional_edges(
123
+ "chart_generator",
124
+ router,
125
+ {"continue": "Researcher", "call_tool": "call_tool", "__end__": END},
126
+ )
127
+
128
+ workflow.add_conditional_edges(
129
+ "call_tool",
130
+ # Each agent node updates the 'sender' field
131
+ # the tool calling node does not, meaning
132
+ # this edge will route back to the original agent
133
+ # who invoked the tool
134
+ lambda x: x["sender"],
135
+ {
136
+ "Researcher": "Researcher",
137
+ "chart_generator": "chart_generator",
138
+ },
139
+ )
140
+ workflow.set_entry_point("Researcher")
141
+ graph = workflow.compile()
142
 
143
+ try:
144
+ display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
145
+ except:
146
+ # This requires some extra dependencies and is optional
147
+ pass
148
+
149
+ events = graph.stream(
150
+ {
151
+ "messages": [
152
+ HumanMessage(
153
+ content=prompt
154
+ )
155
+ ],
156
+ },
157
+ # Maximum number of steps to take in the graph
158
+ {"recursion_limit": 150},
159
+ )
160
+ for s in events:
161
+ return s