File size: 3,245 Bytes
8b835fd 319855f 8b835fd 5e8a326 319855f 5e8a326 8b835fd 401d7df 229fd5d 319855f 401d7df 319855f 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 5e8a326 401d7df 319855f 5e8a326 401d7df 5e8a326 319855f 401d7df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
import streamlit as st
import faiss
import pickle
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
# Constants
DATASET_NAME = "neural-bridge/rag-dataset-1200"
MODEL_NAME = "all-MiniLM-L6-v2"
INDEX_FILE = "faiss_index.pkl"
DOCS_FILE = "contexts.pkl"
# Groq API client
client = Groq(api_key=os.environ.get("MY_KEY"))
# Streamlit page setup
st.set_page_config(page_title="RAG App", layout="wide")
st.title("🧠 Retrieval-Augmented Generation (RAG) with Groq")
# Function to load or create database
@st.cache_resource
def setup_database():
st.info("Setting up vector database...")
progress = st.progress(0)
# Step 1: Load dataset
dataset = load_dataset(DATASET_NAME, split="train")
contexts = [entry["context"] for entry in dataset]
progress.progress(25)
# Step 2: Compute embeddings
embedder = SentenceTransformer(MODEL_NAME)
embeddings = embedder.encode(contexts, show_progress_bar=True)
progress.progress(50)
# Step 3: Build FAISS index
dimension = embeddings[0].shape[0]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings)
progress.progress(75)
# Step 4: Save index and contexts for future use
with open(INDEX_FILE, "wb") as f:
pickle.dump(faiss_index, f)
with open(DOCS_FILE, "wb") as f:
pickle.dump(contexts, f)
progress.progress(100)
st.success("Database setup complete!")
return faiss_index, contexts
# Check if the index and contexts are saved, otherwise set up
if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE):
with open(INDEX_FILE, "rb") as f:
faiss_index = pickle.load(f)
with open(DOCS_FILE, "rb") as f:
all_contexts = pickle.load(f)
st.info("Loaded existing database.")
else:
faiss_index, all_contexts = setup_database()
# UI for sample questions
sample_questions = [
"What is the purpose of the RAG dataset?",
"How does Falcon RefinedWeb contribute to this dataset?",
"What are the benefits of using retrieval-augmented generation?",
"Explain the structure of the RAG-1200 dataset.",
]
st.subheader("Ask a question based on the dataset:")
question = st.text_input("Enter your question:", value=sample_questions[0])
if st.button("Ask"):
if question.strip() == "":
st.warning("Please enter a question.")
else:
with st.spinner("Retrieving and generating answer..."):
# Embed user query
embedder = SentenceTransformer(MODEL_NAME)
query_embedding = embedder.encode([question])
D, I = faiss_index.search(query_embedding, k=1)
# Get closest context
context = all_contexts[I[0][0]]
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
# Call Groq model
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-70b-8192"
)
answer = response.choices[0].message.content
st.success("Answer:")
st.markdown(answer)
with st.expander("🔍 Retrieved Context"):
st.markdown(context)
|