cstixx commited on
Commit
b63562f
1 Parent(s): c77d99b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +312 -0
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Callable, List, Union, TypedDict
3
+
4
+ import gradio as gr
5
+ from langchain_google_genai import ChatGoogleGenerativeAI
6
+ from langchain_community.tools import DuckDuckGoSearchRun
7
+ from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
8
+ from langchain_core.pydantic_v1 import BaseModel
9
+
10
+ from langgraph.checkpoint.sqlite import SqliteSaver
11
+
12
+ class Node:
13
+ def __init__(self, id: str, function: Callable):
14
+ """
15
+ Initialize a Node with an ID and a function to execute.
16
+
17
+ Args:
18
+ id (str): The unique identifier for the node.
19
+ function (Callable): The function to execute for this node.
20
+ """
21
+ self.id = id
22
+ self.function = function
23
+
24
+ def execute(self, state: Dict) -> Dict:
25
+ """
26
+ Execute the node's function with the given state.
27
+
28
+ Args:
29
+ state (Dict): The current state of the chatbot.
30
+
31
+ Returns:
32
+ Dict: The updated state after executing the node's function.
33
+ """
34
+ return self.function(state)
35
+
36
+ class Edge:
37
+ def __init__(self, source: str, target: str, condition: Callable[[Dict], bool] = None):
38
+ """
39
+ Initialize an Edge with a source node, target node, and an optional condition.
40
+
41
+ Args:
42
+ source (str): The ID of the source node.
43
+ target (str): The ID of the target node.
44
+ condition (Callable[[Dict], bool], optional): A condition function that determines if the edge should be traversed.
45
+ """
46
+ self.source = source
47
+ self.target = target
48
+ self.condition = condition
49
+
50
+ def is_active(self, state: Dict) -> bool:
51
+ """
52
+ Check if the edge is active based on the given state.
53
+
54
+ Args:
55
+ state (Dict): The current state of the chatbot.
56
+
57
+ Returns:
58
+ bool: True if the edge is active, False otherwise.
59
+ """
60
+ if self.condition:
61
+ return self.condition(state)
62
+ return True
63
+
64
+ class Graph:
65
+ def __init__(self):
66
+ """
67
+ Initialize an empty Graph with dictionaries to hold nodes and edges.
68
+ """
69
+ self.nodes = {}
70
+ self.edges = {}
71
+
72
+ def add_node(self, node: Node):
73
+ """
74
+ Add a node to the graph.
75
+
76
+ Args:
77
+ node (Node): The node to add.
78
+ """
79
+ self.nodes[node.id] = node
80
+
81
+ def add_edge(self, edge: Edge):
82
+ """
83
+ Add an edge to the graph.
84
+
85
+ Args:
86
+ edge (Edge): The edge to add.
87
+ """
88
+ if edge.source not in self.edges:
89
+ self.edges[edge.source] = []
90
+ self.edges[edge.source].append(edge)
91
+
92
+ def get_next_node(self, current_node_id: str, state: Dict) -> Union[Node, None]:
93
+ """
94
+ Get the next node to traverse to based on the current state.
95
+
96
+ Args:
97
+ current_node_id (str): The ID of the current node.
98
+ state (Dict): The current state of the chatbot.
99
+
100
+ Returns:
101
+ Union[Node, None]: The next node to traverse to, or None if no valid edge is found.
102
+ """
103
+ if current_node_id in self.edges:
104
+ for edge in self.edges[current_node_id]:
105
+ if edge.is_active(state):
106
+ return self.nodes[edge.target]
107
+ return None
108
+
109
+ def execute(self, start_node_id: str, state: Dict) -> Dict:
110
+ """
111
+ Execute the graph starting from the specified node.
112
+
113
+ Args:
114
+ start_node_id (str): The ID of the starting node.
115
+ state (Dict): The initial state of the chatbot.
116
+
117
+ Returns:
118
+ Dict: The final state after traversing the graph.
119
+ """
120
+ current_node = self.nodes.get(start_node_id)
121
+ while current_node:
122
+ state = current_node.execute(state)
123
+ next_node = self.get_next_node(current_node.id, state)
124
+ if next_node is None:
125
+ break
126
+ current_node = next_node
127
+ return state
128
+
129
+ class State(TypedDict):
130
+ """
131
+ Define the State type using TypedDict to specify the structure of the state dictionary.
132
+ """
133
+ messages: List[Union[Dict, BaseMessage, ToolMessage]]
134
+ ask_human: bool
135
+
136
+ class RequestAssistance(BaseModel):
137
+ """
138
+ Define RequestAssistance model inheriting from BaseModel for schema validation.
139
+ """
140
+ request: str
141
+
142
+ def chatbot_function(state: State) -> State:
143
+ """
144
+ Chatbot function definition which processes the current state and generates a response.
145
+
146
+ Args:
147
+ state (State): The current state of the chatbot including messages and ask_human flag.
148
+
149
+ Returns:
150
+ State: The updated state after processing the response.
151
+ """
152
+ response = llm_with_tools.invoke(state["messages"])
153
+ ask_human = False
154
+
155
+ if response.tool_calls:
156
+ tool_name = response.tool_calls[0].get("name")
157
+ if tool_name == "RequestAssistance":
158
+ ask_human = True
159
+ else:
160
+ tool_response = DuckDuck_tool.run(response.tool_calls[0]["args"]["query"])
161
+ response.content = tool_response # Update AI message content with tool response
162
+
163
+ new_state = {"messages": state["messages"] + [response], "ask_human": ask_human}
164
+ return new_state
165
+
166
+ def create_response(response: str, ai_message: AIMessage) -> ToolMessage:
167
+ """
168
+ Create a ToolMessage from a given response and AI message.
169
+
170
+ Args:
171
+ response (str): The response content to be included in the ToolMessage.
172
+ ai_message (AIMessage): The original AI message containing tool call information.
173
+
174
+ Returns:
175
+ ToolMessage: The created ToolMessage.
176
+ """
177
+ return ToolMessage(content=response, tool_call_id=ai_message.tool_calls[0].get("id"))
178
+
179
+ def human_node_function(state: State) -> State:
180
+ """
181
+ Process the state if human assistance is required.
182
+
183
+ Args:
184
+ state (State): The current state of the chatbot including messages and ask_human flag.
185
+
186
+ Returns:
187
+ State: The updated state after processing human assistance.
188
+ """
189
+ new_messages = state["messages"]
190
+
191
+ if state["messages"] and not isinstance(state["messages"][-1], ToolMessage):
192
+ new_response = create_response("No response from human.", state["messages"][-1])
193
+ new_messages.append(new_response)
194
+
195
+ new_state = {"messages": new_messages, "ask_human": False}
196
+ return new_state
197
+
198
+ def tools_condition(state: State) -> str:
199
+ """
200
+ Determine the next node in the state graph based on the current state.
201
+
202
+ Args:
203
+ state (State): The current state of the chatbot including messages and ask_human flag.
204
+
205
+ Returns:
206
+ str: The identifier of the next node to process.
207
+ """
208
+ # Define your condition to choose the next node here
209
+ # Example: Check if the state contains a specific tool call
210
+ for message in state["messages"]:
211
+ if isinstance(message, AIMessage) and message.tool_calls:
212
+ return "tools"
213
+ return "chatbot"
214
+
215
+ def tool_node_function(state: State) -> State:
216
+ """
217
+ Process the state by executing the appropriate tool function.
218
+
219
+ Args:
220
+ state (State): The current state of the chatbot including messages and ask_human flag.
221
+
222
+ Returns:
223
+ State: The updated state after processing the tool function.
224
+ """
225
+ new_messages = state["messages"]
226
+
227
+ for message in state["messages"]:
228
+ if isinstance(message, AIMessage) and message.tool_calls:
229
+ tool_response = DuckDuck_tool.run(message.tool_calls[0]["args"]["query"])
230
+ new_response = create_response(tool_response, message)
231
+ new_messages.append(new_response)
232
+
233
+ new_state = {"messages": new_messages, "ask_human": False}
234
+ return new_state
235
+
236
+ def format_message(msg: Union[Dict, BaseMessage, ToolMessage]) -> Dict[str, str]:
237
+ """
238
+ Format a message for display in the chat.
239
+
240
+ Args:
241
+ msg (Union[Dict, BaseMessage, ToolMessage]): The message to be formatted.
242
+
243
+ Returns:
244
+ Dict[str, str]: The formatted message as a dictionary with role and content.
245
+ """
246
+ if isinstance(msg, dict):
247
+ formatted_msg = {"role": msg["role"], "content": msg["content"]}
248
+ else:
249
+ role = "assistant" if isinstance(msg, AIMessage) else "user"
250
+ formatted_msg = {"role": role, "content": msg.content}
251
+ return formatted_msg
252
+
253
+ def update_chat(message: str, chatbot_state: Dict) -> List[List[str]]:
254
+ """
255
+ Update the chat with a new user message and process it through the chatbot.
256
+
257
+ Args:
258
+ message (str): The user's message to be added to the chat.
259
+ chatbot_state (Dict): The current state of the chatbot.
260
+
261
+ Returns:
262
+ List[List[str]]: The formatted messages for display in the chat.
263
+ """
264
+ state = {"messages": [{"role": "user", "content": message}], "ask_human": False}
265
+ chatbot_state["messages"].append(state["messages"][0])
266
+
267
+ new_state = graph.execute("chatbot", chatbot_state)
268
+ chatbot_state["messages"] = new_state["messages"]
269
+ chatbot_state["ask_human"] = new_state["ask_human"]
270
+
271
+ formatted_messages = [format_message(msg) for msg in chatbot_state["messages"]]
272
+ return [[msg["role"], msg["content"]] for msg in formatted_messages]
273
+
274
+ def init_chatbot() -> Dict:
275
+ """
276
+ Initialize the chatbot with an empty state.
277
+
278
+ Returns:
279
+ Dict: The initial state of the chatbot.
280
+ """
281
+ initial_state = {"messages": [], "ask_human": False}
282
+ return initial_state
283
+
284
+ # Initialize the tools and chatbot
285
+ llm_with_tools = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=os.getenv("GOOGLE_API_KEY"))
286
+ DuckDuck_tool = DuckDuckGoSearchRun()
287
+ toolset = [DuckDuck_tool]
288
+
289
+ # Initialize the graph and add nodes and edges
290
+ graph = Graph()
291
+ graph.add_node(Node("chatbot", chatbot_function))
292
+ graph.add_node(Node("toolset", tool_node_function))
293
+ graph.add_node(Node("human", human_node_function))
294
+
295
+ graph.add_edge(Edge("chatbot", "toolset", lambda state: not state.get("ask_human", False)))
296
+ graph.add_edge(Edge("toolset", "chatbot"))
297
+ graph.add_edge(Edge("human", "chatbot"))
298
+
299
+ # Initialize Gradio interface
300
+ with gr.Blocks() as iface:
301
+ chatbot_state = gr.State(init_chatbot())
302
+
303
+ with gr.Row():
304
+ with gr.Column():
305
+ user_input = gr.Textbox(label="Your message")
306
+ send_button = gr.Button("Send")
307
+ chat_output = gr.Chatbot(label="Chatbot conversation")
308
+
309
+ send_button.click(update_chat, inputs=[user_input, chatbot_state], outputs=[chat_output])
310
+
311
+ iface.launch()
312
+