Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import tempfile | |
import pickle | |
import faiss | |
import numpy as np | |
from helper import extract_text_from_pdf, chunk_text, embedding_function, embedding_model, query_llm_with_context | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set page configuration | |
st.set_page_config( | |
page_title="PDF RAG System", | |
page_icon="π", | |
layout="wide" | |
) | |
# Title and description | |
st.title("π PDF RAG System") | |
st.markdown(""" | |
This application allows you to upload a PDF file, ask questions about its content, and get AI-generated answers based on the document. | |
""") | |
# File upload section | |
st.header("1. Upload PDF") | |
uploaded_file = st.file_uploader("Choose a PDF file", type="pdf", key="pdf_uploader") | |
# Initialize session state variables | |
if 'pdf_processed' not in st.session_state: | |
st.session_state.pdf_processed = False | |
if 'index' not in st.session_state: | |
st.session_state.index = None | |
if 'chunks' not in st.session_state: | |
st.session_state.chunks = None | |
if 'pdf_path' not in st.session_state: | |
st.session_state.pdf_path = None | |
# Process the uploaded PDF | |
if uploaded_file is not None and not st.session_state.pdf_processed: | |
with st.spinner("Processing PDF..."): | |
# Create a temporary file to save the uploaded PDF | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
st.session_state.pdf_path = tmp_file.name | |
# Extract text from PDF | |
pdf_text = extract_text_from_pdf(st.session_state.pdf_path) | |
# Chunk the text | |
chunks = chunk_text(pdf_text, chunk_size=1000, chunk_overlap=100) | |
st.session_state.chunks = chunks | |
# Create embeddings | |
embeddings = embedding_function(chunks) | |
# Convert embeddings to numpy array if they aren't already | |
if not isinstance(embeddings, np.ndarray): | |
embeddings = np.array(embeddings).astype('float32') | |
# Get the dimension of the embeddings | |
dimension = embeddings.shape[1] | |
# Initialize FAISS index | |
index = faiss.IndexFlatL2(dimension) | |
# Add vectors to the index | |
index.add(embeddings) | |
# Save the index and chunks | |
faiss.write_index(index, "./faiss_index") | |
with open("./document_chunks.pkl", 'wb') as f: | |
pickle.dump(chunks, f) | |
# Update session state | |
st.session_state.index = index | |
st.session_state.pdf_processed = True | |
st.success(f"PDF processed successfully! {len(chunks)} chunks created.") | |
# Query section | |
st.header("2. Ask a Question") | |
query = st.text_input("Enter your question about the PDF content:", key="query_input") | |
# Add a button to submit the query | |
if st.button("Get Answer", key="get_answer_button") and query and st.session_state.pdf_processed: | |
with st.spinner("Retrieving relevant information and generating answer..."): | |
try: | |
# Generate embedding for the query | |
query_embedding = embedding_model.encode([query], convert_to_numpy=True).astype('float32') | |
# Search the index | |
n_results = 5 | |
distances, indices = st.session_state.index.search(query_embedding, n_results) | |
# Get the documents | |
documents = [st.session_state.chunks[i] for i in indices[0]] | |
# Convert distances to similarity scores (L2 distance: lower is better) | |
# Normalize distances to [0, 1] range where 1 is most similar | |
max_distance = np.max(distances) | |
similarity_scores = [1 - (dist / max_distance) for dist in distances[0]] | |
# Create context from retrieved documents | |
context = (documents, similarity_scores) | |
# Query the LLM with context | |
answer = query_llm_with_context(query, context, top_n=3) | |
# Display the answer | |
st.header("3. Answer") | |
st.write(answer) | |
# Display the retrieved documents | |
with st.expander("View Retrieved Documents", expanded=False): | |
for i, (doc, score) in enumerate(zip(documents, similarity_scores)): | |
st.markdown(f"**Document {i+1}** (Relevance: {score:.4f})") | |
st.text(doc[:500] + "..." if len(doc) > 500 else doc) | |
st.markdown("---") | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
logger.exception("Error during query processing") | |
# Add a reset button | |
if st.button("Reset and Upload New PDF", key="reset_button"): | |
# Clean up temporary files | |
if st.session_state.pdf_path and os.path.exists(st.session_state.pdf_path): | |
os.unlink(st.session_state.pdf_path) | |
# Reset session state | |
st.session_state.pdf_processed = False | |
st.session_state.index = None | |
st.session_state.chunks = None | |
st.session_state.pdf_path = None | |
# Reload the page | |
st.experimental_rerun() | |
# Footer | |
st.markdown("---") | |
st.markdown("Built with Streamlit, FAISS, and Hugging Face API") |