File size: 4,602 Bytes
cd77259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
from datasets import load_dataset
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.documents import Document
from langgraph.graph import START, StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from langchain_core.prompts import ChatPromptTemplate
from huggingface_hub import login
from dotenv import load_dotenv
from typing import TypedDict, List

# Load environment variables
load_dotenv()

GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")

# Authenticate Hugging Face
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
        print("✅ Logged in to Hugging Face using HF_TOKEN.")
    except Exception as e:
        print(f"⚠️ Hugging Face login failed: {e}")
else:
    print("⚠️ No HF_TOKEN found in .env file. Using public mode.")


# --- STATE DEFINITION ---
class RAGState(TypedDict):
    question: str
    context: str
    answer: str
    chat_history: List[str]
    source_documents: List[Document]


def build_rag_pipeline():
    """Builds a LangGraph-based RAG pipeline compatible with LangChain 1.x."""

    # --- Load dataset ---
    try:
        dataset = load_dataset("fadodr/mental_health_therapy", split="train[:300]")
        print("✅ Loaded dataset: fadodr/mental_health_therapy")
    except Exception as e:
        print(f"⚠️ Could not load dataset: {e}")
        dataset = load_dataset("mental_health_therapy", split="train[:300]", token=HF_TOKEN)

    # --- Prepare documents ---
    texts = [f"Q: {d['instruction']}\nA: {d['input']}" for d in dataset if d.get("input", "").strip()]
    if not texts:
        raise ValueError("No valid text found in dataset to create embeddings!")

    splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    docs = [Document(page_content=t) for t in texts]
    split_docs = splitter.split_documents(docs)

    # --- Embeddings + Chroma DB ---
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_db = Chroma.from_documents(split_docs, embeddings, persist_directory="chroma_db")
    retriever = vector_db.as_retriever(search_kwargs={"k": 3})

    # --- LLM ---
    llm = ChatGoogleGenerativeAI(model="models/gemini-2.5-flash", google_api_key=GOOGLE_API_KEY)

    # --- PROMPT TEMPLATE ---
    prompt = ChatPromptTemplate.from_template(
        """

        You are a helpful assistant. Use the following retrieved context to answer the user's question.

        If the context doesn't contain the answer, say so politely.

        Context:

        {context}



        Question:

        {question}



        Answer:

        """
    )

    # --- NODES (GRAPH FUNCTIONS) ---
    def retrieve_docs(state: RAGState):
        query = state["question"]
        docs = retriever.invoke(query)
        context = "\n\n".join([d.page_content for d in docs])
        return {"context": context, "source_documents": docs}

    def generate_answer(state: RAGState):
        prompt_text = prompt.format(context=state["context"], question=state["question"])
        response = llm.invoke(prompt_text)
        return {"answer": response.content}

    # --- BUILD THE GRAPH ---
    graph_builder = StateGraph(RAGState)
    graph_builder.add_node("retrieve", retrieve_docs)
    graph_builder.add_node("generate", generate_answer)
    graph_builder.add_edge(START, "retrieve")
    graph_builder.add_edge("retrieve", "generate")

    # Add in-memory checkpointing (conversation memory)
    memory = MemorySaver()

    graph = graph_builder.compile(checkpointer=memory)

    # Wrap in a callable interface so app.py still works
    class RAGChainWrapper:
        def __init__(self, graph):
            self.graph = graph

        def __call__(self, inputs: dict):
            question = inputs.get("question", "")
            state = {"question": question, "chat_history": []}
            result = self.graph.invoke(
                state,
                config={"configurable": {"thread_id": "default"}}
            )
            return {
                "answer": result.get("answer", ""),
                "source_documents": result.get("source_documents", [])
            }

    rag_chain = RAGChainWrapper(graph)

    return llm, retriever, rag_chain