Spaces:
Sleeping
Sleeping
Modified all files with the replacement of flan-t5 with groq api
Browse files- Dockerfile +5 -10
- app.py +60 -44
- graph.py +46 -53
- requirements.txt +3 -4
Dockerfile
CHANGED
@@ -3,27 +3,22 @@ FROM python:3.10-slim
|
|
3 |
WORKDIR /app
|
4 |
|
5 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
6 |
-
build-essential \
|
7 |
git \
|
8 |
curl \
|
9 |
-
libopenblas-dev \
|
10 |
-
libomp-dev \
|
11 |
-
python3-dev \
|
12 |
&& apt-get clean \
|
13 |
&& rm -rf /var/lib/apt/lists/*
|
14 |
|
15 |
-
COPY . /app
|
16 |
-
|
17 |
ENV PYTHONUNBUFFERED=1 \
|
18 |
PYTHONDONTWRITEBYTECODE=1 \
|
19 |
LANG=C.UTF-8
|
20 |
|
21 |
-
|
22 |
|
23 |
-
RUN pip install --no-cache-dir
|
|
|
24 |
|
25 |
-
|
26 |
|
27 |
EXPOSE 7860
|
28 |
|
29 |
-
CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
|
|
|
3 |
WORKDIR /app
|
4 |
|
5 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
|
6 |
git \
|
7 |
curl \
|
|
|
|
|
|
|
8 |
&& apt-get clean \
|
9 |
&& rm -rf /var/lib/apt/lists/*
|
10 |
|
|
|
|
|
11 |
ENV PYTHONUNBUFFERED=1 \
|
12 |
PYTHONDONTWRITEBYTECODE=1 \
|
13 |
LANG=C.UTF-8
|
14 |
|
15 |
+
COPY requirements.txt .
|
16 |
|
17 |
+
RUN pip install --no-cache-dir --upgrade pip setuptools wheel \
|
18 |
+
&& pip install --no-cache-dir -r requirements.txt
|
19 |
|
20 |
+
COPY . .
|
21 |
|
22 |
EXPOSE 7860
|
23 |
|
24 |
+
CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
|
app.py
CHANGED
@@ -5,19 +5,11 @@ from pypdf import PdfReader
|
|
5 |
import hashlib
|
6 |
from transformers import pipeline
|
7 |
|
|
|
8 |
st.set_page_config(page_title="LangGraph RAG Chatbot", layout="wide")
|
9 |
st.title("📚 LangGraph RAG Chatbot")
|
10 |
|
11 |
-
#
|
12 |
-
@st.cache_resource(show_spinner=False)
|
13 |
-
def cached_vectorstore_from_text(text):
|
14 |
-
try:
|
15 |
-
return load_vectorstore_from_text(text=text)
|
16 |
-
except Exception as e:
|
17 |
-
st.warning(f"Failed to load vectorstore: {e}")
|
18 |
-
return get_retriever(text)
|
19 |
-
|
20 |
-
# Helpers
|
21 |
def compute_file_hash(raw_text):
|
22 |
return hashlib.md5(raw_text.encode("utf-8")).hexdigest() if raw_text else None
|
23 |
|
@@ -34,7 +26,20 @@ def load_uploaded_file(uploaded_file):
|
|
34 |
st.error(f"Error reading file: {e}")
|
35 |
return raw_text
|
36 |
|
37 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
with st.sidebar:
|
39 |
st.header("🔧 Settings")
|
40 |
temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.7)
|
@@ -47,32 +52,35 @@ with st.sidebar:
|
|
47 |
|
48 |
uploaded_file = st.file_uploader("Upload a file (optional)", type=["txt", "pdf"])
|
49 |
|
50 |
-
# Handle file upload
|
51 |
if uploaded_file:
|
52 |
raw_text = load_uploaded_file(uploaded_file)
|
53 |
if raw_text:
|
54 |
-
|
55 |
-
st.session_state.history = []
|
56 |
-
for key in ["retriever", "file_hash"]:
|
57 |
-
st.session_state.pop(key, None)
|
58 |
|
59 |
-
|
60 |
-
st.session_state.
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
st.markdown("**📄 Uploaded File Preview:**")
|
64 |
st.text_area("Contents", raw_text, height=200)
|
65 |
-
st.success("✅ Document loaded!
|
66 |
else:
|
67 |
st.warning("Uploaded file is empty or could not be read.")
|
68 |
-
|
69 |
# Show current mode
|
70 |
if "retriever" in st.session_state and st.session_state.retriever:
|
71 |
st.info("📄 **RAG Mode**: Answering from uploaded document")
|
72 |
else:
|
73 |
st.info("💬 **General Chat Mode**: No document loaded")
|
74 |
|
75 |
-
# Initialize
|
76 |
if "summarizer" not in st.session_state:
|
77 |
st.session_state.summarizer = pipeline(
|
78 |
"summarization",
|
@@ -80,7 +88,7 @@ if "summarizer" not in st.session_state:
|
|
80 |
device=-1
|
81 |
)
|
82 |
|
83 |
-
# Build Graph
|
84 |
if "graph" not in st.session_state or st.session_state.get("graph_model") != model_type:
|
85 |
try:
|
86 |
st.session_state.graph = build_graph(
|
@@ -93,50 +101,57 @@ if "graph" not in st.session_state or st.session_state.get("graph_model") != mod
|
|
93 |
st.error(f"Failed to build graph: {e}")
|
94 |
st.stop()
|
95 |
|
96 |
-
#
|
97 |
if "history" not in st.session_state:
|
98 |
st.session_state.history = []
|
99 |
|
100 |
-
# Query Input
|
101 |
-
|
|
|
102 |
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
try:
|
106 |
-
# Prepare history in the format expected by the graph (tuples of (query, response))
|
107 |
-
formatted_history = [(q, r) for q, r, _ in st.session_state.history]
|
108 |
-
|
109 |
result = st.session_state.graph(
|
110 |
query=query,
|
111 |
temperature=temperature,
|
112 |
-
raw_text=st.session_state.get("raw_text"
|
113 |
history=formatted_history,
|
114 |
retriever_override=st.session_state.get("retriever")
|
115 |
)
|
116 |
-
|
117 |
response = result.get("response", "No response generated.")
|
118 |
retrieved_docs = result.get("retrieved_docs", [])
|
119 |
-
|
120 |
st.markdown("### 🤖 Response")
|
121 |
st.markdown(response)
|
122 |
-
|
123 |
-
#
|
124 |
st.session_state.history.append((query, response, retrieved_docs))
|
125 |
-
|
126 |
-
# Show retrieved
|
127 |
if retrieved_docs:
|
128 |
with st.expander("📄 Retrieved Chunks"):
|
129 |
for j, doc in enumerate(retrieved_docs):
|
130 |
content = getattr(doc, "text", str(doc))
|
131 |
st.markdown(f"**Chunk {j+1}:**")
|
132 |
st.code(content.strip(), language="markdown")
|
133 |
-
|
|
|
|
|
|
|
134 |
except Exception as e:
|
135 |
st.error(f"Query failed: {e}")
|
136 |
-
|
137 |
-
|
138 |
|
139 |
-
#
|
140 |
if st.session_state.history:
|
141 |
st.markdown("### 💬 Chat History")
|
142 |
for i, (q, r, docs) in enumerate(reversed(st.session_state.history)):
|
@@ -149,7 +164,8 @@ if st.session_state.history:
|
|
149 |
content = getattr(doc, "text", str(doc))
|
150 |
st.code(content.strip()[:200] + "...", language="markdown")
|
151 |
|
152 |
-
# Clear
|
153 |
if st.sidebar.button("🗑️ Clear Chat History"):
|
154 |
st.session_state.history = []
|
155 |
-
st.
|
|
|
|
5 |
import hashlib
|
6 |
from transformers import pipeline
|
7 |
|
8 |
+
# --- Page Config ---
|
9 |
st.set_page_config(page_title="LangGraph RAG Chatbot", layout="wide")
|
10 |
st.title("📚 LangGraph RAG Chatbot")
|
11 |
|
12 |
+
# --- Helpers ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def compute_file_hash(raw_text):
|
14 |
return hashlib.md5(raw_text.encode("utf-8")).hexdigest() if raw_text else None
|
15 |
|
|
|
26 |
st.error(f"Error reading file: {e}")
|
27 |
return raw_text
|
28 |
|
29 |
+
# --- Cached Vectorstore with Persistent Cache ---
|
30 |
+
@st.cache_data(show_spinner=False)
|
31 |
+
def cached_vectorstore_from_text(raw_text_hash: str, text: str):
|
32 |
+
"""
|
33 |
+
Cache vectorstore based on hash of raw text.
|
34 |
+
If the same text is uploaded again, returns cached retriever.
|
35 |
+
"""
|
36 |
+
try:
|
37 |
+
return load_vectorstore_from_text(text=text)
|
38 |
+
except Exception as e:
|
39 |
+
st.warning(f"Failed to load vectorstore: {e}")
|
40 |
+
return get_retriever(text)
|
41 |
+
|
42 |
+
# --- Sidebar ---
|
43 |
with st.sidebar:
|
44 |
st.header("🔧 Settings")
|
45 |
temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.7)
|
|
|
52 |
|
53 |
uploaded_file = st.file_uploader("Upload a file (optional)", type=["txt", "pdf"])
|
54 |
|
|
|
55 |
if uploaded_file:
|
56 |
raw_text = load_uploaded_file(uploaded_file)
|
57 |
if raw_text:
|
58 |
+
file_hash = compute_file_hash(raw_text)
|
|
|
|
|
|
|
59 |
|
60 |
+
# Reset session if new file
|
61 |
+
if st.session_state.get("file_hash") != file_hash:
|
62 |
+
st.session_state.history = []
|
63 |
+
for key in ["retriever", "file_hash"]:
|
64 |
+
st.session_state.pop(key, None)
|
65 |
+
|
66 |
+
st.session_state.raw_text = raw_text
|
67 |
+
st.session_state.file_hash = file_hash
|
68 |
+
# Persistent cached vectorstore
|
69 |
+
st.session_state.retriever = cached_vectorstore_from_text(file_hash, raw_text)
|
70 |
|
71 |
st.markdown("**📄 Uploaded File Preview:**")
|
72 |
st.text_area("Contents", raw_text, height=200)
|
73 |
+
st.success("✅ Document loaded!")
|
74 |
else:
|
75 |
st.warning("Uploaded file is empty or could not be read.")
|
76 |
+
|
77 |
# Show current mode
|
78 |
if "retriever" in st.session_state and st.session_state.retriever:
|
79 |
st.info("📄 **RAG Mode**: Answering from uploaded document")
|
80 |
else:
|
81 |
st.info("💬 **General Chat Mode**: No document loaded")
|
82 |
|
83 |
+
# --- Initialize Summarizer ---
|
84 |
if "summarizer" not in st.session_state:
|
85 |
st.session_state.summarizer = pipeline(
|
86 |
"summarization",
|
|
|
88 |
device=-1
|
89 |
)
|
90 |
|
91 |
+
# --- Build Graph ---
|
92 |
if "graph" not in st.session_state or st.session_state.get("graph_model") != model_type:
|
93 |
try:
|
94 |
st.session_state.graph = build_graph(
|
|
|
101 |
st.error(f"Failed to build graph: {e}")
|
102 |
st.stop()
|
103 |
|
104 |
+
# --- Initialize History ---
|
105 |
if "history" not in st.session_state:
|
106 |
st.session_state.history = []
|
107 |
|
108 |
+
# --- Query Input ---
|
109 |
+
if "current_query" not in st.session_state:
|
110 |
+
st.session_state.current_query = ""
|
111 |
|
112 |
+
query = st.text_input("💬 Ask a question:", key="current_query")
|
113 |
+
send_triggered = st.button("Send")
|
114 |
+
|
115 |
+
# --- Send Query ---
|
116 |
+
if send_triggered and query.strip():
|
117 |
+
formatted_history = [(q, r) for q, r, _ in st.session_state.history]
|
118 |
+
|
119 |
+
with st.spinner("Generating response..."):
|
120 |
try:
|
|
|
|
|
|
|
121 |
result = st.session_state.graph(
|
122 |
query=query,
|
123 |
temperature=temperature,
|
124 |
+
raw_text=st.session_state.get("raw_text"),
|
125 |
history=formatted_history,
|
126 |
retriever_override=st.session_state.get("retriever")
|
127 |
)
|
128 |
+
|
129 |
response = result.get("response", "No response generated.")
|
130 |
retrieved_docs = result.get("retrieved_docs", [])
|
131 |
+
|
132 |
st.markdown("### 🤖 Response")
|
133 |
st.markdown(response)
|
134 |
+
|
135 |
+
# Save to history
|
136 |
st.session_state.history.append((query, response, retrieved_docs))
|
137 |
+
|
138 |
+
# Show retrieved docs
|
139 |
if retrieved_docs:
|
140 |
with st.expander("📄 Retrieved Chunks"):
|
141 |
for j, doc in enumerate(retrieved_docs):
|
142 |
content = getattr(doc, "text", str(doc))
|
143 |
st.markdown(f"**Chunk {j+1}:**")
|
144 |
st.code(content.strip(), language="markdown")
|
145 |
+
|
146 |
+
# Clear input
|
147 |
+
st.session_state.current_query = ""
|
148 |
+
|
149 |
except Exception as e:
|
150 |
st.error(f"Query failed: {e}")
|
151 |
+
elif send_triggered:
|
152 |
+
st.warning("Please enter a question.")
|
153 |
|
154 |
+
# --- Chat History Display ---
|
155 |
if st.session_state.history:
|
156 |
st.markdown("### 💬 Chat History")
|
157 |
for i, (q, r, docs) in enumerate(reversed(st.session_state.history)):
|
|
|
164 |
content = getattr(doc, "text", str(doc))
|
165 |
st.code(content.strip()[:200] + "...", language="markdown")
|
166 |
|
167 |
+
# --- Clear Chat ---
|
168 |
if st.sidebar.button("🗑️ Clear Chat History"):
|
169 |
st.session_state.history = []
|
170 |
+
st.session_state.current_query = ""
|
171 |
+
st.rerun()
|
graph.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
import datetime
|
2 |
import os
|
|
|
3 |
from typing import TypedDict, Optional, List
|
4 |
|
5 |
from llama_index.core.schema import Document
|
6 |
-
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
7 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
8 |
-
from transformers import pipeline as hf_pipeline
|
9 |
from langgraph.graph import StateGraph, END
|
10 |
from llama_index.core import VectorStoreIndex
|
11 |
from llama_index.core.retrievers import BaseRetriever
|
|
|
|
|
12 |
|
13 |
# --- 1. Define the State for the Graph ---
|
14 |
class GraphState(TypedDict):
|
@@ -22,17 +23,13 @@ class GraphState(TypedDict):
|
|
22 |
summarizer: Optional[any]
|
23 |
|
24 |
# --- 2. Define Graph Nodes ---
|
25 |
-
|
26 |
-
# Router node to decide the flow
|
27 |
def router_node(state: GraphState) -> GraphState:
|
28 |
"""
|
29 |
Router that determines the next step based on available retriever.
|
30 |
"""
|
31 |
print("---NODE: ROUTER---")
|
32 |
-
# This node just passes through the state - routing logic is in conditional edges
|
33 |
return state
|
34 |
|
35 |
-
# Node for handling general conversation when no PDF is loaded
|
36 |
def general_chat_node(state: GraphState) -> GraphState:
|
37 |
"""
|
38 |
Generates a response for general conversation using the LLM.
|
@@ -41,8 +38,6 @@ def general_chat_node(state: GraphState) -> GraphState:
|
|
41 |
llm = state["llm"]
|
42 |
query = state["query"]
|
43 |
history = state["history"]
|
44 |
-
|
45 |
-
# Format history for the prompt
|
46 |
history_context = "\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
|
47 |
current_time = datetime.datetime.now().strftime("%Y-%m-%d %I:%M %p")
|
48 |
prompt = f"""You are Sarathi, a friendly and knowledgeable AI assistant.
|
@@ -61,23 +56,17 @@ Human: {query}
|
|
61 |
AI:"""
|
62 |
|
63 |
try:
|
64 |
-
|
65 |
-
|
66 |
-
elif isinstance(llm, ChatGoogleGenerativeAI):
|
67 |
-
response_obj = llm.invoke(prompt)
|
68 |
-
response_text = getattr(response_obj, "content", str(response_obj))
|
69 |
-
else:
|
70 |
-
response_text = "Unsupported LLM type provided."
|
71 |
-
|
72 |
except Exception as e:
|
73 |
-
response_text = f"Model inference failed
|
74 |
|
75 |
return {"response": response_text.strip()}
|
76 |
|
77 |
-
# Node for retrieving information from a PDF
|
78 |
def retrieve_node(state: GraphState) -> GraphState:
|
79 |
"""
|
80 |
Retrieves relevant documents from the vector store based on the query.
|
|
|
81 |
"""
|
82 |
print("---NODE: RETRIEVE---")
|
83 |
query = state["query"]
|
@@ -89,7 +78,6 @@ def retrieve_node(state: GraphState) -> GraphState:
|
|
89 |
retrieved_docs = []
|
90 |
|
91 |
try:
|
92 |
-
# Dynamic top_k based on query length
|
93 |
q_len = len(query.split())
|
94 |
top_k = 3 if q_len < 5 else (5 if q_len < 15 else 8)
|
95 |
|
@@ -98,19 +86,32 @@ def retrieve_node(state: GraphState) -> GraphState:
|
|
98 |
if retrieved_docs:
|
99 |
context = "\n\n---\n\n".join([doc.text for doc in retrieved_docs])
|
100 |
|
101 |
-
# Add chat history to context
|
102 |
if history:
|
103 |
history_context = "\n\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
|
104 |
context = f"{context}\n\n--- Chat History ---\n{history_context}"
|
105 |
|
106 |
-
# Summarize if context is too long
|
107 |
MAX_CONTEXT_CHARS = 4000
|
108 |
-
if len(context) > MAX_CONTEXT_CHARS
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
except Exception as e:
|
116 |
print(f"Error in retrieve_node: {e}")
|
@@ -118,7 +119,6 @@ def retrieve_node(state: GraphState) -> GraphState:
|
|
118 |
|
119 |
return {"retrieved_docs": retrieved_docs, "context": context}
|
120 |
|
121 |
-
# Node for generating a response from RAG context
|
122 |
def generate_rag_node(state: GraphState) -> GraphState:
|
123 |
"""
|
124 |
Generates an answer using the retrieved context from the PDF.
|
@@ -148,16 +148,11 @@ Instructions:
|
|
148 |
Answer:"""
|
149 |
|
150 |
try:
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
response_obj = llm.invoke(prompt)
|
155 |
-
response_text = getattr(response_obj, "content", str(response_obj))
|
156 |
-
else:
|
157 |
-
response_text = "Unsupported LLM type provided."
|
158 |
-
|
159 |
except Exception as e:
|
160 |
-
response_text = f"Model inference failed
|
161 |
|
162 |
return {"response": response_text.strip()}
|
163 |
|
@@ -174,18 +169,18 @@ def route_query(state: GraphState) -> str:
|
|
174 |
return "general_chat"
|
175 |
|
176 |
# --- 4. Build the Graph ---
|
177 |
-
def build_graph(model_type: str = "
|
178 |
"""
|
179 |
-
Builds the
|
|
|
180 |
"""
|
181 |
-
|
182 |
if model_type == "groq":
|
183 |
-
from langchain_groq import ChatGroq
|
184 |
api_key = os.getenv("GROQ_API_KEY", "").strip()
|
185 |
if not api_key:
|
186 |
raise ValueError("GROQ_API_KEY environment variable not set.")
|
187 |
llm = ChatGroq(
|
188 |
-
model="mixtral-8x7b-32768",
|
189 |
api_key=api_key,
|
190 |
temperature=0.7,
|
191 |
)
|
@@ -193,23 +188,27 @@ def build_graph(model_type: str = "huggingface", retriever=None, summarizer=None
|
|
193 |
api_key = os.getenv("GEMINI_API_KEY", "").strip()
|
194 |
if not api_key:
|
195 |
raise ValueError("GEMINI_API_KEY environment variable not set.")
|
196 |
-
llm = ChatGoogleGenerativeAI(
|
|
|
|
|
|
|
|
|
197 |
else:
|
198 |
raise ValueError("Invalid model_type. Choose 'groq' or 'gemini'.")
|
199 |
|
200 |
-
|
|
|
|
|
|
|
201 |
workflow = StateGraph(GraphState)
|
202 |
|
203 |
-
# Add all the nodes to the graph
|
204 |
workflow.add_node("router", router_node)
|
205 |
workflow.add_node("general_chat", general_chat_node)
|
206 |
workflow.add_node("retrieve", retrieve_node)
|
207 |
workflow.add_node("generate", generate_rag_node)
|
208 |
|
209 |
-
# Set the router as the entry point
|
210 |
workflow.set_entry_point("router")
|
211 |
|
212 |
-
# Add the conditional edge from the router
|
213 |
workflow.add_conditional_edges(
|
214 |
"router",
|
215 |
route_query,
|
@@ -219,19 +218,13 @@ def build_graph(model_type: str = "huggingface", retriever=None, summarizer=None
|
|
219 |
},
|
220 |
)
|
221 |
|
222 |
-
# Define the standard path for the RAG pipeline
|
223 |
workflow.add_edge("retrieve", "generate")
|
224 |
-
|
225 |
-
# Define the end points for the graph
|
226 |
workflow.add_edge("generate", END)
|
227 |
workflow.add_edge("general_chat", END)
|
228 |
|
229 |
-
# Compile the graph
|
230 |
compiled_graph = workflow.compile()
|
231 |
|
232 |
-
# Return a function that wraps the graph invocation
|
233 |
def graph_wrapper(query: str, temperature: float = 0.7, raw_text: str = None, history=None, retriever_override=None):
|
234 |
-
# Use retriever_override if provided, otherwise use the build-time retriever
|
235 |
active_retriever = retriever_override or retriever
|
236 |
return compiled_graph.invoke({
|
237 |
"query": query,
|
|
|
1 |
import datetime
|
2 |
import os
|
3 |
+
import re
|
4 |
from typing import TypedDict, Optional, List
|
5 |
|
6 |
from llama_index.core.schema import Document
|
|
|
7 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
8 |
from langgraph.graph import StateGraph, END
|
9 |
from llama_index.core import VectorStoreIndex
|
10 |
from llama_index.core.retrievers import BaseRetriever
|
11 |
+
from langchain_groq import ChatGroq
|
12 |
+
from transformers import pipeline as hf_pipeline
|
13 |
|
14 |
# --- 1. Define the State for the Graph ---
|
15 |
class GraphState(TypedDict):
|
|
|
23 |
summarizer: Optional[any]
|
24 |
|
25 |
# --- 2. Define Graph Nodes ---
|
|
|
|
|
26 |
def router_node(state: GraphState) -> GraphState:
|
27 |
"""
|
28 |
Router that determines the next step based on available retriever.
|
29 |
"""
|
30 |
print("---NODE: ROUTER---")
|
|
|
31 |
return state
|
32 |
|
|
|
33 |
def general_chat_node(state: GraphState) -> GraphState:
|
34 |
"""
|
35 |
Generates a response for general conversation using the LLM.
|
|
|
38 |
llm = state["llm"]
|
39 |
query = state["query"]
|
40 |
history = state["history"]
|
|
|
|
|
41 |
history_context = "\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
|
42 |
current_time = datetime.datetime.now().strftime("%Y-%m-%d %I:%M %p")
|
43 |
prompt = f"""You are Sarathi, a friendly and knowledgeable AI assistant.
|
|
|
56 |
AI:"""
|
57 |
|
58 |
try:
|
59 |
+
response_obj = llm.invoke(prompt)
|
60 |
+
response_text = getattr(response_obj, "content", str(response_obj))
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
except Exception as e:
|
62 |
+
response_text = f"Model inference failed: {str(e)}"
|
63 |
|
64 |
return {"response": response_text.strip()}
|
65 |
|
|
|
66 |
def retrieve_node(state: GraphState) -> GraphState:
|
67 |
"""
|
68 |
Retrieves relevant documents from the vector store based on the query.
|
69 |
+
Summarizes context if too long, or truncates at sentence boundaries.
|
70 |
"""
|
71 |
print("---NODE: RETRIEVE---")
|
72 |
query = state["query"]
|
|
|
78 |
retrieved_docs = []
|
79 |
|
80 |
try:
|
|
|
81 |
q_len = len(query.split())
|
82 |
top_k = 3 if q_len < 5 else (5 if q_len < 15 else 8)
|
83 |
|
|
|
86 |
if retrieved_docs:
|
87 |
context = "\n\n---\n\n".join([doc.text for doc in retrieved_docs])
|
88 |
|
|
|
89 |
if history:
|
90 |
history_context = "\n\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
|
91 |
context = f"{context}\n\n--- Chat History ---\n{history_context}"
|
92 |
|
|
|
93 |
MAX_CONTEXT_CHARS = 4000
|
94 |
+
if len(context) > MAX_CONTEXT_CHARS:
|
95 |
+
try:
|
96 |
+
print("---CONTEXT TOO LONG, SUMMARIZING---")
|
97 |
+
summary_result = summarizer(
|
98 |
+
context,
|
99 |
+
max_length=500,
|
100 |
+
min_length=150,
|
101 |
+
do_sample=False
|
102 |
+
)
|
103 |
+
context = summary_result[0].get("summary_text", context[:MAX_CONTEXT_CHARS])
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Summarizer failed: {e}")
|
106 |
+
sentences = re.split(r'(?<=[.!?]) +', context)
|
107 |
+
truncated = []
|
108 |
+
total_len = 0
|
109 |
+
for sent in sentences:
|
110 |
+
if total_len + len(sent) > MAX_CONTEXT_CHARS:
|
111 |
+
break
|
112 |
+
truncated.append(sent)
|
113 |
+
total_len += len(sent)
|
114 |
+
context = " ".join(truncated)
|
115 |
|
116 |
except Exception as e:
|
117 |
print(f"Error in retrieve_node: {e}")
|
|
|
119 |
|
120 |
return {"retrieved_docs": retrieved_docs, "context": context}
|
121 |
|
|
|
122 |
def generate_rag_node(state: GraphState) -> GraphState:
|
123 |
"""
|
124 |
Generates an answer using the retrieved context from the PDF.
|
|
|
148 |
Answer:"""
|
149 |
|
150 |
try:
|
151 |
+
response_obj = llm.invoke(prompt)
|
152 |
+
response_text = getattr(response_obj, "content", str(response_obj))
|
153 |
+
|
|
|
|
|
|
|
|
|
|
|
154 |
except Exception as e:
|
155 |
+
response_text = f"Model inference failed: {str(e)}"
|
156 |
|
157 |
return {"response": response_text.strip()}
|
158 |
|
|
|
169 |
return "general_chat"
|
170 |
|
171 |
# --- 4. Build the Graph ---
|
172 |
+
def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
|
173 |
"""
|
174 |
+
Builds the workflow graph with LLM, retriever, and optional summarizer.
|
175 |
+
If summarizer not provided, initializes a default HuggingFace summarizer.
|
176 |
"""
|
177 |
+
|
178 |
if model_type == "groq":
|
|
|
179 |
api_key = os.getenv("GROQ_API_KEY", "").strip()
|
180 |
if not api_key:
|
181 |
raise ValueError("GROQ_API_KEY environment variable not set.")
|
182 |
llm = ChatGroq(
|
183 |
+
model="mixtral-8x7b-32768",
|
184 |
api_key=api_key,
|
185 |
temperature=0.7,
|
186 |
)
|
|
|
188 |
api_key = os.getenv("GEMINI_API_KEY", "").strip()
|
189 |
if not api_key:
|
190 |
raise ValueError("GEMINI_API_KEY environment variable not set.")
|
191 |
+
llm = ChatGoogleGenerativeAI(
|
192 |
+
model="gemini-2.0-flash",
|
193 |
+
api_key=api_key,
|
194 |
+
temperature=0.7
|
195 |
+
)
|
196 |
else:
|
197 |
raise ValueError("Invalid model_type. Choose 'groq' or 'gemini'.")
|
198 |
|
199 |
+
if summarizer is None:
|
200 |
+
print("---NO SUMMARIZER PROVIDED, USING DEFAULT (facebook/bart-large-cnn)---")
|
201 |
+
summarizer = hf_pipeline("summarization", model="facebook/bart-large-cnn")
|
202 |
+
|
203 |
workflow = StateGraph(GraphState)
|
204 |
|
|
|
205 |
workflow.add_node("router", router_node)
|
206 |
workflow.add_node("general_chat", general_chat_node)
|
207 |
workflow.add_node("retrieve", retrieve_node)
|
208 |
workflow.add_node("generate", generate_rag_node)
|
209 |
|
|
|
210 |
workflow.set_entry_point("router")
|
211 |
|
|
|
212 |
workflow.add_conditional_edges(
|
213 |
"router",
|
214 |
route_query,
|
|
|
218 |
},
|
219 |
)
|
220 |
|
|
|
221 |
workflow.add_edge("retrieve", "generate")
|
|
|
|
|
222 |
workflow.add_edge("generate", END)
|
223 |
workflow.add_edge("general_chat", END)
|
224 |
|
|
|
225 |
compiled_graph = workflow.compile()
|
226 |
|
|
|
227 |
def graph_wrapper(query: str, temperature: float = 0.7, raw_text: str = None, history=None, retriever_override=None):
|
|
|
228 |
active_retriever = retriever_override or retriever
|
229 |
return compiled_graph.invoke({
|
230 |
"query": query,
|
requirements.txt
CHANGED
@@ -6,14 +6,13 @@ accelerate>=0.30.0
|
|
6 |
# LangChain + LangGraph
|
7 |
langchain>=0.2.1
|
8 |
langgraph>=0.0.45
|
9 |
-
langchain-
|
10 |
-
langchain-
|
11 |
-
langchain-google-genai>=1.0.5 # wrapper for Gemini
|
12 |
|
13 |
# Retrieval + Embeddings
|
14 |
llama-index>=0.13.5
|
15 |
llama-index-embeddings-huggingface>=0.1.3
|
16 |
-
chromadb>=0.5.3
|
17 |
|
18 |
# Hugging Face + Deployment
|
19 |
huggingface_hub>=0.23.4
|
|
|
6 |
# LangChain + LangGraph
|
7 |
langchain>=0.2.1
|
8 |
langgraph>=0.0.45
|
9 |
+
langchain-groq>=0.1.0
|
10 |
+
langchain-google-genai>=1.0.5
|
|
|
11 |
|
12 |
# Retrieval + Embeddings
|
13 |
llama-index>=0.13.5
|
14 |
llama-index-embeddings-huggingface>=0.1.3
|
15 |
+
chromadb>=0.5.3
|
16 |
|
17 |
# Hugging Face + Deployment
|
18 |
huggingface_hub>=0.23.4
|