aryan195a commited on
Commit
9792cab
·
1 Parent(s): 40ca667

Added grok-4-fast model

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. graph.py +23 -21
  3. requirements.txt +1 -0
app.py CHANGED
@@ -27,7 +27,7 @@ def load_uploaded_file(uploaded_file):
27
  return raw_text
28
 
29
  # --- Cached Vectorstore with Persistent Cache ---
30
- @st.cache_data(show_spinner=False)
31
  def cached_vectorstore_from_text(raw_text_hash: str, text: str):
32
  """
33
  Cache vectorstore based on hash of raw text.
@@ -47,7 +47,7 @@ with st.sidebar:
47
  model_type = st.radio(
48
  "Select LLM Backend:",
49
  options=["groq", "gemini"],
50
- format_func=lambda x: "⚡ Groq (Mixtral-8x7B)" if x == "groq" else "🌐 Google Gemini"
51
  )
52
 
53
  uploaded_file = st.file_uploader("Upload a file (optional)", type=["txt", "pdf"])
@@ -145,7 +145,7 @@ if send_triggered:
145
  st.code(content.strip(), language="markdown")
146
 
147
  # Clear the input field by rerunning widget with empty value
148
- st.experimental_rerun()
149
 
150
  except Exception as e:
151
  st.error(f"Query failed: {e}")
 
27
  return raw_text
28
 
29
  # --- Cached Vectorstore with Persistent Cache ---
30
+ @st.cache_resource(show_spinner=False)
31
  def cached_vectorstore_from_text(raw_text_hash: str, text: str):
32
  """
33
  Cache vectorstore based on hash of raw text.
 
47
  model_type = st.radio(
48
  "Select LLM Backend:",
49
  options=["groq", "gemini"],
50
+ format_func=lambda x: "⚡ Groq API" if x == "groq" else "🌐 Google Gemini"
51
  )
52
 
53
  uploaded_file = st.file_uploader("Upload a file (optional)", type=["txt", "pdf"])
 
145
  st.code(content.strip(), language="markdown")
146
 
147
  # Clear the input field by rerunning widget with empty value
148
+ st.rerun()
149
 
150
  except Exception as e:
151
  st.error(f"Query failed: {e}")
graph.py CHANGED
@@ -5,11 +5,12 @@ from typing import TypedDict, Optional, List
5
 
6
  from llama_index.core.schema import Document
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langgraph.graph import StateGraph, END
9
  from llama_index.core import VectorStoreIndex
10
  from llama_index.core.retrievers import BaseRetriever
11
  from langchain_groq import ChatGroq
12
  from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
 
13
 
14
  # --- 1. Define the State for the Graph ---
15
  class GraphState(TypedDict):
@@ -175,26 +176,26 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
175
  If summarizer not provided, initializes a default HuggingFace summarizer.
176
  """
177
 
178
- if model_type == "groq":
179
- api_key = os.getenv("GROQ_API_KEY", "").strip()
180
- if not api_key:
181
- raise ValueError("GROQ_API_KEY environment variable not set.")
182
- llm = ChatGroq(
183
- model="x-ai/grok-4-fast:free",
184
- api_key=api_key,
185
- temperature=0.7,
186
- )
187
- elif model_type == "gemini":
188
- api_key = os.getenv("GEMINI_API_KEY", "").strip()
189
- if not api_key:
190
- raise ValueError("GEMINI_API_KEY environment variable not set.")
191
- llm = ChatGoogleGenerativeAI(
192
- model="gemini-2.0-flash",
193
- api_key=api_key,
194
- temperature=0.7
195
- )
196
- else:
197
- raise ValueError("Invalid model_type. Choose 'groq' or 'gemini'.")
198
 
199
  def get_default_summarizer():
200
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
@@ -231,6 +232,7 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
231
 
232
  def graph_wrapper(query: str, temperature: float = 0.7, raw_text: str = None, history=None, retriever_override=None):
233
  active_retriever = retriever_override or retriever
 
234
  return compiled_graph.invoke({
235
  "query": query,
236
  "retriever": active_retriever.as_retriever() if active_retriever else None,
 
5
 
6
  from llama_index.core.schema import Document
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_openai import ChatOpenAI
9
  from llama_index.core import VectorStoreIndex
10
  from llama_index.core.retrievers import BaseRetriever
11
  from langchain_groq import ChatGroq
12
  from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
13
+ from langgraph.graph import StateGraph, END
14
 
15
  # --- 1. Define the State for the Graph ---
16
  class GraphState(TypedDict):
 
176
  If summarizer not provided, initializes a default HuggingFace summarizer.
177
  """
178
 
179
+ def make_llm(temp: float):
180
+ if model_type == "groq":
181
+ api_key = os.getenv("OPENROUTER_API_KEY", "").strip()
182
+ if not api_key:
183
+ raise ValueError("OPENROUTER_API_KEY environment variable not set.")
184
+ return ChatOpenAI(
185
+ model="x-ai/grok-4-fast:free",
186
+ base_url="https://openrouter.ai/api/v1",
187
+ api_key=api_key,
188
+ temperature=temp,
189
+ )
190
+ elif model_type == "gemini":
191
+ api_key = os.getenv("GEMINI_API_KEY", "").strip()
192
+ if not api_key:
193
+ raise ValueError("GEMINI_API_KEY environment variable not set.")
194
+ return ChatGoogleGenerativeAI(
195
+ model="gemini-2.0-flash",
196
+ api_key=api_key,
197
+ temperature=temp,
198
+ )
199
 
200
  def get_default_summarizer():
201
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
 
232
 
233
  def graph_wrapper(query: str, temperature: float = 0.7, raw_text: str = None, history=None, retriever_override=None):
234
  active_retriever = retriever_override or retriever
235
+ llm = make_llm(temperature)
236
  return compiled_graph.invoke({
237
  "query": query,
238
  "retriever": active_retriever.as_retriever() if active_retriever else None,
requirements.txt CHANGED
@@ -6,6 +6,7 @@ accelerate>=0.30.0
6
  # LangChain + LangGraph
7
  langchain>=0.2.1
8
  langgraph>=0.0.45
 
9
  langchain-groq>=0.1.0
10
  langchain-google-genai>=1.0.5
11
 
 
6
  # LangChain + LangGraph
7
  langchain>=0.2.1
8
  langgraph>=0.0.45
9
+ langchain-openai>=0.1.0
10
  langchain-groq>=0.1.0
11
  langchain-google-genai>=1.0.5
12