import os from typing import Dict, Callable, List, Union, TypedDict import gradio as gr from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.messages import AIMessage, BaseMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel from langgraph.checkpoint.sqlite import SqliteSaver class Node: def __init__(self, id: str, function: Callable): """ Initialize a Node with an ID and a function to execute. Args: id (str): The unique identifier for the node. function (Callable): The function to execute for this node. """ self.id = id self.function = function def execute(self, state: Dict) -> Dict: """ Execute the node's function with the given state. Args: state (Dict): The current state of the chatbot. Returns: Dict: The updated state after executing the node's function. """ return self.function(state) class Edge: def __init__(self, source: str, target: str, condition: Callable[[Dict], bool] = None): """ Initialize an Edge with a source node, target node, and an optional condition. Args: source (str): The ID of the source node. target (str): The ID of the target node. condition (Callable[[Dict], bool], optional): A condition function that determines if the edge should be traversed. """ self.source = source self.target = target self.condition = condition def is_active(self, state: Dict) -> bool: """ Check if the edge is active based on the given state. Args: state (Dict): The current state of the chatbot. Returns: bool: True if the edge is active, False otherwise. """ if self.condition: return self.condition(state) return True class Graph: def __init__(self): """ Initialize an empty Graph with dictionaries to hold nodes and edges. """ self.nodes = {} self.edges = {} def add_node(self, node: Node): """ Add a node to the graph. Args: node (Node): The node to add. """ self.nodes[node.id] = node def add_edge(self, edge: Edge): """ Add an edge to the graph. Args: edge (Edge): The edge to add. """ if edge.source not in self.edges: self.edges[edge.source] = [] self.edges[edge.source].append(edge) def get_next_node(self, current_node_id: str, state: Dict) -> Union[Node, None]: """ Get the next node to traverse to based on the current state. Args: current_node_id (str): The ID of the current node. state (Dict): The current state of the chatbot. Returns: Union[Node, None]: The next node to traverse to, or None if no valid edge is found. """ if current_node_id in self.edges: for edge in self.edges[current_node_id]: if edge.is_active(state): return self.nodes[edge.target] return None def execute(self, start_node_id: str, state: Dict) -> Dict: """ Execute the graph starting from the specified node. Args: start_node_id (str): The ID of the starting node. state (Dict): The initial state of the chatbot. Returns: Dict: The final state after traversing the graph. """ current_node = self.nodes.get(start_node_id) while current_node: state = current_node.execute(state) next_node = self.get_next_node(current_node.id, state) if next_node is None: break current_node = next_node return state class State(TypedDict): """ Define the State type using TypedDict to specify the structure of the state dictionary. """ messages: List[Union[Dict, BaseMessage, ToolMessage]] ask_human: bool class RequestAssistance(BaseModel): """ Define RequestAssistance model inheriting from BaseModel for schema validation. """ request: str def chatbot_function(state: State) -> State: """ Chatbot function definition which processes the current state and generates a response. Args: state (State): The current state of the chatbot including messages and ask_human flag. Returns: State: The updated state after processing the response. """ response = llm_with_tools.invoke(state["messages"]) ask_human = False if response.tool_calls: tool_name = response.tool_calls[0].get("name") if tool_name == "RequestAssistance": ask_human = True else: tool_response = DuckDuck_tool.run(response.tool_calls[0]["args"]["query"]) response.content = tool_response # Update AI message content with tool response new_state = {"messages": state["messages"] + [response], "ask_human": ask_human} return new_state def create_response(response: str, ai_message: AIMessage) -> ToolMessage: """ Create a ToolMessage from a given response and AI message. Args: response (str): The response content to be included in the ToolMessage. ai_message (AIMessage): The original AI message containing tool call information. Returns: ToolMessage: The created ToolMessage. """ return ToolMessage(content=response, tool_call_id=ai_message.tool_calls[0].get("id")) def human_node_function(state: State) -> State: """ Process the state if human assistance is required. Args: state (State): The current state of the chatbot including messages and ask_human flag. Returns: State: The updated state after processing human assistance. """ new_messages = state["messages"] if state["messages"] and not isinstance(state["messages"][-1], ToolMessage): new_response = create_response("No response from human.", state["messages"][-1]) new_messages.append(new_response) new_state = {"messages": new_messages, "ask_human": False} return new_state def tools_condition(state: State) -> str: """ Determine the next node in the state graph based on the current state. Args: state (State): The current state of the chatbot including messages and ask_human flag. Returns: str: The identifier of the next node to process. """ # Define your condition to choose the next node here # Example: Check if the state contains a specific tool call for message in state["messages"]: if isinstance(message, AIMessage) and message.tool_calls: return "tools" return "chatbot" def tool_node_function(state: State) -> State: """ Process the state by executing the appropriate tool function. Args: state (State): The current state of the chatbot including messages and ask_human flag. Returns: State: The updated state after processing the tool function. """ new_messages = state["messages"] for message in state["messages"]: if isinstance(message, AIMessage) and message.tool_calls: tool_response = DuckDuck_tool.run(message.tool_calls[0]["args"]["query"]) new_response = create_response(tool_response, message) new_messages.append(new_response) new_state = {"messages": new_messages, "ask_human": False} return new_state def format_message(msg: Union[Dict, BaseMessage, ToolMessage]) -> Dict[str, str]: """ Format a message for display in the chat. Args: msg (Union[Dict, BaseMessage, ToolMessage]): The message to be formatted. Returns: Dict[str, str]: The formatted message as a dictionary with role and content. """ if isinstance(msg, dict): formatted_msg = {"role": msg["role"], "content": msg["content"]} else: role = "assistant" if isinstance(msg, AIMessage) else "user" formatted_msg = {"role": role, "content": msg.content} return formatted_msg def update_chat(message: str, chatbot_state: Dict) -> List[List[str]]: """ Update the chat with a new user message and process it through the chatbot. Args: message (str): The user's message to be added to the chat. chatbot_state (Dict): The current state of the chatbot. Returns: List[List[str]]: The formatted messages for display in the chat. """ state = {"messages": [{"role": "user", "content": message}], "ask_human": False} chatbot_state["messages"].append(state["messages"][0]) new_state = graph.execute("chatbot", chatbot_state) chatbot_state["messages"] = new_state["messages"] chatbot_state["ask_human"] = new_state["ask_human"] formatted_messages = [format_message(msg) for msg in chatbot_state["messages"]] return [[msg["role"], msg["content"]] for msg in formatted_messages] def init_chatbot() -> Dict: """ Initialize the chatbot with an empty state. Returns: Dict: The initial state of the chatbot. """ initial_state = {"messages": [], "ask_human": False} return initial_state # Initialize the tools and chatbot llm_with_tools = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=os.getenv("GOOGLE_API_KEY")) DuckDuck_tool = DuckDuckGoSearchRun() toolset = [DuckDuck_tool] # Initialize the graph and add nodes and edges graph = Graph() graph.add_node(Node("chatbot", chatbot_function)) graph.add_node(Node("toolset", tool_node_function)) graph.add_node(Node("human", human_node_function)) graph.add_edge(Edge("chatbot", "toolset", lambda state: not state.get("ask_human", False))) graph.add_edge(Edge("toolset", "chatbot")) graph.add_edge(Edge("human", "chatbot")) # Initialize Gradio interface with gr.Blocks() as iface: chatbot_state = gr.State(init_chatbot()) with gr.Row(): with gr.Column(): user_input = gr.Textbox(label="Your message") send_button = gr.Button("Send") chat_output = gr.Chatbot(label="Chatbot conversation") send_button.click(update_chat, inputs=[user_input, chatbot_state], outputs=[chat_output]) iface.launch()