|
import os |
|
import uuid |
|
import logging |
|
from dotenv import load_dotenv |
|
import json |
|
import gradio as gr |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import AIMessage, HumanMessage |
|
from typing_extensions import TypedDict |
|
from typing import Annotated |
|
from langchain_core.messages import ToolMessage |
|
from langgraph.graph import StateGraph, START, END |
|
from langgraph.graph.message import add_messages |
|
from langchain_openai import ChatOpenAI as Chat |
|
|
|
from uuid import uuid4 |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
model = os.getenv("OPENAI_MODEL", "gpt-4") |
|
temperature = float(os.getenv("OPENAI_TEMPERATURE", 0)) |
|
|
|
web_search = TavilySearchResults(max_results=2) |
|
tools = [web_search] |
|
|
|
|
|
class State(TypedDict): |
|
messages: Annotated[list, add_messages] |
|
|
|
|
|
graph_builder = StateGraph(State) |
|
|
|
|
|
llm = Chat( |
|
openai_api_key=openai_api_key, |
|
model=model, |
|
temperature=temperature |
|
) |
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
def chatbot(state: State): |
|
return {"messages": [llm_with_tools.invoke(state["messages"])]} |
|
|
|
|
|
graph_builder.add_node("chatbot", chatbot) |
|
|
|
|
|
class BasicToolNode: |
|
"""A node that runs the tools requested in the last AIMessage.""" |
|
|
|
def __init__(self, tools: list) -> None: |
|
self.tools_by_name = {tool.name: tool for tool in tools} |
|
|
|
def __call__(self, inputs: dict): |
|
if messages := inputs.get("messages", []): |
|
message = messages[-1] |
|
else: |
|
raise ValueError("No message found in input") |
|
outputs = [] |
|
for tool_call in message.tool_calls: |
|
tool_result = self.tools_by_name[tool_call["name"]].invoke( |
|
tool_call["args"] |
|
) |
|
outputs.append( |
|
ToolMessage( |
|
content=json.dumps(tool_result), |
|
name=tool_call["name"], |
|
tool_call_id=tool_call["id"], |
|
) |
|
) |
|
return {"messages": outputs} |
|
|
|
|
|
def route_tools( |
|
state: State, |
|
): |
|
""" |
|
Use in the conditional_edge to route to the ToolNode if the last message |
|
has tool calls. Otherwise, route to the end. |
|
""" |
|
if isinstance(state, list): |
|
ai_message = state[-1] |
|
elif messages := state.get("messages", []): |
|
ai_message = messages[-1] |
|
else: |
|
raise ValueError( |
|
f"No messages found in input state to tool_edge: {state}") |
|
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: |
|
return "tools" |
|
return END |
|
|
|
|
|
tool_node = BasicToolNode(tools=[web_search]) |
|
graph_builder.add_node("tools", tool_node) |
|
graph_builder.add_conditional_edges( |
|
"chatbot", |
|
route_tools, |
|
|
|
|
|
|
|
|
|
|
|
{"tools": "tools", END: END}, |
|
) |
|
|
|
graph_builder.add_edge("tools", "chatbot") |
|
graph_builder.add_edge(START, "chatbot") |
|
|
|
|
|
def chatbot(state: State): |
|
if not state["messages"]: |
|
logger.info( |
|
"Received an empty message list. Returning default response.") |
|
return {"messages": [AIMessage(content="Hello! How can I assist you today?")]} |
|
|
|
|
|
last_message = state["messages"][-1] |
|
if not getattr(last_message, "tool_calls", None): |
|
logger.info( |
|
"No tool call in the last message. Proceeding without tool invocation.") |
|
response = llm.invoke(state["messages"]) |
|
else: |
|
logger.info( |
|
"Tool call detected in the last message. Invoking tool response.") |
|
response = llm_with_tools.invoke(state["messages"]) |
|
|
|
|
|
if not isinstance(response, AIMessage): |
|
response = AIMessage(content=response.content) |
|
|
|
return {"messages": [response]} |
|
|
|
|
|
graph = graph_builder.compile() |
|
|
|
|
|
def gradio_chat(message, history): |
|
try: |
|
if not isinstance(message, str): |
|
message = str(message) |
|
|
|
config = { |
|
"configurable": {"thread_id": "1"}, |
|
"checkpoint_id": str(uuid4()), |
|
"recursion_limit": 300 |
|
} |
|
|
|
|
|
formatted_message = [HumanMessage(content=message)] |
|
response = graph.invoke( |
|
{ |
|
"messages": formatted_message |
|
}, |
|
config=config, |
|
stream_mode="values" |
|
) |
|
|
|
|
|
assistant_messages = [ |
|
msg for msg in response["messages"] if isinstance(msg, AIMessage) |
|
] |
|
last_message = assistant_messages[-1] if assistant_messages else AIMessage( |
|
content="No response generated.") |
|
|
|
logger.info("Sending response back to Gradio interface.") |
|
return last_message.content |
|
except Exception as e: |
|
logger.error(f"Error encountered in gradio_chat: {e}") |
|
return "Sorry, I encountered an error. Please try again." |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as demo: |
|
chatbot = gr.ChatInterface( |
|
chatbot=gr.Chatbot(height=800, render=False), |
|
fn=gradio_chat, |
|
multimodal=False, |
|
title="LangGraph Agentic Chatbot", |
|
examples=[ |
|
"What is the capital of France?", |
|
"Show me the Movie Trailer for Doctor Strange.", |
|
"Give me the latest news on the COVID-19 pandemic.", |
|
"What are the latest updates on NVIDIA's new GPU?", |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|