File size: 3,245 Bytes
8b835fd
319855f
8b835fd
5e8a326
319855f
 
 
 
5e8a326
 
 
 
 
8b835fd
401d7df
229fd5d
319855f
401d7df
 
 
319855f
401d7df
5e8a326
 
401d7df
5e8a326
 
401d7df
5e8a326
 
401d7df
5e8a326
401d7df
5e8a326
 
401d7df
5e8a326
401d7df
5e8a326
401d7df
 
 
5e8a326
401d7df
5e8a326
401d7df
5e8a326
 
 
 
401d7df
 
5e8a326
401d7df
5e8a326
 
 
 
 
401d7df
5e8a326
 
 
401d7df
5e8a326
401d7df
 
 
 
319855f
 
5e8a326
401d7df
5e8a326
319855f
401d7df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import streamlit as st
import faiss
import pickle
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq

# Constants
DATASET_NAME = "neural-bridge/rag-dataset-1200"
MODEL_NAME = "all-MiniLM-L6-v2"
INDEX_FILE = "faiss_index.pkl"
DOCS_FILE = "contexts.pkl"

# Groq API client
client = Groq(api_key=os.environ.get("MY_KEY"))

# Streamlit page setup
st.set_page_config(page_title="RAG App", layout="wide")
st.title("🧠 Retrieval-Augmented Generation (RAG) with Groq")

# Function to load or create database
@st.cache_resource
def setup_database():
    st.info("Setting up vector database...")
    progress = st.progress(0)

    # Step 1: Load dataset
    dataset = load_dataset(DATASET_NAME, split="train")
    contexts = [entry["context"] for entry in dataset]
    progress.progress(25)

    # Step 2: Compute embeddings
    embedder = SentenceTransformer(MODEL_NAME)
    embeddings = embedder.encode(contexts, show_progress_bar=True)
    progress.progress(50)

    # Step 3: Build FAISS index
    dimension = embeddings[0].shape[0]
    faiss_index = faiss.IndexFlatL2(dimension)
    faiss_index.add(embeddings)
    progress.progress(75)

    # Step 4: Save index and contexts for future use
    with open(INDEX_FILE, "wb") as f:
        pickle.dump(faiss_index, f)
    with open(DOCS_FILE, "wb") as f:
        pickle.dump(contexts, f)

    progress.progress(100)
    st.success("Database setup complete!")
    return faiss_index, contexts

# Check if the index and contexts are saved, otherwise set up
if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE):
    with open(INDEX_FILE, "rb") as f:
        faiss_index = pickle.load(f)
    with open(DOCS_FILE, "rb") as f:
        all_contexts = pickle.load(f)
    st.info("Loaded existing database.")
else:
    faiss_index, all_contexts = setup_database()

# UI for sample questions
sample_questions = [
    "What is the purpose of the RAG dataset?",
    "How does Falcon RefinedWeb contribute to this dataset?",
    "What are the benefits of using retrieval-augmented generation?",
    "Explain the structure of the RAG-1200 dataset.",
]

st.subheader("Ask a question based on the dataset:")
question = st.text_input("Enter your question:", value=sample_questions[0])

if st.button("Ask"):
    if question.strip() == "":
        st.warning("Please enter a question.")
    else:
        with st.spinner("Retrieving and generating answer..."):
            # Embed user query
            embedder = SentenceTransformer(MODEL_NAME)
            query_embedding = embedder.encode([question])
            D, I = faiss_index.search(query_embedding, k=1)

            # Get closest context
            context = all_contexts[I[0][0]]
            prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"

            # Call Groq model
            response = client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                model="llama3-70b-8192"
            )

            answer = response.choices[0].message.content
            st.success("Answer:")
            st.markdown(answer)

            with st.expander("🔍 Retrieved Context"):
                st.markdown(context)