aryan195a commited on
Commit
40ca667
Β·
1 Parent(s): dd83a42

Modified grok model

Browse files
Files changed (2) hide show
  1. app.py +44 -43
  2. graph.py +8 -3
app.py CHANGED
@@ -3,7 +3,7 @@ from graph import build_graph
3
  from utils import get_retriever, load_vectorstore_from_text
4
  from pypdf import PdfReader
5
  import hashlib
6
- from transformers import pipeline
7
 
8
  # --- Page Config ---
9
  st.set_page_config(page_title="LangGraph RAG Chatbot", layout="wide")
@@ -82,9 +82,12 @@ with st.sidebar:
82
 
83
  # --- Initialize Summarizer ---
84
  if "summarizer" not in st.session_state:
 
 
85
  st.session_state.summarizer = pipeline(
86
  "summarization",
87
- model="facebook/bart-large-cnn",
 
88
  device=-1
89
  )
90
 
@@ -106,50 +109,48 @@ if "history" not in st.session_state:
106
  st.session_state.history = []
107
 
108
  # --- Query Input ---
109
- if "current_query" not in st.session_state:
110
- st.session_state.current_query = ""
111
 
112
- query = st.text_input("πŸ’¬ Ask a question:", key="current_query")
113
  send_triggered = st.button("Send")
114
 
115
- # --- Send Query ---
116
- if send_triggered and query.strip():
117
- formatted_history = [(q, r) for q, r, _ in st.session_state.history]
118
-
119
- with st.spinner("Generating response..."):
120
- try:
121
- result = st.session_state.graph(
122
- query=query,
123
- temperature=temperature,
124
- raw_text=st.session_state.get("raw_text"),
125
- history=formatted_history,
126
- retriever_override=st.session_state.get("retriever")
127
- )
128
-
129
- response = result.get("response", "No response generated.")
130
- retrieved_docs = result.get("retrieved_docs", [])
131
-
132
- st.markdown("### πŸ€– Response")
133
- st.markdown(response)
134
-
135
- # Save to history
136
- st.session_state.history.append((query, response, retrieved_docs))
137
-
138
- # Show retrieved docs
139
- if retrieved_docs:
140
- with st.expander("πŸ“„ Retrieved Chunks"):
141
- for j, doc in enumerate(retrieved_docs):
142
- content = getattr(doc, "text", str(doc))
143
- st.markdown(f"**Chunk {j+1}:**")
144
- st.code(content.strip(), language="markdown")
145
-
146
- # Clear input
147
- st.session_state.current_query = ""
148
-
149
- except Exception as e:
150
- st.error(f"Query failed: {e}")
151
- elif send_triggered:
152
- st.warning("Please enter a question.")
153
 
154
  # --- Chat History Display ---
155
  if st.session_state.history:
 
3
  from utils import get_retriever, load_vectorstore_from_text
4
  from pypdf import PdfReader
5
  import hashlib
6
+ from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
7
 
8
  # --- Page Config ---
9
  st.set_page_config(page_title="LangGraph RAG Chatbot", layout="wide")
 
82
 
83
  # --- Initialize Summarizer ---
84
  if "summarizer" not in st.session_state:
85
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
86
+ model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
87
  st.session_state.summarizer = pipeline(
88
  "summarization",
89
+ model=model,
90
+ tokenizer=tokenizer,
91
  device=-1
92
  )
93
 
 
109
  st.session_state.history = []
110
 
111
  # --- Query Input ---
112
+ query_input = st.text_input("πŸ’¬ Ask a question:")
 
113
 
 
114
  send_triggered = st.button("Send")
115
 
116
+ if send_triggered:
117
+ if query_input.strip():
118
+ formatted_history = [(q, r) for q, r, _ in st.session_state.history]
119
+
120
+ with st.spinner("Generating response..."):
121
+ try:
122
+ result = st.session_state.graph(
123
+ query=query_input,
124
+ temperature=temperature,
125
+ raw_text=st.session_state.get("raw_text"),
126
+ history=formatted_history,
127
+ retriever_override=st.session_state.get("retriever")
128
+ )
129
+
130
+ response = result.get("response", "No response generated.")
131
+ retrieved_docs = result.get("retrieved_docs", [])
132
+
133
+ st.markdown("### πŸ€– Response")
134
+ st.markdown(response)
135
+
136
+ # Save to history
137
+ st.session_state.history.append((query_input, response, retrieved_docs))
138
+
139
+ # Show retrieved docs if available
140
+ if retrieved_docs:
141
+ with st.expander("πŸ“„ Retrieved Chunks"):
142
+ for j, doc in enumerate(retrieved_docs):
143
+ content = getattr(doc, "text", str(doc))
144
+ st.markdown(f"**Chunk {j+1}:**")
145
+ st.code(content.strip(), language="markdown")
146
+
147
+ # Clear the input field by rerunning widget with empty value
148
+ st.experimental_rerun()
149
+
150
+ except Exception as e:
151
+ st.error(f"Query failed: {e}")
152
+ else:
153
+ st.warning("Please enter a question.")
154
 
155
  # --- Chat History Display ---
156
  if st.session_state.history:
graph.py CHANGED
@@ -9,7 +9,7 @@ from langgraph.graph import StateGraph, END
9
  from llama_index.core import VectorStoreIndex
10
  from llama_index.core.retrievers import BaseRetriever
11
  from langchain_groq import ChatGroq
12
- from transformers import pipeline as hf_pipeline
13
 
14
  # --- 1. Define the State for the Graph ---
15
  class GraphState(TypedDict):
@@ -180,7 +180,7 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
180
  if not api_key:
181
  raise ValueError("GROQ_API_KEY environment variable not set.")
182
  llm = ChatGroq(
183
- model="mixtral-8x7b-32768",
184
  api_key=api_key,
185
  temperature=0.7,
186
  )
@@ -195,10 +195,15 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
195
  )
196
  else:
197
  raise ValueError("Invalid model_type. Choose 'groq' or 'gemini'.")
 
 
 
 
 
198
 
199
  if summarizer is None:
200
  print("---NO SUMMARIZER PROVIDED, USING DEFAULT (facebook/bart-large-cnn)---")
201
- summarizer = hf_pipeline("summarization", model="facebook/bart-large-cnn")
202
 
203
  workflow = StateGraph(GraphState)
204
 
 
9
  from llama_index.core import VectorStoreIndex
10
  from llama_index.core.retrievers import BaseRetriever
11
  from langchain_groq import ChatGroq
12
+ from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
13
 
14
  # --- 1. Define the State for the Graph ---
15
  class GraphState(TypedDict):
 
180
  if not api_key:
181
  raise ValueError("GROQ_API_KEY environment variable not set.")
182
  llm = ChatGroq(
183
+ model="x-ai/grok-4-fast:free",
184
  api_key=api_key,
185
  temperature=0.7,
186
  )
 
195
  )
196
  else:
197
  raise ValueError("Invalid model_type. Choose 'groq' or 'gemini'.")
198
+
199
+ def get_default_summarizer():
200
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
201
+ model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
202
+ return pipeline("summarization", model=model, tokenizer=tokenizer, device=-1)
203
 
204
  if summarizer is None:
205
  print("---NO SUMMARIZER PROVIDED, USING DEFAULT (facebook/bart-large-cnn)---")
206
+ summarizer = get_default_summarizer()
207
 
208
  workflow = StateGraph(GraphState)
209