File size: 4,438 Bytes
c6a7187
0a83688
 
 
c6a7187
 
 
 
 
 
 
c24cad9
c6a7187
 
 
 
 
 
 
 
 
 
 
 
 
232ffed
c6a7187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24cad9
c6a7187
 
c24cad9
c6a7187
 
 
 
 
 
 
 
 
c24cad9
c6a7187
 
 
 
 
 
 
 
232ffed
c24cad9
 
232ffed
 
c24cad9
c6a7187
55e25b1
b3a8383
c24cad9
21fac93
 
c6a7187
c24cad9
232ffed
 
c6a7187
 
 
 
 
 
 
 
 
 
82c4620
 
c6a7187
08d3fb3
 
 
 
 
 
 
 
c24cad9
c6a7187
 
232ffed
c6a7187
 
 
 
c038587
 
833794f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_core.messages import HumanMessage, SystemMessage
import tempfile

# ------------------- Environment Variable Setup -------------------
# Fetch API keys from environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")

# Verify if API keys are set
if not openai_api_key:
    raise ValueError("Missing required environment variable: OPENAI_API_KEY")
if not tavily_api_key:
    raise ValueError("Missing required environment variable: TAVILY_API_KEY")

# ------------------- Tool Definitions -------------------
# Tavily Search Tool
tavily_tool = TavilySearchResults(max_results=5)

def multiply(a: int, b: int) -> int:
    """Multiply two numbers."""
    return a * b

def add(a: int, b: int) -> int:
    """Add two numbers."""
    return a + b

def divide(a: int, b: int) -> float:
    """Divide two numbers."""
    if b == 0:
        raise ValueError("Division by zero is not allowed.")
    return a / b

# Combine tools
tools = [add, multiply, divide, tavily_tool]

# ------------------- LLM and System Message Setup -------------------
llm = ChatOpenAI(model="gpt-4o-mini")
llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)
sys_msg = SystemMessage(content="You are a helpful assistant tasked with performing arithmetic and search on a set of inputs.")

# ------------------- LangGraph Workflow -------------------
def assistant(state: MessagesState):
    """Assistant node to invoke LLM with tools."""
    return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}

# Define the graph
app_graph = StateGraph(MessagesState)
app_graph.add_node("assistant", assistant)
app_graph.add_node("tools", ToolNode(tools))
app_graph.add_edge(START, "assistant")
app_graph.add_conditional_edges("assistant", tools_condition)
app_graph.add_edge("tools", "assistant")
react_graph = app_graph.compile()

# Save graph visualization as an image
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
    graph = react_graph.get_graph(xray=True)
    tmpfile.write(graph.draw_mermaid_png())  # Write binary image data to file
    graph_image_path = tmpfile.name

# ------------------- Streamlit Interface -------------------
st.title("ReAct Agent for Arithmetic Ops & Web Search")

# Display the workflow graph
#st.header("LangGraph Workflow Visualization")
st.image(graph_image_path, caption="Workflow Visualization")

# Prompt user for inputs
user_question = st.text_area("Enter your question:",
                             placeholder="Example: 'Add 3 and 4. Multiply the result by 2. Divide it by 5.'")

if st.button("Submit"):
    if not user_question.strip():
        st.error("Please enter a valid question.")
        st.stop()
    
    st.info("Processing your question...")
    messages = [HumanMessage(content=user_question)]
    response = react_graph.invoke({"messages": messages})
    
    # Display results step-by-step
    st.subheader("Response:")
    for m in response['messages']:
        if hasattr(m, "content") and m.content:  # Display human and assistant messages
            st.write("**AI Message:**", m.content)
        if hasattr(m, "tool_calls") and m.tool_calls:  # Display tool call steps
            for tool_call in m.tool_calls:
                st.write(f"**Tool Call:** `{tool_call['name']}`")
                st.json(tool_call['args'])  # Display tool arguments in JSON
                if "output" in tool_call:  # Handle tool outputs if available
                    st.write("**Tool Output:**", tool_call['output'])

    st.success("Processing complete!")

# Example Placeholder Suggestions
st.sidebar.subheader("Example Questions")
st.sidebar.write("- Add 3 and 4. Multiply the result by 2. Divide it by 5.")
st.sidebar.write("- Tell me how many centuries Virat Kohli scored.")
st.sidebar.write("- Search for the tallest building in the world.")

st.sidebar.title("References")
st.sidebar.markdown("1. [LangGraph ReAct Agents](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_9_ReAct_Agents.ipynb)")