| """LangGraph nodes for RAG workflow + ReAct Agent inside generate_content""" | |
| from typing import List | |
| from langchain_core.documents import Document | |
| from langchain_core.prompts import ChatPromptTemplate | |
| class RAGNodes: | |
| """Graph nodes for LangGraph-based RAG workflow""" | |
| def __init__(self, vector_store, llm): | |
| self.vector_store = vector_store | |
| self.llm = llm | |
| # ------------------------- | |
| # RETRIEVE NODE | |
| # ------------------------- | |
| def retrieve(self, state: dict) -> dict: | |
| """Node: Fetch documents from FAISS.""" | |
| print("--- RETRIEVING ---") | |
| retriever = self.vector_store.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": 5, "lambda_mult": 0.25} | |
| ) | |
| documents: List[Document] = retriever.invoke(state["question"]) | |
| return {"context": documents} | |
| # ------------------------- | |
| # GENERATE NODE | |
| # ------------------------- | |
| def generate(self, state: dict) -> dict: | |
| """Node: Generate answer using LLM strictly from context.""" | |
| print("--- GENERATING ---") | |
| prompt = ChatPromptTemplate.from_template(""" | |
| You are a professional Project Analyst. | |
| Use ONLY the following context to answer the question. | |
| If the answer is not in the context, say "I don't know". | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer (cite sources if possible): | |
| """) | |
| # Format retrieved documents | |
| formatted_context = "\n\n".join( | |
| f"[{i+1}] {doc.page_content}" | |
| for i, doc in enumerate(state["context"]) | |
| ) | |
| chain = prompt | self.llm | |
| response = chain.invoke({ | |
| "context": formatted_context, | |
| "question": state["question"] | |
| }) | |
| return {"answer": response.content} | |
| # from typing import List, Optional | |
| # from src.state.rag_state import RAGState | |
| # from langchain_core.documents import Document | |
| # from langchain_core.tools import Tool | |
| # from langchain_core.messages import HumanMessage | |
| # from langgraph.prebuilt import create_react_agent | |
| # # Wikipedia tool | |
| # from langchain_community.utilities import WikipediaAPIWrapper | |
| # from langchain_community.tools.wikipedia.tool import WikipediaQueryRun | |
| # class RAGNodes: | |
| # """Contains node functions for RAG workflow""" | |
| # def __init__(self, retriever, llm): | |
| # self.retriever = retriever | |
| # self.llm = llm | |
| # self._agent = None # lazy-init agent | |
| # def retrieve_docs(self, state: RAGState) -> RAGState: | |
| # """Classic retriever node""" | |
| # docs = self.retriever.invoke(state.question) | |
| # return RAGState( | |
| # question=state.question, | |
| # retrieved_docs=docs | |
| # ) | |
| # def _build_tools(self) -> List[Tool]: | |
| # """Build retriever + wikipedia tools""" | |
| # def retriever_tool_fn(query: str) -> str: | |
| # docs: List[Document] = self.retriever.invoke(query) | |
| # if not docs: | |
| # return "No documents found." | |
| # merged = [] | |
| # for i, d in enumerate(docs[:8], start=1): | |
| # meta = d.metadata if hasattr(d, "metadata") else {} | |
| # title = meta.get("title") or meta.get("source") or f"doc_{i}" | |
| # merged.append(f"[{i}] {title}\n{d.page_content}") | |
| # return "\n\n".join(merged) | |
| # retriever_tool = Tool( | |
| # name="retriever", | |
| # description="Fetch passages from indexed corpus.", | |
| # func=retriever_tool_fn, | |
| # ) | |
| # wiki = WikipediaQueryRun( | |
| # api_wrapper=WikipediaAPIWrapper(top_k_results=3, lang="en") | |
| # ) | |
| # wikipedia_tool = Tool( | |
| # name="wikipedia", | |
| # description="Search Wikipedia for general knowledge.", | |
| # func=wiki.run, | |
| # ) | |
| # return [retriever_tool, wikipedia_tool] | |
| # def _build_agent(self): | |
| # """ReAct agent with tools""" | |
| # tools = self._build_tools() | |
| # system_prompt = ( | |
| # "You are a helpful RAG agent. " | |
| # "Prefer 'retriever' for user-provided docs; use 'wikipedia' for general knowledge. " | |
| # "Return only the final useful answer." | |
| # ) | |
| # self._agent = create_react_agent(self.llm, tools=tools,prompt=system_prompt) | |
| # def generate_answer(self, state: RAGState) -> RAGState: | |
| # """ | |
| # Generate answer using ReAct agent with retriever + wikipedia. | |
| # """ | |
| # if self._agent is None: | |
| # self._build_agent() | |
| # result = self._agent.invoke({"messages": [HumanMessage(content=state.question)]}) | |
| # messages = result.get("messages", []) | |
| # answer: Optional[str] = None | |
| # if messages: | |
| # answer_msg = messages[-1] | |
| # answer = getattr(answer_msg, "content", None) | |
| # return RAGState( | |
| # question=state.question, | |
| # retrieved_docs=state.retrieved_docs, | |
| # answer=answer or "Could not generate answer." | |
| # ) | |