Spaces:
Sleeping
Sleeping
Added grok-4-fast model
Browse files- app.py +3 -3
- graph.py +23 -21
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -27,7 +27,7 @@ def load_uploaded_file(uploaded_file):
|
|
| 27 |
return raw_text
|
| 28 |
|
| 29 |
# --- Cached Vectorstore with Persistent Cache ---
|
| 30 |
-
@st.
|
| 31 |
def cached_vectorstore_from_text(raw_text_hash: str, text: str):
|
| 32 |
"""
|
| 33 |
Cache vectorstore based on hash of raw text.
|
|
@@ -47,7 +47,7 @@ with st.sidebar:
|
|
| 47 |
model_type = st.radio(
|
| 48 |
"Select LLM Backend:",
|
| 49 |
options=["groq", "gemini"],
|
| 50 |
-
format_func=lambda x: "⚡ Groq
|
| 51 |
)
|
| 52 |
|
| 53 |
uploaded_file = st.file_uploader("Upload a file (optional)", type=["txt", "pdf"])
|
|
@@ -145,7 +145,7 @@ if send_triggered:
|
|
| 145 |
st.code(content.strip(), language="markdown")
|
| 146 |
|
| 147 |
# Clear the input field by rerunning widget with empty value
|
| 148 |
-
st.
|
| 149 |
|
| 150 |
except Exception as e:
|
| 151 |
st.error(f"Query failed: {e}")
|
|
|
|
| 27 |
return raw_text
|
| 28 |
|
| 29 |
# --- Cached Vectorstore with Persistent Cache ---
|
| 30 |
+
@st.cache_resource(show_spinner=False)
|
| 31 |
def cached_vectorstore_from_text(raw_text_hash: str, text: str):
|
| 32 |
"""
|
| 33 |
Cache vectorstore based on hash of raw text.
|
|
|
|
| 47 |
model_type = st.radio(
|
| 48 |
"Select LLM Backend:",
|
| 49 |
options=["groq", "gemini"],
|
| 50 |
+
format_func=lambda x: "⚡ Groq API" if x == "groq" else "🌐 Google Gemini"
|
| 51 |
)
|
| 52 |
|
| 53 |
uploaded_file = st.file_uploader("Upload a file (optional)", type=["txt", "pdf"])
|
|
|
|
| 145 |
st.code(content.strip(), language="markdown")
|
| 146 |
|
| 147 |
# Clear the input field by rerunning widget with empty value
|
| 148 |
+
st.rerun()
|
| 149 |
|
| 150 |
except Exception as e:
|
| 151 |
st.error(f"Query failed: {e}")
|
graph.py
CHANGED
|
@@ -5,11 +5,12 @@ from typing import TypedDict, Optional, List
|
|
| 5 |
|
| 6 |
from llama_index.core.schema import Document
|
| 7 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 8 |
-
from
|
| 9 |
from llama_index.core import VectorStoreIndex
|
| 10 |
from llama_index.core.retrievers import BaseRetriever
|
| 11 |
from langchain_groq import ChatGroq
|
| 12 |
from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
|
|
|
|
| 13 |
|
| 14 |
# --- 1. Define the State for the Graph ---
|
| 15 |
class GraphState(TypedDict):
|
|
@@ -175,26 +176,26 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
|
|
| 175 |
If summarizer not provided, initializes a default HuggingFace summarizer.
|
| 176 |
"""
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
|
| 199 |
def get_default_summarizer():
|
| 200 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
|
@@ -231,6 +232,7 @@ def build_graph(model_type: str = "groq", retriever=None, summarizer=None):
|
|
| 231 |
|
| 232 |
def graph_wrapper(query: str, temperature: float = 0.7, raw_text: str = None, history=None, retriever_override=None):
|
| 233 |
active_retriever = retriever_override or retriever
|
|
|
|
| 234 |
return compiled_graph.invoke({
|
| 235 |
"query": query,
|
| 236 |
"retriever": active_retriever.as_retriever() if active_retriever else None,
|
|
|
|
| 5 |
|
| 6 |
from llama_index.core.schema import Document
|
| 7 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 8 |
+
from langchain_openai import ChatOpenAI
|
| 9 |
from llama_index.core import VectorStoreIndex
|
| 10 |
from llama_index.core.retrievers import BaseRetriever
|
| 11 |
from langchain_groq import ChatGroq
|
| 12 |
from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
|
| 13 |
+
from langgraph.graph import StateGraph, END
|
| 14 |
|
| 15 |
# --- 1. Define the State for the Graph ---
|
| 16 |
class GraphState(TypedDict):
|
|
|
|
| 176 |
If summarizer not provided, initializes a default HuggingFace summarizer.
|
| 177 |
"""
|
| 178 |
|
| 179 |
+
def make_llm(temp: float):
|
| 180 |
+
if model_type == "groq":
|
| 181 |
+
api_key = os.getenv("OPENROUTER_API_KEY", "").strip()
|
| 182 |
+
if not api_key:
|
| 183 |
+
raise ValueError("OPENROUTER_API_KEY environment variable not set.")
|
| 184 |
+
return ChatOpenAI(
|
| 185 |
+
model="x-ai/grok-4-fast:free",
|
| 186 |
+
base_url="https://openrouter.ai/api/v1",
|
| 187 |
+
api_key=api_key,
|
| 188 |
+
temperature=temp,
|
| 189 |
+
)
|
| 190 |
+
elif model_type == "gemini":
|
| 191 |
+
api_key = os.getenv("GEMINI_API_KEY", "").strip()
|
| 192 |
+
if not api_key:
|
| 193 |
+
raise ValueError("GEMINI_API_KEY environment variable not set.")
|
| 194 |
+
return ChatGoogleGenerativeAI(
|
| 195 |
+
model="gemini-2.0-flash",
|
| 196 |
+
api_key=api_key,
|
| 197 |
+
temperature=temp,
|
| 198 |
+
)
|
| 199 |
|
| 200 |
def get_default_summarizer():
|
| 201 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
|
|
|
| 232 |
|
| 233 |
def graph_wrapper(query: str, temperature: float = 0.7, raw_text: str = None, history=None, retriever_override=None):
|
| 234 |
active_retriever = retriever_override or retriever
|
| 235 |
+
llm = make_llm(temperature)
|
| 236 |
return compiled_graph.invoke({
|
| 237 |
"query": query,
|
| 238 |
"retriever": active_retriever.as_retriever() if active_retriever else None,
|
requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@ accelerate>=0.30.0
|
|
| 6 |
# LangChain + LangGraph
|
| 7 |
langchain>=0.2.1
|
| 8 |
langgraph>=0.0.45
|
|
|
|
| 9 |
langchain-groq>=0.1.0
|
| 10 |
langchain-google-genai>=1.0.5
|
| 11 |
|
|
|
|
| 6 |
# LangChain + LangGraph
|
| 7 |
langchain>=0.2.1
|
| 8 |
langgraph>=0.0.45
|
| 9 |
+
langchain-openai>=0.1.0
|
| 10 |
langchain-groq>=0.1.0
|
| 11 |
langchain-google-genai>=1.0.5
|
| 12 |
|