RAG_App / app.py
flutterbasit's picture
Update app.py
18a0f56 verified
import os
import fitz # For PDF extraction
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from dotenv import load_dotenv
import streamlit as st
from groq import Groq
# Function to extract text from a PDF
def extract_text_from_pdf(file):
try:
doc = fitz.open(stream=file.read(), filetype="pdf")
text = ""
for page in doc:
text += page.get_text()
return text
except Exception as e:
st.error(f"Error extracting text: {e}")
return ""
# Function to chunk the text
def chunk_text(text, chunk_size=500):
sentences = text.split(". ")
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= chunk_size:
current_chunk += sentence + ". "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + ". "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# Load the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Function to generate embeddings
def generate_embeddings(chunks):
return embedding_model.encode(chunks)
# Function to store embeddings in FAISS
def store_embeddings_in_faiss(embeddings):
try:
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
except Exception as e:
st.error(f"Error with FAISS: {e}")
return None
# Function to retrieve similar chunks
def retrieve_similar_chunks(query, index, chunks, model):
try:
query_embedding = model.encode([query])[0]
distances, indices = index.search(np.array([query_embedding]), k=5)
return [chunks[i] for i in indices[0]]
except Exception as e:
st.error(f"Error retrieving similar chunks: {e}")
return []
# Load environment variables
# Initialize Groq client with direct API key
groq_api_key = "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"
if not groq_api_key:
st.error("The GROQ_API_KEY is not set.")
exit()
# Initialize Groq client
groq_client = Groq(api_key=groq_api_key)
def query_llm(prompt, model="llama3-8b-8192"):
try:
response = groq_client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
model=model,
)
return response.choices[0].message.content
except Exception as e:
st.error(f"Error querying LLM: {e}")
return "Error in LLM response."
# Streamlit application
def main():
st.title("RAG Application with Groq API")
# File upload
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
if uploaded_file:
# Extract text
pdf_text = extract_text_from_pdf(uploaded_file)
if not pdf_text:
return
st.write("PDF Text Extracted:")
st.write(pdf_text[:500]) # Show a preview
# Chunk the text
chunks = chunk_text(pdf_text)
st.write(f"Text split into {len(chunks)} chunks.")
# Generate embeddings
embeddings = np.array(generate_embeddings(chunks))
index = store_embeddings_in_faiss(embeddings)
if index is None:
return
# Query handling
query = st.text_input("Enter your query:")
if query:
similar_chunks = retrieve_similar_chunks(query, index, chunks, embedding_model)
st.write("Relevant Chunks:")
for i, chunk in enumerate(similar_chunks, start=1):
st.write(f"Chunk {i}: {chunk}")
# Query the LLM
combined_context = " ".join(similar_chunks[:3])
llm_prompt = f"Context: {combined_context}\n\nQuery: {query}"
llm_response = query_llm(llm_prompt)
st.write("LLM Response:")
st.write(llm_response)
if __name__ == "__main__":
main()