demo2 / src /node /reactnode.py
Dinesh310's picture
Update src/node/reactnode.py
f68c145 verified
"""LangGraph nodes for RAG workflow + ReAct Agent inside generate_content"""
from typing import List
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
class RAGNodes:
"""Graph nodes for LangGraph-based RAG workflow"""
def __init__(self, vector_store, llm):
self.vector_store = vector_store
self.llm = llm
# -------------------------
# RETRIEVE NODE
# -------------------------
def retrieve(self, state: dict) -> dict:
"""Node: Fetch documents from FAISS."""
print("--- RETRIEVING ---")
retriever = self.vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": 5, "lambda_mult": 0.25}
)
documents: List[Document] = retriever.invoke(state["question"])
return {"context": documents}
# -------------------------
# GENERATE NODE
# -------------------------
def generate(self, state: dict) -> dict:
"""Node: Generate answer using LLM strictly from context."""
print("--- GENERATING ---")
prompt = ChatPromptTemplate.from_template("""
You are a professional Project Analyst.
Use ONLY the following context to answer the question.
If the answer is not in the context, say "I don't know".
Context:
{context}
Question:
{question}
Answer (cite sources if possible):
""")
# Format retrieved documents
formatted_context = "\n\n".join(
f"[{i+1}] {doc.page_content}"
for i, doc in enumerate(state["context"])
)
chain = prompt | self.llm
response = chain.invoke({
"context": formatted_context,
"question": state["question"]
})
return {"answer": response.content}
# from typing import List, Optional
# from src.state.rag_state import RAGState
# from langchain_core.documents import Document
# from langchain_core.tools import Tool
# from langchain_core.messages import HumanMessage
# from langgraph.prebuilt import create_react_agent
# # Wikipedia tool
# from langchain_community.utilities import WikipediaAPIWrapper
# from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
# class RAGNodes:
# """Contains node functions for RAG workflow"""
# def __init__(self, retriever, llm):
# self.retriever = retriever
# self.llm = llm
# self._agent = None # lazy-init agent
# def retrieve_docs(self, state: RAGState) -> RAGState:
# """Classic retriever node"""
# docs = self.retriever.invoke(state.question)
# return RAGState(
# question=state.question,
# retrieved_docs=docs
# )
# def _build_tools(self) -> List[Tool]:
# """Build retriever + wikipedia tools"""
# def retriever_tool_fn(query: str) -> str:
# docs: List[Document] = self.retriever.invoke(query)
# if not docs:
# return "No documents found."
# merged = []
# for i, d in enumerate(docs[:8], start=1):
# meta = d.metadata if hasattr(d, "metadata") else {}
# title = meta.get("title") or meta.get("source") or f"doc_{i}"
# merged.append(f"[{i}] {title}\n{d.page_content}")
# return "\n\n".join(merged)
# retriever_tool = Tool(
# name="retriever",
# description="Fetch passages from indexed corpus.",
# func=retriever_tool_fn,
# )
# wiki = WikipediaQueryRun(
# api_wrapper=WikipediaAPIWrapper(top_k_results=3, lang="en")
# )
# wikipedia_tool = Tool(
# name="wikipedia",
# description="Search Wikipedia for general knowledge.",
# func=wiki.run,
# )
# return [retriever_tool, wikipedia_tool]
# def _build_agent(self):
# """ReAct agent with tools"""
# tools = self._build_tools()
# system_prompt = (
# "You are a helpful RAG agent. "
# "Prefer 'retriever' for user-provided docs; use 'wikipedia' for general knowledge. "
# "Return only the final useful answer."
# )
# self._agent = create_react_agent(self.llm, tools=tools,prompt=system_prompt)
# def generate_answer(self, state: RAGState) -> RAGState:
# """
# Generate answer using ReAct agent with retriever + wikipedia.
# """
# if self._agent is None:
# self._build_agent()
# result = self._agent.invoke({"messages": [HumanMessage(content=state.question)]})
# messages = result.get("messages", [])
# answer: Optional[str] = None
# if messages:
# answer_msg = messages[-1]
# answer = getattr(answer_msg, "content", None)
# return RAGState(
# question=state.question,
# retrieved_docs=state.retrieved_docs,
# answer=answer or "Could not generate answer."
# )