akileshjayakumar commited on
Commit
9a4dc5f
1 Parent(s): e4124e6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import logging
4
+ from dotenv import load_dotenv
5
+ import json
6
+ import gradio as gr
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_core.messages import AIMessage, HumanMessage
9
+ from typing_extensions import TypedDict
10
+ from typing import Annotated
11
+ from langchain_core.messages import ToolMessage
12
+ from langgraph.graph import StateGraph, START, END
13
+ from langgraph.graph.message import add_messages
14
+ from langchain_openai import ChatOpenAI as Chat
15
+
16
+ from uuid import uuid4
17
+
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format="%(asctime)s [%(levelname)s] %(message)s",
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Load environment variables
25
+ load_dotenv()
26
+
27
+ # LangGraph setup
28
+ openai_api_key = os.getenv("OPENAI_API_KEY")
29
+ model = os.getenv("OPENAI_MODEL", "gpt-4")
30
+ temperature = float(os.getenv("OPENAI_TEMPERATURE", 0))
31
+
32
+ web_search = TavilySearchResults(max_results=2)
33
+ tools = [web_search]
34
+
35
+
36
+ class State(TypedDict):
37
+ messages: Annotated[list, add_messages]
38
+
39
+
40
+ graph_builder = StateGraph(State)
41
+
42
+
43
+ llm = Chat(
44
+ openai_api_key=openai_api_key,
45
+ model=model,
46
+ temperature=temperature
47
+ )
48
+ llm_with_tools = llm.bind_tools(tools)
49
+
50
+
51
+ def chatbot(state: State):
52
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
53
+
54
+
55
+ graph_builder.add_node("chatbot", chatbot)
56
+
57
+
58
+ class BasicToolNode:
59
+ """A node that runs the tools requested in the last AIMessage."""
60
+
61
+ def __init__(self, tools: list) -> None:
62
+ self.tools_by_name = {tool.name: tool for tool in tools}
63
+
64
+ def __call__(self, inputs: dict):
65
+ if messages := inputs.get("messages", []):
66
+ message = messages[-1]
67
+ else:
68
+ raise ValueError("No message found in input")
69
+ outputs = []
70
+ for tool_call in message.tool_calls:
71
+ tool_result = self.tools_by_name[tool_call["name"]].invoke(
72
+ tool_call["args"]
73
+ )
74
+ outputs.append(
75
+ ToolMessage(
76
+ content=json.dumps(tool_result),
77
+ name=tool_call["name"],
78
+ tool_call_id=tool_call["id"],
79
+ )
80
+ )
81
+ return {"messages": outputs}
82
+
83
+
84
+ def route_tools(
85
+ state: State,
86
+ ):
87
+ """
88
+ Use in the conditional_edge to route to the ToolNode if the last message
89
+ has tool calls. Otherwise, route to the end.
90
+ """
91
+ if isinstance(state, list):
92
+ ai_message = state[-1]
93
+ elif messages := state.get("messages", []):
94
+ ai_message = messages[-1]
95
+ else:
96
+ raise ValueError(
97
+ f"No messages found in input state to tool_edge: {state}")
98
+ if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
99
+ return "tools"
100
+ return END
101
+
102
+
103
+ tool_node = BasicToolNode(tools=[web_search])
104
+ graph_builder.add_node("tools", tool_node)
105
+ graph_builder.add_conditional_edges(
106
+ "chatbot",
107
+ route_tools,
108
+ # The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node
109
+ # It defaults to the identity function, but if you
110
+ # want to use a node named something else apart from "tools",
111
+ # You can update the value of the dictionary to something else
112
+ # e.g., "tools": "my_tools"
113
+ {"tools": "tools", END: END},
114
+ )
115
+ # Any time a tool is called, we return to the chatbot to decide the next step
116
+ graph_builder.add_edge("tools", "chatbot")
117
+ graph_builder.add_edge(START, "chatbot")
118
+
119
+
120
+ def chatbot(state: State):
121
+ if not state["messages"]:
122
+ logger.info(
123
+ "Received an empty message list. Returning default response.")
124
+ return {"messages": [AIMessage(content="Hello! How can I assist you today?")]}
125
+
126
+ # Check for tool call in the last message
127
+ last_message = state["messages"][-1]
128
+ if not getattr(last_message, "tool_calls", None):
129
+ logger.info(
130
+ "No tool call in the last message. Proceeding without tool invocation.")
131
+ response = llm.invoke(state["messages"])
132
+ else:
133
+ logger.info(
134
+ "Tool call detected in the last message. Invoking tool response.")
135
+ response = llm_with_tools.invoke(state["messages"])
136
+
137
+ # Ensure the response is wrapped as AIMessage if it's not already
138
+ if not isinstance(response, AIMessage):
139
+ response = AIMessage(content=response.content)
140
+
141
+ return {"messages": [response]}
142
+
143
+
144
+ graph = graph_builder.compile()
145
+
146
+
147
+ def gradio_chat(message, history):
148
+ try:
149
+ if not isinstance(message, str):
150
+ message = str(message)
151
+
152
+ config = {
153
+ "configurable": {"thread_id": "1"},
154
+ "checkpoint_id": str(uuid4()),
155
+ "recursion_limit": 300
156
+ }
157
+
158
+ # Format the user message correctly as a HumanMessage
159
+ formatted_message = [HumanMessage(content=message)]
160
+ response = graph.invoke(
161
+ {
162
+ "messages": formatted_message
163
+ },
164
+ config=config,
165
+ stream_mode="values"
166
+ )
167
+
168
+ # Extract assistant messages and ensure they are AIMessage type
169
+ assistant_messages = [
170
+ msg for msg in response["messages"] if isinstance(msg, AIMessage)
171
+ ]
172
+ last_message = assistant_messages[-1] if assistant_messages else AIMessage(
173
+ content="No response generated.")
174
+
175
+ logger.info("Sending response back to Gradio interface.")
176
+ return last_message.content
177
+ except Exception as e:
178
+ logger.error(f"Error encountered in gradio_chat: {e}")
179
+ return "Sorry, I encountered an error. Please try again."
180
+
181
+
182
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
183
+ chatbot = gr.ChatInterface(
184
+ chatbot=gr.Chatbot(height=800, render=False),
185
+ fn=gradio_chat,
186
+ multimodal=False,
187
+ title="LangGraph Agentic Chatbot",
188
+ examples=[
189
+ "What's the weather like today?",
190
+ "Show me the Movie Trailer for Doctor Strange.",
191
+ "Give me the latest news on the COVID-19 pandemic.",
192
+ "What are the latest updtaes on NVIDIA's new GPU?",
193
+ ],
194
+ )
195
+
196
+ if __name__ == "__main__":
197
+ demo.launch()