|
|
import os
|
|
|
from datasets import load_dataset
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
from langchain_community.vectorstores import Chroma
|
|
|
from langchain_text_splitters import CharacterTextSplitter
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
from langchain_core.documents import Document
|
|
|
from langgraph.graph import START, StateGraph
|
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
|
from langgraph.prebuilt import create_react_agent
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
from huggingface_hub import login
|
|
|
from dotenv import load_dotenv
|
|
|
from typing import TypedDict, List
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
|
|
|
|
|
if HF_TOKEN:
|
|
|
try:
|
|
|
login(token=HF_TOKEN)
|
|
|
print("✅ Logged in to Hugging Face using HF_TOKEN.")
|
|
|
except Exception as e:
|
|
|
print(f"⚠️ Hugging Face login failed: {e}")
|
|
|
else:
|
|
|
print("⚠️ No HF_TOKEN found in .env file. Using public mode.")
|
|
|
|
|
|
|
|
|
|
|
|
class RAGState(TypedDict):
|
|
|
question: str
|
|
|
context: str
|
|
|
answer: str
|
|
|
chat_history: List[str]
|
|
|
source_documents: List[Document]
|
|
|
|
|
|
|
|
|
def build_rag_pipeline():
|
|
|
"""Builds a LangGraph-based RAG pipeline compatible with LangChain 1.x."""
|
|
|
|
|
|
|
|
|
try:
|
|
|
dataset = load_dataset("fadodr/mental_health_therapy", split="train[:300]")
|
|
|
print("✅ Loaded dataset: fadodr/mental_health_therapy")
|
|
|
except Exception as e:
|
|
|
print(f"⚠️ Could not load dataset: {e}")
|
|
|
dataset = load_dataset("mental_health_therapy", split="train[:300]", token=HF_TOKEN)
|
|
|
|
|
|
|
|
|
texts = [f"Q: {d['instruction']}\nA: {d['input']}" for d in dataset if d.get("input", "").strip()]
|
|
|
if not texts:
|
|
|
raise ValueError("No valid text found in dataset to create embeddings!")
|
|
|
|
|
|
splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
|
|
docs = [Document(page_content=t) for t in texts]
|
|
|
split_docs = splitter.split_documents(docs)
|
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
vector_db = Chroma.from_documents(split_docs, embeddings, persist_directory="chroma_db")
|
|
|
retriever = vector_db.as_retriever(search_kwargs={"k": 3})
|
|
|
|
|
|
|
|
|
llm = ChatGoogleGenerativeAI(model="models/gemini-2.5-flash", google_api_key=GOOGLE_API_KEY)
|
|
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_template(
|
|
|
"""
|
|
|
You are a helpful assistant. Use the following retrieved context to answer the user's question.
|
|
|
If the context doesn't contain the answer, say so politely.
|
|
|
Context:
|
|
|
{context}
|
|
|
|
|
|
Question:
|
|
|
{question}
|
|
|
|
|
|
Answer:
|
|
|
"""
|
|
|
)
|
|
|
|
|
|
|
|
|
def retrieve_docs(state: RAGState):
|
|
|
query = state["question"]
|
|
|
docs = retriever.invoke(query)
|
|
|
context = "\n\n".join([d.page_content for d in docs])
|
|
|
return {"context": context, "source_documents": docs}
|
|
|
|
|
|
def generate_answer(state: RAGState):
|
|
|
prompt_text = prompt.format(context=state["context"], question=state["question"])
|
|
|
response = llm.invoke(prompt_text)
|
|
|
return {"answer": response.content}
|
|
|
|
|
|
|
|
|
graph_builder = StateGraph(RAGState)
|
|
|
graph_builder.add_node("retrieve", retrieve_docs)
|
|
|
graph_builder.add_node("generate", generate_answer)
|
|
|
graph_builder.add_edge(START, "retrieve")
|
|
|
graph_builder.add_edge("retrieve", "generate")
|
|
|
|
|
|
|
|
|
memory = MemorySaver()
|
|
|
|
|
|
graph = graph_builder.compile(checkpointer=memory)
|
|
|
|
|
|
|
|
|
class RAGChainWrapper:
|
|
|
def __init__(self, graph):
|
|
|
self.graph = graph
|
|
|
|
|
|
def __call__(self, inputs: dict):
|
|
|
question = inputs.get("question", "")
|
|
|
state = {"question": question, "chat_history": []}
|
|
|
result = self.graph.invoke(
|
|
|
state,
|
|
|
config={"configurable": {"thread_id": "default"}}
|
|
|
)
|
|
|
return {
|
|
|
"answer": result.get("answer", ""),
|
|
|
"source_documents": result.get("source_documents", [])
|
|
|
}
|
|
|
|
|
|
rag_chain = RAGChainWrapper(graph)
|
|
|
|
|
|
return llm, retriever, rag_chain
|
|
|
|