import os import streamlit as st import faiss import pickle from datasets import load_dataset from sentence_transformers import SentenceTransformer from groq import Groq # Constants DATASET_NAME = "neural-bridge/rag-dataset-1200" MODEL_NAME = "all-MiniLM-L6-v2" INDEX_FILE = "faiss_index.pkl" DOCS_FILE = "contexts.pkl" # Groq API client client = Groq(api_key=os.environ.get("MY_KEY")) # Streamlit page setup st.set_page_config(page_title="RAG App", layout="wide") st.title("🧠 Retrieval-Augmented Generation (RAG) with Groq") # Function to load or create database @st.cache_resource def setup_database(): st.info("Setting up vector database...") progress = st.progress(0) # Step 1: Load dataset dataset = load_dataset(DATASET_NAME, split="train") contexts = [entry["context"] for entry in dataset] progress.progress(25) # Step 2: Compute embeddings embedder = SentenceTransformer(MODEL_NAME) embeddings = embedder.encode(contexts, show_progress_bar=True) progress.progress(50) # Step 3: Build FAISS index dimension = embeddings[0].shape[0] faiss_index = faiss.IndexFlatL2(dimension) faiss_index.add(embeddings) progress.progress(75) # Step 4: Save index and contexts for future use with open(INDEX_FILE, "wb") as f: pickle.dump(faiss_index, f) with open(DOCS_FILE, "wb") as f: pickle.dump(contexts, f) progress.progress(100) st.success("Database setup complete!") return faiss_index, contexts # Check if the index and contexts are saved, otherwise set up if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE): with open(INDEX_FILE, "rb") as f: faiss_index = pickle.load(f) with open(DOCS_FILE, "rb") as f: all_contexts = pickle.load(f) st.info("Loaded existing database.") else: faiss_index, all_contexts = setup_database() # UI for sample questions sample_questions = [ "What is the purpose of the RAG dataset?", "How does Falcon RefinedWeb contribute to this dataset?", "What are the benefits of using retrieval-augmented generation?", "Explain the structure of the RAG-1200 dataset.", ] st.subheader("Ask a question based on the dataset:") question = st.text_input("Enter your question:", value=sample_questions[0]) if st.button("Ask"): if question.strip() == "": st.warning("Please enter a question.") else: with st.spinner("Retrieving and generating answer..."): # Embed user query embedder = SentenceTransformer(MODEL_NAME) query_embedding = embedder.encode([question]) D, I = faiss_index.search(query_embedding, k=1) # Get closest context context = all_contexts[I[0][0]] prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" # Call Groq model response = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama3-70b-8192" ) answer = response.choices[0].message.content st.success("Answer:") st.markdown(answer) with st.expander("🔍 Retrieved Context"): st.markdown(context)