File size: 6,078 Bytes
252375c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
import uuid
import logging
from dotenv import load_dotenv
import json
import gradio as gr
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AIMessage, HumanMessage
from typing_extensions import TypedDict
from typing import Annotated
from langchain_core.messages import ToolMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI as Chat
from uuid import uuid4
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# LangGraph setup
openai_api_key = os.getenv("OPENAI_API_KEY")
model = os.getenv("OPENAI_MODEL", "gpt-4")
temperature = float(os.getenv("OPENAI_TEMPERATURE", 0))
web_search = TavilySearchResults(max_results=2)
tools = [web_search]
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
llm = Chat(
openai_api_key=openai_api_key,
model=model,
temperature=temperature
)
llm_with_tools = llm.bind_tools(tools)
def chatbot(state: State):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
graph_builder.add_node("chatbot", chatbot)
class BasicToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
def route_tools(
state: State,
):
"""
Use in the conditional_edge to route to the ToolNode if the last message
has tool calls. Otherwise, route to the end.
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(
f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return END
tool_node = BasicToolNode(tools=[web_search])
graph_builder.add_node("tools", tool_node)
graph_builder.add_conditional_edges(
"chatbot",
route_tools,
# The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node
# It defaults to the identity function, but if you
# want to use a node named something else apart from "tools",
# You can update the value of the dictionary to something else
# e.g., "tools": "my_tools"
{"tools": "tools", END: END},
)
# Any time a tool is called, we return to the chatbot to decide the next step
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge(START, "chatbot")
def chatbot(state: State):
if not state["messages"]:
logger.info(
"Received an empty message list. Returning default response.")
return {"messages": [AIMessage(content="Hello! How can I assist you today?")]}
# Check for tool call in the last message
last_message = state["messages"][-1]
if not getattr(last_message, "tool_calls", None):
logger.info(
"No tool call in the last message. Proceeding without tool invocation.")
response = llm.invoke(state["messages"])
else:
logger.info(
"Tool call detected in the last message. Invoking tool response.")
response = llm_with_tools.invoke(state["messages"])
# Ensure the response is wrapped as AIMessage if it's not already
if not isinstance(response, AIMessage):
response = AIMessage(content=response.content)
return {"messages": [response]}
graph = graph_builder.compile()
def gradio_chat(message, history):
try:
if not isinstance(message, str):
message = str(message)
config = {
"configurable": {"thread_id": "1"},
"checkpoint_id": str(uuid4()),
"recursion_limit": 300
}
# Format the user message correctly as a HumanMessage
formatted_message = [HumanMessage(content=message)]
response = graph.invoke(
{
"messages": formatted_message
},
config=config,
stream_mode="values"
)
# Extract assistant messages and ensure they are AIMessage type
assistant_messages = [
msg for msg in response["messages"] if isinstance(msg, AIMessage)
]
last_message = assistant_messages[-1] if assistant_messages else AIMessage(
content="No response generated.")
logger.info("Sending response back to Gradio interface.")
return last_message.content
except Exception as e:
logger.error(f"Error encountered in gradio_chat: {e}")
return "Sorry, I encountered an error. Please try again."
with gr.Blocks(theme=gr.themes.Default()) as demo:
chatbot = gr.ChatInterface(
chatbot=gr.Chatbot(height=800, render=False),
fn=gradio_chat,
multimodal=False,
title="LangGraph Agentic Chatbot",
examples=[
"What's the weather like today?",
"Show me the Movie Trailer for Doctor Strange.",
"Give me the latest news on the COVID-19 pandemic.",
"What are the latest updtaes on NVIDIA's new GPU?",
],
)
if __name__ == "__main__":
demo.launch()
|