DrishtiSharma commited on
Commit
1af1ebb
·
verified ·
1 Parent(s): 21fac93

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +21 -23
interim.py CHANGED
@@ -1,9 +1,7 @@
1
- #fix workflow
2
  import os
3
  import streamlit as st
4
  import pandas as pd
5
  import matplotlib.pyplot as plt
6
- import networkx as nx
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_openai import ChatOpenAI
9
  from langgraph.graph import MessagesState
@@ -11,6 +9,7 @@ from langgraph.graph import START, StateGraph
11
  from langgraph.prebuilt import tools_condition
12
  from langgraph.prebuilt import ToolNode
13
  from langchain_core.messages import HumanMessage, SystemMessage
 
14
 
15
  # ------------------- Environment Variable Setup -------------------
16
  # Fetch API keys from environment variables
@@ -24,6 +23,7 @@ if not tavily_api_key:
24
  raise ValueError("Missing required environment variable: TAVILY_API_KEY")
25
 
26
  # ------------------- Tool Definitions -------------------
 
27
  tavily_tool = TavilySearchResults(max_results=5)
28
 
29
  def multiply(a: int, b: int) -> int:
@@ -40,9 +40,10 @@ def divide(a: int, b: int) -> float:
40
  raise ValueError("Division by zero is not allowed.")
41
  return a / b
42
 
 
43
  tools = [add, multiply, divide, tavily_tool]
44
 
45
- # ------------------- LLM Setup -------------------
46
  llm = ChatOpenAI(model="gpt-4o-mini")
47
  llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)
48
  sys_msg = SystemMessage(content="You are a helpful assistant tasked with performing arithmetic and search on a set of inputs.")
@@ -52,6 +53,7 @@ def assistant(state: MessagesState):
52
  """Assistant node to invoke LLM with tools."""
53
  return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
54
 
 
55
  app_graph = StateGraph(MessagesState)
56
  app_graph.add_node("assistant", assistant)
57
  app_graph.add_node("tools", ToolNode(tools))
@@ -60,28 +62,22 @@ app_graph.add_conditional_edges("assistant", tools_condition)
60
  app_graph.add_edge("tools", "assistant")
61
  react_graph = app_graph.compile()
62
 
63
- # ------------------- Streamlit Interface -------------------
64
- st.title("ReAct Agent")
65
-
66
- # Display the workflow graph using NetworkX
67
- st.header("LangGraph Workflow Visualization")
68
 
69
- G = nx.DiGraph()
70
- G.add_edge("START", "assistant")
71
- G.add_edge("assistant", "tools", label="tools_condition")
72
- G.add_edge("tools", "assistant", label="loop back")
73
 
74
- plt.figure(figsize=(10, 6))
75
- pos = nx.spring_layout(G, seed=42)
76
- nx.draw(G, pos, with_labels=True, node_size=3000, node_color="lightblue", font_size=10, font_weight="bold")
77
- nx.draw_networkx_edge_labels(G, pos, edge_labels={
78
- ("assistant", "tools"): "tools_condition",
79
- ("tools", "assistant"): "loop back"
80
- }, font_color="red")
81
- st.pyplot(plt)
82
 
83
- # User input
84
- user_question = st.text_area("Enter your question:", placeholder="Example: 'Add 3 and 4. Multiply the result by 2. Divide it by 5.'")
 
85
 
86
  if st.button("Submit"):
87
  if not user_question.strip():
@@ -92,12 +88,14 @@ if st.button("Submit"):
92
  messages = [HumanMessage(content=user_question)]
93
  response = react_graph.invoke({"messages": messages})
94
 
 
95
  st.subheader("Responses")
96
  for m in response['messages']:
97
  st.write(m.content)
 
98
  st.success("Processing complete!")
99
 
100
- # Example Questions
101
  st.sidebar.subheader("Example Questions")
102
  st.sidebar.write("- Add 3 and 4. Multiply the result by 2. Divide it by 5.")
103
  st.sidebar.write("- Tell me how many centuries Virat Kohli scored.")
 
 
1
  import os
2
  import streamlit as st
3
  import pandas as pd
4
  import matplotlib.pyplot as plt
 
5
  from langchain_community.tools.tavily_search import TavilySearchResults
6
  from langchain_openai import ChatOpenAI
7
  from langgraph.graph import MessagesState
 
9
  from langgraph.prebuilt import tools_condition
10
  from langgraph.prebuilt import ToolNode
11
  from langchain_core.messages import HumanMessage, SystemMessage
12
+ import tempfile
13
 
14
  # ------------------- Environment Variable Setup -------------------
15
  # Fetch API keys from environment variables
 
23
  raise ValueError("Missing required environment variable: TAVILY_API_KEY")
24
 
25
  # ------------------- Tool Definitions -------------------
26
+ # Tavily Search Tool
27
  tavily_tool = TavilySearchResults(max_results=5)
28
 
29
  def multiply(a: int, b: int) -> int:
 
40
  raise ValueError("Division by zero is not allowed.")
41
  return a / b
42
 
43
+ # Combine tools
44
  tools = [add, multiply, divide, tavily_tool]
45
 
46
+ # ------------------- LLM and System Message Setup -------------------
47
  llm = ChatOpenAI(model="gpt-4o-mini")
48
  llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)
49
  sys_msg = SystemMessage(content="You are a helpful assistant tasked with performing arithmetic and search on a set of inputs.")
 
53
  """Assistant node to invoke LLM with tools."""
54
  return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
55
 
56
+ # Define the graph
57
  app_graph = StateGraph(MessagesState)
58
  app_graph.add_node("assistant", assistant)
59
  app_graph.add_node("tools", ToolNode(tools))
 
62
  app_graph.add_edge("tools", "assistant")
63
  react_graph = app_graph.compile()
64
 
65
+ # Save graph visualization as an image
66
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
67
+ graph = react_graph.get_graph(xray=True)
68
+ tmpfile.write(graph.draw_mermaid_png()) # Write binary image data to file
69
+ graph_image_path = tmpfile.name
70
 
71
+ # ------------------- Streamlit Interface -------------------
72
+ st.title("ReAct Agent for Arithmetic Ops & Web Search")
 
 
73
 
74
+ # Display the workflow graph
75
+ #st.header("LangGraph Workflow Visualization")
76
+ st.image(graph_image_path, caption="Workflow Visualization")
 
 
 
 
 
77
 
78
+ # Prompt user for inputs
79
+ user_question = st.text_area("Enter your question:",
80
+ placeholder="Example: 'Add 3 and 4. Multiply the result by 2. Divide it by 5.'")
81
 
82
  if st.button("Submit"):
83
  if not user_question.strip():
 
88
  messages = [HumanMessage(content=user_question)]
89
  response = react_graph.invoke({"messages": messages})
90
 
91
+ # Display results
92
  st.subheader("Responses")
93
  for m in response['messages']:
94
  st.write(m.content)
95
+
96
  st.success("Processing complete!")
97
 
98
+ # Example Placeholder Suggestions
99
  st.sidebar.subheader("Example Questions")
100
  st.sidebar.write("- Add 3 and 4. Multiply the result by 2. Divide it by 5.")
101
  st.sidebar.write("- Tell me how many centuries Virat Kohli scored.")