Spaces:
Sleeping
Sleeping
| from typing import TypedDict, Annotated | |
| from langchain_core.messages import BaseMessage, SystemMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_openai import ChatOpenAI | |
| from dotenv import load_dotenv | |
| from tools import ( | |
| create_rag_tool, | |
| arxiv_search, | |
| wikipedia_search, | |
| tavily_search, | |
| ) | |
| load_dotenv() | |
| # =============================== | |
| # SYSTEM PROMPT | |
| # =============================== | |
| SYSTEM_PROMPT = SystemMessage( | |
| content=""" | |
| You are an AI assistant using Retrieval-Augmented Generation. | |
| If a document is uploaded, you MUST answer using it. | |
| If no relevant info exists, clearly say so. | |
| Never hallucinate document content. | |
| """ | |
| ) | |
| # =============================== | |
| # STATE | |
| # =============================== | |
| class ChatState(TypedDict): | |
| messages: Annotated[list[BaseMessage], add_messages] | |
| # =============================== | |
| # LLM | |
| # =============================== | |
| llm = ChatOpenAI( | |
| model="gpt-4.1-nano", | |
| temperature=0.3, | |
| streaming=True | |
| ) | |
| # =============================== | |
| # TOOLS | |
| # =============================== | |
| rag_tool = create_rag_tool() | |
| tools = [ | |
| rag_tool, | |
| wikipedia_search, | |
| arxiv_search, | |
| tavily_search, | |
| ] | |
| llm = llm.bind_tools(tools) | |
| tool_node = ToolNode(tools) | |
| # =============================== | |
| # CHAT NODE | |
| # =============================== | |
| def chatbot(state: ChatState): | |
| messages = [SYSTEM_PROMPT] + state["messages"] | |
| response = llm.invoke(messages) | |
| return {"messages": [response]} | |
| # =============================== | |
| # GRAPH | |
| # =============================== | |
| memory = MemorySaver() | |
| graph = StateGraph(ChatState) | |
| graph.add_node("chat", chatbot) | |
| graph.add_node("tools", tool_node) | |
| graph.add_edge(START, "chat") | |
| graph.add_conditional_edges("chat", tools_condition) | |
| graph.add_edge("tools", "chat") | |
| graph.add_edge("chat", END) | |
| app = graph.compile(checkpointer=memory) |