aryan195a commited on
Commit
6ed5d07
·
1 Parent(s): 9792cab

Modified document loading logic

Browse files
Files changed (3) hide show
  1. app.py +51 -57
  2. graph.py +107 -53
  3. requirements.txt +0 -2
app.py CHANGED
@@ -3,7 +3,6 @@ from graph import build_graph
3
  from utils import get_retriever, load_vectorstore_from_text
4
  from pypdf import PdfReader
5
  import hashlib
6
- from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
7
 
8
  # --- Page Config ---
9
  st.set_page_config(page_title="LangGraph RAG Chatbot", layout="wide")
@@ -74,30 +73,20 @@ with st.sidebar:
74
  else:
75
  st.warning("Uploaded file is empty or could not be read.")
76
 
77
- # Show current mode
78
- if "retriever" in st.session_state and st.session_state.retriever:
 
 
79
  st.info("📄 **RAG Mode**: Answering from uploaded document")
80
  else:
81
  st.info("💬 **General Chat Mode**: No document loaded")
82
 
83
- # --- Initialize Summarizer ---
84
- if "summarizer" not in st.session_state:
85
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
86
- model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
87
- st.session_state.summarizer = pipeline(
88
- "summarization",
89
- model=model,
90
- tokenizer=tokenizer,
91
- device=-1
92
- )
93
-
94
  # --- Build Graph ---
95
  if "graph" not in st.session_state or st.session_state.get("graph_model") != model_type:
96
  try:
97
  st.session_state.graph = build_graph(
98
  model_type=model_type,
99
- retriever=st.session_state.get("retriever"),
100
- summarizer=st.session_state.get("summarizer")
101
  )
102
  st.session_state.graph_model = model_type
103
  except Exception as e:
@@ -108,49 +97,55 @@ if "graph" not in st.session_state or st.session_state.get("graph_model") != mod
108
  if "history" not in st.session_state:
109
  st.session_state.history = []
110
 
 
 
 
 
111
  # --- Query Input ---
112
- query_input = st.text_input("💬 Ask a question:")
113
 
114
  send_triggered = st.button("Send")
115
 
116
- if send_triggered:
117
- if query_input.strip():
118
- formatted_history = [(q, r) for q, r, _ in st.session_state.history]
119
-
120
- with st.spinner("Generating response..."):
121
- try:
122
- result = st.session_state.graph(
123
- query=query_input,
124
- temperature=temperature,
125
- raw_text=st.session_state.get("raw_text"),
126
- history=formatted_history,
127
- retriever_override=st.session_state.get("retriever")
128
- )
129
-
130
- response = result.get("response", "No response generated.")
131
- retrieved_docs = result.get("retrieved_docs", [])
132
-
133
- st.markdown("### 🤖 Response")
134
- st.markdown(response)
135
-
136
- # Save to history
137
- st.session_state.history.append((query_input, response, retrieved_docs))
138
-
139
- # Show retrieved docs if available
140
- if retrieved_docs:
141
- with st.expander("📄 Retrieved Chunks"):
142
- for j, doc in enumerate(retrieved_docs):
143
- content = getattr(doc, "text", str(doc))
144
- st.markdown(f"**Chunk {j+1}:**")
145
- st.code(content.strip(), language="markdown")
146
-
147
- # Clear the input field by rerunning widget with empty value
148
- st.rerun()
149
-
150
- except Exception as e:
151
- st.error(f"Query failed: {e}")
152
- else:
153
- st.warning("Please enter a question.")
 
 
154
 
155
  # --- Chat History Display ---
156
  if st.session_state.history:
@@ -168,5 +163,4 @@ if st.session_state.history:
168
  # --- Clear Chat ---
169
  if st.sidebar.button("🗑️ Clear Chat History"):
170
  st.session_state.history = []
171
- st.session_state.current_query = ""
172
- st.rerun()
 
3
  from utils import get_retriever, load_vectorstore_from_text
4
  from pypdf import PdfReader
5
  import hashlib
 
6
 
7
  # --- Page Config ---
8
  st.set_page_config(page_title="LangGraph RAG Chatbot", layout="wide")
 
73
  else:
74
  st.warning("Uploaded file is empty or could not be read.")
75
 
76
+ # Show current mode with proper guard
77
+ if ("retriever" in st.session_state and
78
+ st.session_state.retriever is not None and
79
+ "raw_text" in st.session_state):
80
  st.info("📄 **RAG Mode**: Answering from uploaded document")
81
  else:
82
  st.info("💬 **General Chat Mode**: No document loaded")
83
 
 
 
 
 
 
 
 
 
 
 
 
84
  # --- Build Graph ---
85
  if "graph" not in st.session_state or st.session_state.get("graph_model") != model_type:
86
  try:
87
  st.session_state.graph = build_graph(
88
  model_type=model_type,
89
+ retriever=st.session_state.get("retriever")
 
90
  )
91
  st.session_state.graph_model = model_type
92
  except Exception as e:
 
97
  if "history" not in st.session_state:
98
  st.session_state.history = []
99
 
100
+ # --- Initialize current query to handle input clearing ---
101
+ if "current_query" not in st.session_state:
102
+ st.session_state.current_query = ""
103
+
104
  # --- Query Input ---
105
+ query_input = st.text_input("💬 Ask a question:", value=st.session_state.current_query, key="current_query")
106
 
107
  send_triggered = st.button("Send")
108
 
109
+ if send_triggered and query_input.strip():
110
+ formatted_history = [(q, r) for q, r, _ in st.session_state.history]
111
+
112
+ with st.spinner("Generating response..."):
113
+ try:
114
+ result = st.session_state.graph(
115
+ query=query_input,
116
+ temperature=temperature,
117
+ raw_text=st.session_state.get("raw_text"),
118
+ history=formatted_history,
119
+ retriever_override=st.session_state.get("retriever")
120
+ )
121
+
122
+ response = result.get("response", "No response generated.")
123
+ retrieved_docs = result.get("retrieved_docs", [])
124
+
125
+ # Display response immediately
126
+ st.markdown("### 🤖 Response")
127
+ st.markdown(response)
128
+
129
+ # Show retrieved docs if available
130
+ if retrieved_docs:
131
+ with st.expander("📄 Retrieved Chunks"):
132
+ for j, doc in enumerate(retrieved_docs):
133
+ content = getattr(doc, "text", str(doc))
134
+ st.markdown(f"**Chunk {j+1}:**")
135
+ st.code(content.strip(), language="markdown")
136
+
137
+ # Save to history after displaying
138
+ st.session_state.history.append((query_input, response, retrieved_docs))
139
+
140
+ # Clear the input field
141
+ st.session_state.current_query = ""
142
+ st.rerun()
143
+
144
+ except Exception as e:
145
+ st.error(f"Query failed: {e}")
146
+
147
+ elif send_triggered and not query_input.strip():
148
+ st.warning("Please enter a question.")
149
 
150
  # --- Chat History Display ---
151
  if st.session_state.history:
 
163
  # --- Clear Chat ---
164
  if st.sidebar.button("🗑️ Clear Chat History"):
165
  st.session_state.history = []
166
+ st.session_state.current_query = ""
 
graph.py CHANGED
@@ -6,13 +6,9 @@ from typing import TypedDict, Optional, List
6
  from llama_index.core.schema import Document
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_openai import ChatOpenAI
9
- from llama_index.core import VectorStoreIndex
10
  from llama_index.core.retrievers import BaseRetriever
11
- from langchain_groq import ChatGroq
12
- from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
13
  from langgraph.graph import StateGraph, END
14
 
15
- # --- 1. Define the State for the Graph ---
16
  class GraphState(TypedDict):
17
  query: str
18
  response: Optional[str]
@@ -21,9 +17,77 @@ class GraphState(TypedDict):
21
  history: list
22
  retriever: Optional[BaseRetriever]
23
  llm: any
24
- summarizer: Optional[any]
25
 
26
- # --- 2. Define Graph Nodes ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def router_node(state: GraphState) -> GraphState:
28
  """
29
  Router that determines the next step based on available retriever.
@@ -66,13 +130,11 @@ AI:"""
66
 
67
  def retrieve_node(state: GraphState) -> GraphState:
68
  """
69
- Retrieves relevant documents from the vector store based on the query.
70
- Summarizes context if too long, or truncates at sentence boundaries.
71
  """
72
  print("---NODE: RETRIEVE---")
73
  query = state["query"]
74
  retriever = state["retriever"]
75
- summarizer = state["summarizer"]
76
  history = state["history"]
77
 
78
  context = ""
@@ -80,43 +142,43 @@ def retrieve_node(state: GraphState) -> GraphState:
80
 
81
  try:
82
  q_len = len(query.split())
83
- top_k = 3 if q_len < 5 else (5 if q_len < 15 else 8)
 
 
 
 
 
84
 
85
  retrieved_docs = retriever.retrieve(query)
 
86
 
87
  if retrieved_docs:
88
- context = "\n\n---\n\n".join([doc.text for doc in retrieved_docs])
89
-
90
- if history:
91
- history_context = "\n\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
92
- context = f"{context}\n\n--- Chat History ---\n{history_context}"
93
-
94
- MAX_CONTEXT_CHARS = 4000
95
- if len(context) > MAX_CONTEXT_CHARS:
96
- try:
97
- print("---CONTEXT TOO LONG, SUMMARIZING---")
98
- summary_result = summarizer(
99
- context,
100
- max_length=500,
101
- min_length=150,
102
- do_sample=False
103
- )
104
- context = summary_result[0].get("summary_text", context[:MAX_CONTEXT_CHARS])
105
- except Exception as e:
106
- print(f"Summarizer failed: {e}")
107
- sentences = re.split(r'(?<=[.!?]) +', context)
108
- truncated = []
109
- total_len = 0
110
- for sent in sentences:
111
- if total_len + len(sent) > MAX_CONTEXT_CHARS:
112
- break
113
- truncated.append(sent)
114
- total_len += len(sent)
115
- context = " ".join(truncated)
116
 
117
  except Exception as e:
118
  print(f"Error in retrieve_node: {e}")
119
  context = f"Retriever failed: {str(e)}"
 
120
 
121
  return {"retrieved_docs": retrieved_docs, "context": context}
122
 
@@ -157,7 +219,7 @@ Answer:"""
157
 
158
  return {"response": response_text.strip()}
159
 
160
- # --- 3. Define the Router Logic ---
161
  def route_query(state: GraphState) -> str:
162
  """
163
  Checks if a retriever is available in the state to decide the next step.
@@ -169,11 +231,10 @@ def route_query(state: GraphState) -> str:
169
  print("---ROUTING: No PDF, routing to general_chat_node---")
170
  return "general_chat"
171
 
172
- # --- 4. Build the Graph ---
173
- def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
174
  """
175
- Builds the workflow graph with LLM, retriever, and optional summarizer.
176
- If summarizer not provided, initializes a default HuggingFace summarizer.
177
  """
178
 
179
  def make_llm(temp: float):
@@ -187,6 +248,7 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
187
  api_key=api_key,
188
  temperature=temp,
189
  )
 
190
  elif model_type == "gemini":
191
  api_key = os.getenv("GEMINI_API_KEY", "").strip()
192
  if not api_key:
@@ -196,16 +258,9 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
196
  api_key=api_key,
197
  temperature=temp,
198
  )
 
 
199
 
200
- def get_default_summarizer():
201
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
202
- model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
203
- return pipeline("summarization", model=model, tokenizer=tokenizer, device=-1)
204
-
205
- if summarizer is None:
206
- print("---NO SUMMARIZER PROVIDED, USING DEFAULT (facebook/bart-large-cnn)---")
207
- summarizer = get_default_summarizer()
208
-
209
  workflow = StateGraph(GraphState)
210
 
211
  workflow.add_node("router", router_node)
@@ -238,7 +293,6 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
238
  "retriever": active_retriever.as_retriever() if active_retriever else None,
239
  "history": history or [],
240
  "llm": llm,
241
- "summarizer": summarizer,
242
  "response": None,
243
  "retrieved_docs": None,
244
  "context": None,
 
6
  from llama_index.core.schema import Document
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_openai import ChatOpenAI
 
9
  from llama_index.core.retrievers import BaseRetriever
 
 
10
  from langgraph.graph import StateGraph, END
11
 
 
12
  class GraphState(TypedDict):
13
  query: str
14
  response: Optional[str]
 
17
  history: list
18
  retriever: Optional[BaseRetriever]
19
  llm: any
 
20
 
21
+ def _select_chunks_by_relevance_and_position(docs: List[Document], query: str, target_count: int) -> List[Document]:
22
+ """Select chunks based on semantic relevance + positional importance without summarization."""
23
+ if len(docs) <= target_count:
24
+ return docs
25
+
26
+ selected = []
27
+ if len(docs) >= 2:
28
+ selected = [docs[0], docs[-1]]
29
+ remaining_slots = target_count - 2
30
+ middle_docs = docs[1:-1]
31
+ else:
32
+ remaining_slots = target_count
33
+ middle_docs = docs[:]
34
+
35
+ if remaining_slots > 0 and middle_docs:
36
+ query_words = set(query.lower().split())
37
+ scored_docs = []
38
+
39
+ for doc in middle_docs:
40
+ content = doc.text.lower()
41
+ relevance_score = sum(content.count(word) for word in query_words)
42
+ length_bonus = len(doc.text) / 1000
43
+ total_score = relevance_score + length_bonus
44
+ scored_docs.append((total_score, doc))
45
+
46
+ scored_docs.sort(key=lambda x: x[0], reverse=True)
47
+ selected.extend([doc for _, doc in scored_docs[:remaining_slots]])
48
+
49
+ return selected[:target_count]
50
+
51
+ def _create_context_with_priorities(docs: List[Document], max_chars: int) -> str:
52
+ """Create context by prioritizing important chunks without summarization."""
53
+ if not docs:
54
+ return ""
55
+
56
+ contexts = []
57
+ total_chars = 0
58
+
59
+ priority_docs = []
60
+
61
+ if len(docs) >= 2:
62
+ priority_docs.append(("BOUNDARY", docs[0]))
63
+ priority_docs.append(("BOUNDARY", docs[-1]))
64
+
65
+ middle_docs = docs[1:-1] if len(docs) > 2 else []
66
+ middle_docs.sort(key=lambda d: len(d.text), reverse=True)
67
+ priority_docs.extend([("CONTENT", doc) for doc in middle_docs])
68
+ else:
69
+ priority_docs = [("CONTENT", doc) for doc in docs]
70
+
71
+ for priority_type, doc in priority_docs:
72
+ content = doc.text
73
+
74
+ if total_chars + len(content) > max_chars:
75
+ remaining_chars = max_chars - total_chars
76
+ if remaining_chars > 200:
77
+ truncated = content[:remaining_chars]
78
+ last_period = truncated.rfind('.')
79
+ if last_period > remaining_chars * 0.8:
80
+ truncated = truncated[:last_period + 1]
81
+ truncated += "...[truncated]"
82
+ contexts.append(f"[{priority_type}] {truncated}")
83
+ break
84
+
85
+ contexts.append(f"[{priority_type}] {content}")
86
+ total_chars += len(content)
87
+
88
+ return "\n\n---\n\n".join(contexts)
89
+
90
+
91
  def router_node(state: GraphState) -> GraphState:
92
  """
93
  Router that determines the next step based on available retriever.
 
130
 
131
  def retrieve_node(state: GraphState) -> GraphState:
132
  """
133
+ Retrieves relevant documents and creates intelligent context without summarization.
 
134
  """
135
  print("---NODE: RETRIEVE---")
136
  query = state["query"]
137
  retriever = state["retriever"]
 
138
  history = state["history"]
139
 
140
  context = ""
 
142
 
143
  try:
144
  q_len = len(query.split())
145
+ if q_len < 5:
146
+ top_k = 5
147
+ elif q_len < 15:
148
+ top_k = 8
149
+ else:
150
+ top_k = 12
151
 
152
  retrieved_docs = retriever.retrieve(query)
153
+ print(f"Retrieved {len(retrieved_docs)} documents")
154
 
155
  if retrieved_docs:
156
+ max_chunks = min(len(retrieved_docs), top_k)
157
+ selected_docs = _select_chunks_by_relevance_and_position(
158
+ retrieved_docs, query, max_chunks
159
+ )
160
+
161
+ print(f"Selected {len(selected_docs)} chunks for context")
162
+
163
+ MAX_CONTEXT_CHARS = 6000
164
+ doc_context = _create_context_with_priorities(selected_docs, MAX_CONTEXT_CHARS)
165
+
166
+ if history and len(doc_context) < MAX_CONTEXT_CHARS * 0.8:
167
+ history_context = "\n\n".join([f"Human: {q}\nAI: {a}" for q, a in history[-3:]]) # Last 3 exchanges
168
+ remaining_chars = MAX_CONTEXT_CHARS - len(doc_context)
169
+ if len(history_context) <= remaining_chars:
170
+ context = f"{doc_context}\n\n--- Recent Chat History ---\n{history_context}"
171
+ else:
172
+ context = doc_context
173
+ else:
174
+ context = doc_context
175
+ else:
176
+ context = "No relevant content found in the document."
 
 
 
 
 
 
 
177
 
178
  except Exception as e:
179
  print(f"Error in retrieve_node: {e}")
180
  context = f"Retriever failed: {str(e)}"
181
+ retrieved_docs = []
182
 
183
  return {"retrieved_docs": retrieved_docs, "context": context}
184
 
 
219
 
220
  return {"response": response_text.strip()}
221
 
222
+
223
  def route_query(state: GraphState) -> str:
224
  """
225
  Checks if a retriever is available in the state to decide the next step.
 
231
  print("---ROUTING: No PDF, routing to general_chat_node---")
232
  return "general_chat"
233
 
234
+
235
+ def build_graph(model_type: str = "groq", retriever=None):
236
  """
237
+ Builds the workflow graph with LLM and retriever.
 
238
  """
239
 
240
  def make_llm(temp: float):
 
248
  api_key=api_key,
249
  temperature=temp,
250
  )
251
+
252
  elif model_type == "gemini":
253
  api_key = os.getenv("GEMINI_API_KEY", "").strip()
254
  if not api_key:
 
258
  api_key=api_key,
259
  temperature=temp,
260
  )
261
+ else:
262
+ raise ValueError("Invalid model_type. Choose 'groq' or 'gemini'.")
263
 
 
 
 
 
 
 
 
 
 
264
  workflow = StateGraph(GraphState)
265
 
266
  workflow.add_node("router", router_node)
 
293
  "retriever": active_retriever.as_retriever() if active_retriever else None,
294
  "history": history or [],
295
  "llm": llm,
 
296
  "response": None,
297
  "retrieved_docs": None,
298
  "context": None,
requirements.txt CHANGED
@@ -1,7 +1,5 @@
1
  # Core LLM + Transformers
2
- transformers[sentencepiece]>=4.40.0
3
  sentence-transformers>=2.6.0
4
- accelerate>=0.30.0
5
 
6
  # LangChain + LangGraph
7
  langchain>=0.2.1
 
1
  # Core LLM + Transformers
 
2
  sentence-transformers>=2.6.0
 
3
 
4
  # LangChain + LangGraph
5
  langchain>=0.2.1