aryan195a commited on
Commit
dd83a42
·
1 Parent(s): 8fa7103

Modified all files with the replacement of flan-t5 with groq api

Browse files
Files changed (4) hide show
  1. Dockerfile +5 -10
  2. app.py +60 -44
  3. graph.py +46 -53
  4. requirements.txt +3 -4
Dockerfile CHANGED
@@ -3,27 +3,22 @@ FROM python:3.10-slim
3
  WORKDIR /app
4
 
5
  RUN apt-get update && apt-get install -y --no-install-recommends \
6
- build-essential \
7
  git \
8
  curl \
9
- libopenblas-dev \
10
- libomp-dev \
11
- python3-dev \
12
  && apt-get clean \
13
  && rm -rf /var/lib/apt/lists/*
14
 
15
- COPY . /app
16
-
17
  ENV PYTHONUNBUFFERED=1 \
18
  PYTHONDONTWRITEBYTECODE=1 \
19
  LANG=C.UTF-8
20
 
21
- RUN pip install --no-cache-dir --upgrade pip setuptools wheel
22
 
23
- RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
 
24
 
25
- RUN pip install --no-cache-dir -r requirements.txt
26
 
27
  EXPOSE 7860
28
 
29
- CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
 
3
  WORKDIR /app
4
 
5
  RUN apt-get update && apt-get install -y --no-install-recommends \
 
6
  git \
7
  curl \
 
 
 
8
  && apt-get clean \
9
  && rm -rf /var/lib/apt/lists/*
10
 
 
 
11
  ENV PYTHONUNBUFFERED=1 \
12
  PYTHONDONTWRITEBYTECODE=1 \
13
  LANG=C.UTF-8
14
 
15
+ COPY requirements.txt .
16
 
17
+ RUN pip install --no-cache-dir --upgrade pip setuptools wheel \
18
+ && pip install --no-cache-dir -r requirements.txt
19
 
20
+ COPY . .
21
 
22
  EXPOSE 7860
23
 
24
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
app.py CHANGED
@@ -5,19 +5,11 @@ from pypdf import PdfReader
5
  import hashlib
6
  from transformers import pipeline
7
 
 
8
  st.set_page_config(page_title="LangGraph RAG Chatbot", layout="wide")
9
  st.title("📚 LangGraph RAG Chatbot")
10
 
11
- # Cached Vectorstore
12
- @st.cache_resource(show_spinner=False)
13
- def cached_vectorstore_from_text(text):
14
- try:
15
- return load_vectorstore_from_text(text=text)
16
- except Exception as e:
17
- st.warning(f"Failed to load vectorstore: {e}")
18
- return get_retriever(text)
19
-
20
- # Helpers
21
  def compute_file_hash(raw_text):
22
  return hashlib.md5(raw_text.encode("utf-8")).hexdigest() if raw_text else None
23
 
@@ -34,7 +26,20 @@ def load_uploaded_file(uploaded_file):
34
  st.error(f"Error reading file: {e}")
35
  return raw_text
36
 
37
- # Sidebar: Settings & File Upload
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  with st.sidebar:
39
  st.header("🔧 Settings")
40
  temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.7)
@@ -47,32 +52,35 @@ with st.sidebar:
47
 
48
  uploaded_file = st.file_uploader("Upload a file (optional)", type=["txt", "pdf"])
49
 
50
- # Handle file upload
51
  if uploaded_file:
52
  raw_text = load_uploaded_file(uploaded_file)
53
  if raw_text:
54
- # Reset session state when new file is uploaded
55
- st.session_state.history = []
56
- for key in ["retriever", "file_hash"]:
57
- st.session_state.pop(key, None)
58
 
59
- st.session_state.raw_text = raw_text
60
- st.session_state.retriever = cached_vectorstore_from_text(raw_text)
61
- st.session_state.file_hash = compute_file_hash(raw_text)
 
 
 
 
 
 
 
62
 
63
  st.markdown("**📄 Uploaded File Preview:**")
64
  st.text_area("Contents", raw_text, height=200)
65
- st.success("✅ Document loaded! You can now ask questions about it.")
66
  else:
67
  st.warning("Uploaded file is empty or could not be read.")
68
-
69
  # Show current mode
70
  if "retriever" in st.session_state and st.session_state.retriever:
71
  st.info("📄 **RAG Mode**: Answering from uploaded document")
72
  else:
73
  st.info("💬 **General Chat Mode**: No document loaded")
74
 
75
- # Initialize summarizer
76
  if "summarizer" not in st.session_state:
77
  st.session_state.summarizer = pipeline(
78
  "summarization",
@@ -80,7 +88,7 @@ if "summarizer" not in st.session_state:
80
  device=-1
81
  )
82
 
83
- # Build Graph
84
  if "graph" not in st.session_state or st.session_state.get("graph_model") != model_type:
85
  try:
86
  st.session_state.graph = build_graph(
@@ -93,50 +101,57 @@ if "graph" not in st.session_state or st.session_state.get("graph_model") != mod
93
  st.error(f"Failed to build graph: {e}")
94
  st.stop()
95
 
96
- # Chat History
97
  if "history" not in st.session_state:
98
  st.session_state.history = []
99
 
100
- # Query Input
101
- query = st.text_input("💬 Ask a question:")
 
102
 
103
- if st.button("Send") or query:
104
- if query.strip():
 
 
 
 
 
 
105
  try:
106
- # Prepare history in the format expected by the graph (tuples of (query, response))
107
- formatted_history = [(q, r) for q, r, _ in st.session_state.history]
108
-
109
  result = st.session_state.graph(
110
  query=query,
111
  temperature=temperature,
112
- raw_text=st.session_state.get("raw_text", None),
113
  history=formatted_history,
114
  retriever_override=st.session_state.get("retriever")
115
  )
116
-
117
  response = result.get("response", "No response generated.")
118
  retrieved_docs = result.get("retrieved_docs", [])
119
-
120
  st.markdown("### 🤖 Response")
121
  st.markdown(response)
122
-
123
- # Add to history (keeping the original format for display)
124
  st.session_state.history.append((query, response, retrieved_docs))
125
-
126
- # Show retrieved documents if available
127
  if retrieved_docs:
128
  with st.expander("📄 Retrieved Chunks"):
129
  for j, doc in enumerate(retrieved_docs):
130
  content = getattr(doc, "text", str(doc))
131
  st.markdown(f"**Chunk {j+1}:**")
132
  st.code(content.strip(), language="markdown")
133
-
 
 
 
134
  except Exception as e:
135
  st.error(f"Query failed: {e}")
136
- else:
137
- st.warning("Please enter a question.")
138
 
139
- # Display Chat History
140
  if st.session_state.history:
141
  st.markdown("### 💬 Chat History")
142
  for i, (q, r, docs) in enumerate(reversed(st.session_state.history)):
@@ -149,7 +164,8 @@ if st.session_state.history:
149
  content = getattr(doc, "text", str(doc))
150
  st.code(content.strip()[:200] + "...", language="markdown")
151
 
152
- # Clear chat button
153
  if st.sidebar.button("🗑️ Clear Chat History"):
154
  st.session_state.history = []
155
- st.rerun()
 
 
5
  import hashlib
6
  from transformers import pipeline
7
 
8
+ # --- Page Config ---
9
  st.set_page_config(page_title="LangGraph RAG Chatbot", layout="wide")
10
  st.title("📚 LangGraph RAG Chatbot")
11
 
12
+ # --- Helpers ---
 
 
 
 
 
 
 
 
 
13
  def compute_file_hash(raw_text):
14
  return hashlib.md5(raw_text.encode("utf-8")).hexdigest() if raw_text else None
15
 
 
26
  st.error(f"Error reading file: {e}")
27
  return raw_text
28
 
29
+ # --- Cached Vectorstore with Persistent Cache ---
30
+ @st.cache_data(show_spinner=False)
31
+ def cached_vectorstore_from_text(raw_text_hash: str, text: str):
32
+ """
33
+ Cache vectorstore based on hash of raw text.
34
+ If the same text is uploaded again, returns cached retriever.
35
+ """
36
+ try:
37
+ return load_vectorstore_from_text(text=text)
38
+ except Exception as e:
39
+ st.warning(f"Failed to load vectorstore: {e}")
40
+ return get_retriever(text)
41
+
42
+ # --- Sidebar ---
43
  with st.sidebar:
44
  st.header("🔧 Settings")
45
  temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.7)
 
52
 
53
  uploaded_file = st.file_uploader("Upload a file (optional)", type=["txt", "pdf"])
54
 
 
55
  if uploaded_file:
56
  raw_text = load_uploaded_file(uploaded_file)
57
  if raw_text:
58
+ file_hash = compute_file_hash(raw_text)
 
 
 
59
 
60
+ # Reset session if new file
61
+ if st.session_state.get("file_hash") != file_hash:
62
+ st.session_state.history = []
63
+ for key in ["retriever", "file_hash"]:
64
+ st.session_state.pop(key, None)
65
+
66
+ st.session_state.raw_text = raw_text
67
+ st.session_state.file_hash = file_hash
68
+ # Persistent cached vectorstore
69
+ st.session_state.retriever = cached_vectorstore_from_text(file_hash, raw_text)
70
 
71
  st.markdown("**📄 Uploaded File Preview:**")
72
  st.text_area("Contents", raw_text, height=200)
73
+ st.success("✅ Document loaded!")
74
  else:
75
  st.warning("Uploaded file is empty or could not be read.")
76
+
77
  # Show current mode
78
  if "retriever" in st.session_state and st.session_state.retriever:
79
  st.info("📄 **RAG Mode**: Answering from uploaded document")
80
  else:
81
  st.info("💬 **General Chat Mode**: No document loaded")
82
 
83
+ # --- Initialize Summarizer ---
84
  if "summarizer" not in st.session_state:
85
  st.session_state.summarizer = pipeline(
86
  "summarization",
 
88
  device=-1
89
  )
90
 
91
+ # --- Build Graph ---
92
  if "graph" not in st.session_state or st.session_state.get("graph_model") != model_type:
93
  try:
94
  st.session_state.graph = build_graph(
 
101
  st.error(f"Failed to build graph: {e}")
102
  st.stop()
103
 
104
+ # --- Initialize History ---
105
  if "history" not in st.session_state:
106
  st.session_state.history = []
107
 
108
+ # --- Query Input ---
109
+ if "current_query" not in st.session_state:
110
+ st.session_state.current_query = ""
111
 
112
+ query = st.text_input("💬 Ask a question:", key="current_query")
113
+ send_triggered = st.button("Send")
114
+
115
+ # --- Send Query ---
116
+ if send_triggered and query.strip():
117
+ formatted_history = [(q, r) for q, r, _ in st.session_state.history]
118
+
119
+ with st.spinner("Generating response..."):
120
  try:
 
 
 
121
  result = st.session_state.graph(
122
  query=query,
123
  temperature=temperature,
124
+ raw_text=st.session_state.get("raw_text"),
125
  history=formatted_history,
126
  retriever_override=st.session_state.get("retriever")
127
  )
128
+
129
  response = result.get("response", "No response generated.")
130
  retrieved_docs = result.get("retrieved_docs", [])
131
+
132
  st.markdown("### 🤖 Response")
133
  st.markdown(response)
134
+
135
+ # Save to history
136
  st.session_state.history.append((query, response, retrieved_docs))
137
+
138
+ # Show retrieved docs
139
  if retrieved_docs:
140
  with st.expander("📄 Retrieved Chunks"):
141
  for j, doc in enumerate(retrieved_docs):
142
  content = getattr(doc, "text", str(doc))
143
  st.markdown(f"**Chunk {j+1}:**")
144
  st.code(content.strip(), language="markdown")
145
+
146
+ # Clear input
147
+ st.session_state.current_query = ""
148
+
149
  except Exception as e:
150
  st.error(f"Query failed: {e}")
151
+ elif send_triggered:
152
+ st.warning("Please enter a question.")
153
 
154
+ # --- Chat History Display ---
155
  if st.session_state.history:
156
  st.markdown("### 💬 Chat History")
157
  for i, (q, r, docs) in enumerate(reversed(st.session_state.history)):
 
164
  content = getattr(doc, "text", str(doc))
165
  st.code(content.strip()[:200] + "...", language="markdown")
166
 
167
+ # --- Clear Chat ---
168
  if st.sidebar.button("🗑️ Clear Chat History"):
169
  st.session_state.history = []
170
+ st.session_state.current_query = ""
171
+ st.rerun()
graph.py CHANGED
@@ -1,14 +1,15 @@
1
  import datetime
2
  import os
 
3
  from typing import TypedDict, Optional, List
4
 
5
  from llama_index.core.schema import Document
6
- from langchain.llms.huggingface_pipeline import HuggingFacePipeline
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
- from transformers import pipeline as hf_pipeline
9
  from langgraph.graph import StateGraph, END
10
  from llama_index.core import VectorStoreIndex
11
  from llama_index.core.retrievers import BaseRetriever
 
 
12
 
13
  # --- 1. Define the State for the Graph ---
14
  class GraphState(TypedDict):
@@ -22,17 +23,13 @@ class GraphState(TypedDict):
22
  summarizer: Optional[any]
23
 
24
  # --- 2. Define Graph Nodes ---
25
-
26
- # Router node to decide the flow
27
  def router_node(state: GraphState) -> GraphState:
28
  """
29
  Router that determines the next step based on available retriever.
30
  """
31
  print("---NODE: ROUTER---")
32
- # This node just passes through the state - routing logic is in conditional edges
33
  return state
34
 
35
- # Node for handling general conversation when no PDF is loaded
36
  def general_chat_node(state: GraphState) -> GraphState:
37
  """
38
  Generates a response for general conversation using the LLM.
@@ -41,8 +38,6 @@ def general_chat_node(state: GraphState) -> GraphState:
41
  llm = state["llm"]
42
  query = state["query"]
43
  history = state["history"]
44
-
45
- # Format history for the prompt
46
  history_context = "\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
47
  current_time = datetime.datetime.now().strftime("%Y-%m-%d %I:%M %p")
48
  prompt = f"""You are Sarathi, a friendly and knowledgeable AI assistant.
@@ -61,23 +56,17 @@ Human: {query}
61
  AI:"""
62
 
63
  try:
64
- if isinstance(llm, HuggingFacePipeline):
65
- response_text = llm.invoke(prompt)
66
- elif isinstance(llm, ChatGoogleGenerativeAI):
67
- response_obj = llm.invoke(prompt)
68
- response_text = getattr(response_obj, "content", str(response_obj))
69
- else:
70
- response_text = "Unsupported LLM type provided."
71
-
72
  except Exception as e:
73
- response_text = f"Model inference failed in general chat: {str(e)}"
74
 
75
  return {"response": response_text.strip()}
76
 
77
- # Node for retrieving information from a PDF
78
  def retrieve_node(state: GraphState) -> GraphState:
79
  """
80
  Retrieves relevant documents from the vector store based on the query.
 
81
  """
82
  print("---NODE: RETRIEVE---")
83
  query = state["query"]
@@ -89,7 +78,6 @@ def retrieve_node(state: GraphState) -> GraphState:
89
  retrieved_docs = []
90
 
91
  try:
92
- # Dynamic top_k based on query length
93
  q_len = len(query.split())
94
  top_k = 3 if q_len < 5 else (5 if q_len < 15 else 8)
95
 
@@ -98,19 +86,32 @@ def retrieve_node(state: GraphState) -> GraphState:
98
  if retrieved_docs:
99
  context = "\n\n---\n\n".join([doc.text for doc in retrieved_docs])
100
 
101
- # Add chat history to context
102
  if history:
103
  history_context = "\n\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
104
  context = f"{context}\n\n--- Chat History ---\n{history_context}"
105
 
106
- # Summarize if context is too long
107
  MAX_CONTEXT_CHARS = 4000
108
- if len(context) > MAX_CONTEXT_CHARS and summarizer:
109
- print("---CONTEXT TOO LONG, SUMMARIZING---")
110
- summary_result = summarizer(context, max_length=500, min_length=150, do_sample=False)
111
- context = summary_result[0]['summary_text']
112
- elif len(context) > MAX_CONTEXT_CHARS:
113
- context = context[:MAX_CONTEXT_CHARS]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  except Exception as e:
116
  print(f"Error in retrieve_node: {e}")
@@ -118,7 +119,6 @@ def retrieve_node(state: GraphState) -> GraphState:
118
 
119
  return {"retrieved_docs": retrieved_docs, "context": context}
120
 
121
- # Node for generating a response from RAG context
122
  def generate_rag_node(state: GraphState) -> GraphState:
123
  """
124
  Generates an answer using the retrieved context from the PDF.
@@ -148,16 +148,11 @@ Instructions:
148
  Answer:"""
149
 
150
  try:
151
- if isinstance(llm, HuggingFacePipeline):
152
- response_text = llm.invoke(prompt)
153
- elif isinstance(llm, ChatGoogleGenerativeAI):
154
- response_obj = llm.invoke(prompt)
155
- response_text = getattr(response_obj, "content", str(response_obj))
156
- else:
157
- response_text = "Unsupported LLM type provided."
158
-
159
  except Exception as e:
160
- response_text = f"Model inference failed during RAG generation: {str(e)}"
161
 
162
  return {"response": response_text.strip()}
163
 
@@ -174,18 +169,18 @@ def route_query(state: GraphState) -> str:
174
  return "general_chat"
175
 
176
  # --- 4. Build the Graph ---
177
- def build_graph(model_type: str = "huggingface", retriever=None, summarizer=None):
178
  """
179
- Builds the conditional LangGraph workflow.
 
180
  """
181
- # Configure the LLM based on the selected model type
182
  if model_type == "groq":
183
- from langchain_groq import ChatGroq
184
  api_key = os.getenv("GROQ_API_KEY", "").strip()
185
  if not api_key:
186
  raise ValueError("GROQ_API_KEY environment variable not set.")
187
  llm = ChatGroq(
188
- model="mixtral-8x7b-32768", # Fast and capable model
189
  api_key=api_key,
190
  temperature=0.7,
191
  )
@@ -193,23 +188,27 @@ def build_graph(model_type: str = "huggingface", retriever=None, summarizer=None
193
  api_key = os.getenv("GEMINI_API_KEY", "").strip()
194
  if not api_key:
195
  raise ValueError("GEMINI_API_KEY environment variable not set.")
196
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", api_key=api_key, temperature=0.7)
 
 
 
 
197
  else:
198
  raise ValueError("Invalid model_type. Choose 'groq' or 'gemini'.")
199
 
200
- # Define the graph structure
 
 
 
201
  workflow = StateGraph(GraphState)
202
 
203
- # Add all the nodes to the graph
204
  workflow.add_node("router", router_node)
205
  workflow.add_node("general_chat", general_chat_node)
206
  workflow.add_node("retrieve", retrieve_node)
207
  workflow.add_node("generate", generate_rag_node)
208
 
209
- # Set the router as the entry point
210
  workflow.set_entry_point("router")
211
 
212
- # Add the conditional edge from the router
213
  workflow.add_conditional_edges(
214
  "router",
215
  route_query,
@@ -219,19 +218,13 @@ def build_graph(model_type: str = "huggingface", retriever=None, summarizer=None
219
  },
220
  )
221
 
222
- # Define the standard path for the RAG pipeline
223
  workflow.add_edge("retrieve", "generate")
224
-
225
- # Define the end points for the graph
226
  workflow.add_edge("generate", END)
227
  workflow.add_edge("general_chat", END)
228
 
229
- # Compile the graph
230
  compiled_graph = workflow.compile()
231
 
232
- # Return a function that wraps the graph invocation
233
  def graph_wrapper(query: str, temperature: float = 0.7, raw_text: str = None, history=None, retriever_override=None):
234
- # Use retriever_override if provided, otherwise use the build-time retriever
235
  active_retriever = retriever_override or retriever
236
  return compiled_graph.invoke({
237
  "query": query,
 
1
  import datetime
2
  import os
3
+ import re
4
  from typing import TypedDict, Optional, List
5
 
6
  from llama_index.core.schema import Document
 
7
  from langchain_google_genai import ChatGoogleGenerativeAI
 
8
  from langgraph.graph import StateGraph, END
9
  from llama_index.core import VectorStoreIndex
10
  from llama_index.core.retrievers import BaseRetriever
11
+ from langchain_groq import ChatGroq
12
+ from transformers import pipeline as hf_pipeline
13
 
14
  # --- 1. Define the State for the Graph ---
15
  class GraphState(TypedDict):
 
23
  summarizer: Optional[any]
24
 
25
  # --- 2. Define Graph Nodes ---
 
 
26
  def router_node(state: GraphState) -> GraphState:
27
  """
28
  Router that determines the next step based on available retriever.
29
  """
30
  print("---NODE: ROUTER---")
 
31
  return state
32
 
 
33
  def general_chat_node(state: GraphState) -> GraphState:
34
  """
35
  Generates a response for general conversation using the LLM.
 
38
  llm = state["llm"]
39
  query = state["query"]
40
  history = state["history"]
 
 
41
  history_context = "\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
42
  current_time = datetime.datetime.now().strftime("%Y-%m-%d %I:%M %p")
43
  prompt = f"""You are Sarathi, a friendly and knowledgeable AI assistant.
 
56
  AI:"""
57
 
58
  try:
59
+ response_obj = llm.invoke(prompt)
60
+ response_text = getattr(response_obj, "content", str(response_obj))
 
 
 
 
 
 
61
  except Exception as e:
62
+ response_text = f"Model inference failed: {str(e)}"
63
 
64
  return {"response": response_text.strip()}
65
 
 
66
  def retrieve_node(state: GraphState) -> GraphState:
67
  """
68
  Retrieves relevant documents from the vector store based on the query.
69
+ Summarizes context if too long, or truncates at sentence boundaries.
70
  """
71
  print("---NODE: RETRIEVE---")
72
  query = state["query"]
 
78
  retrieved_docs = []
79
 
80
  try:
 
81
  q_len = len(query.split())
82
  top_k = 3 if q_len < 5 else (5 if q_len < 15 else 8)
83
 
 
86
  if retrieved_docs:
87
  context = "\n\n---\n\n".join([doc.text for doc in retrieved_docs])
88
 
 
89
  if history:
90
  history_context = "\n\n".join([f"Human: {q}\nAI: {a}" for q, a in history])
91
  context = f"{context}\n\n--- Chat History ---\n{history_context}"
92
 
 
93
  MAX_CONTEXT_CHARS = 4000
94
+ if len(context) > MAX_CONTEXT_CHARS:
95
+ try:
96
+ print("---CONTEXT TOO LONG, SUMMARIZING---")
97
+ summary_result = summarizer(
98
+ context,
99
+ max_length=500,
100
+ min_length=150,
101
+ do_sample=False
102
+ )
103
+ context = summary_result[0].get("summary_text", context[:MAX_CONTEXT_CHARS])
104
+ except Exception as e:
105
+ print(f"Summarizer failed: {e}")
106
+ sentences = re.split(r'(?<=[.!?]) +', context)
107
+ truncated = []
108
+ total_len = 0
109
+ for sent in sentences:
110
+ if total_len + len(sent) > MAX_CONTEXT_CHARS:
111
+ break
112
+ truncated.append(sent)
113
+ total_len += len(sent)
114
+ context = " ".join(truncated)
115
 
116
  except Exception as e:
117
  print(f"Error in retrieve_node: {e}")
 
119
 
120
  return {"retrieved_docs": retrieved_docs, "context": context}
121
 
 
122
  def generate_rag_node(state: GraphState) -> GraphState:
123
  """
124
  Generates an answer using the retrieved context from the PDF.
 
148
  Answer:"""
149
 
150
  try:
151
+ response_obj = llm.invoke(prompt)
152
+ response_text = getattr(response_obj, "content", str(response_obj))
153
+
 
 
 
 
 
154
  except Exception as e:
155
+ response_text = f"Model inference failed: {str(e)}"
156
 
157
  return {"response": response_text.strip()}
158
 
 
169
  return "general_chat"
170
 
171
  # --- 4. Build the Graph ---
172
+ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
173
  """
174
+ Builds the workflow graph with LLM, retriever, and optional summarizer.
175
+ If summarizer not provided, initializes a default HuggingFace summarizer.
176
  """
177
+
178
  if model_type == "groq":
 
179
  api_key = os.getenv("GROQ_API_KEY", "").strip()
180
  if not api_key:
181
  raise ValueError("GROQ_API_KEY environment variable not set.")
182
  llm = ChatGroq(
183
+ model="mixtral-8x7b-32768",
184
  api_key=api_key,
185
  temperature=0.7,
186
  )
 
188
  api_key = os.getenv("GEMINI_API_KEY", "").strip()
189
  if not api_key:
190
  raise ValueError("GEMINI_API_KEY environment variable not set.")
191
+ llm = ChatGoogleGenerativeAI(
192
+ model="gemini-2.0-flash",
193
+ api_key=api_key,
194
+ temperature=0.7
195
+ )
196
  else:
197
  raise ValueError("Invalid model_type. Choose 'groq' or 'gemini'.")
198
 
199
+ if summarizer is None:
200
+ print("---NO SUMMARIZER PROVIDED, USING DEFAULT (facebook/bart-large-cnn)---")
201
+ summarizer = hf_pipeline("summarization", model="facebook/bart-large-cnn")
202
+
203
  workflow = StateGraph(GraphState)
204
 
 
205
  workflow.add_node("router", router_node)
206
  workflow.add_node("general_chat", general_chat_node)
207
  workflow.add_node("retrieve", retrieve_node)
208
  workflow.add_node("generate", generate_rag_node)
209
 
 
210
  workflow.set_entry_point("router")
211
 
 
212
  workflow.add_conditional_edges(
213
  "router",
214
  route_query,
 
218
  },
219
  )
220
 
 
221
  workflow.add_edge("retrieve", "generate")
 
 
222
  workflow.add_edge("generate", END)
223
  workflow.add_edge("general_chat", END)
224
 
 
225
  compiled_graph = workflow.compile()
226
 
 
227
  def graph_wrapper(query: str, temperature: float = 0.7, raw_text: str = None, history=None, retriever_override=None):
 
228
  active_retriever = retriever_override or retriever
229
  return compiled_graph.invoke({
230
  "query": query,
requirements.txt CHANGED
@@ -6,14 +6,13 @@ accelerate>=0.30.0
6
  # LangChain + LangGraph
7
  langchain>=0.2.1
8
  langgraph>=0.0.45
9
- langchain-community>=0.0.45
10
- langchain-huggingface>=0.1.0
11
- langchain-google-genai>=1.0.5 # wrapper for Gemini
12
 
13
  # Retrieval + Embeddings
14
  llama-index>=0.13.5
15
  llama-index-embeddings-huggingface>=0.1.3
16
- chromadb>=0.5.3 # optional if you still want hybrid search / persistence
17
 
18
  # Hugging Face + Deployment
19
  huggingface_hub>=0.23.4
 
6
  # LangChain + LangGraph
7
  langchain>=0.2.1
8
  langgraph>=0.0.45
9
+ langchain-groq>=0.1.0
10
+ langchain-google-genai>=1.0.5
 
11
 
12
  # Retrieval + Embeddings
13
  llama-index>=0.13.5
14
  llama-index-embeddings-huggingface>=0.1.3
15
+ chromadb>=0.5.3
16
 
17
  # Hugging Face + Deployment
18
  huggingface_hub>=0.23.4