|
import os |
|
import tempfile |
|
import streamlit as st |
|
import fitz |
|
from typing import List, Dict, Any, Optional |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.prompts import PromptTemplate |
|
|
|
|
|
st.set_page_config( |
|
page_title="PDF Q&A Assistant", |
|
page_icon="π", |
|
layout="wide" |
|
) |
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
if "conversation_chain" not in st.session_state: |
|
st.session_state.conversation_chain = None |
|
if "document_processed" not in st.session_state: |
|
st.session_state.document_processed = False |
|
if "file_names" not in st.session_state: |
|
st.session_state.file_names = [] |
|
|
|
class PDFQAAssistant: |
|
def __init__(self, |
|
hf_token: str = None, |
|
model_name: str = "google/flan-t5-base", |
|
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): |
|
""" |
|
Initialize the PDF Q&A Assistant with Hugging Face models. |
|
|
|
Args: |
|
hf_token: Hugging Face API token |
|
model_name: HF model to use for Q&A |
|
embedding_model_name: HF model to use for embeddings |
|
""" |
|
self.model_name = model_name |
|
self.embedding_model_name = embedding_model_name |
|
self.hf_token = hf_token |
|
|
|
|
|
self.persist_directory = os.path.join(tempfile.gettempdir(), "pdf_qa_vectorstore") |
|
|
|
|
|
self.llm = HuggingFaceEndpoint( |
|
repo_id=model_name, |
|
huggingfacehub_api_token=hf_token, |
|
max_length=512, |
|
temperature=0.5 |
|
) |
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings( |
|
model_name=embedding_model_name, |
|
model_kwargs={'device': 'cpu'} |
|
) |
|
|
|
|
|
self.text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=800, |
|
chunk_overlap=150, |
|
length_function=len |
|
) |
|
|
|
|
|
self.vectorstore = None |
|
self.memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
return_messages=True |
|
) |
|
|
|
|
|
os.makedirs(self.persist_directory, exist_ok=True) |
|
|
|
def extract_text_from_pdf(self, pdf_file) -> str: |
|
""" |
|
Extract text from a PDF file using PyMuPDF. |
|
|
|
Args: |
|
pdf_file: Uploaded PDF file |
|
|
|
Returns: |
|
Extracted text as a string |
|
""" |
|
try: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: |
|
tmp_file.write(pdf_file.getvalue()) |
|
tmp_path = tmp_file.name |
|
|
|
|
|
doc = fitz.open(tmp_path) |
|
|
|
|
|
text = "" |
|
for page_num, page in enumerate(doc): |
|
text += page.get_text() |
|
|
|
|
|
doc.close() |
|
os.unlink(tmp_path) |
|
|
|
return text |
|
|
|
except Exception as e: |
|
st.error(f"Error extracting text from PDF: {e}") |
|
raise |
|
|
|
def process_pdf(self, pdf_file, document_name: str) -> None: |
|
""" |
|
Process a PDF file and prepare it for question answering. |
|
|
|
Args: |
|
pdf_file: Uploaded PDF file |
|
document_name: Name to identify the document |
|
""" |
|
|
|
with st.status("Extracting text from PDF..."): |
|
text = self.extract_text_from_pdf(pdf_file) |
|
st.write(f"Extracted {len(text)} characters") |
|
|
|
|
|
with st.status("Splitting document into chunks..."): |
|
chunks = self.text_splitter.split_text(text) |
|
st.write(f"Document split into {len(chunks)} chunks") |
|
|
|
|
|
with st.status("Creating vector embeddings..."): |
|
|
|
metadatas = [{"source": document_name, "chunk": i} for i in range(len(chunks))] |
|
|
|
|
|
if self.vectorstore is None: |
|
self.vectorstore = Chroma.from_texts( |
|
texts=chunks, |
|
embedding=self.embeddings, |
|
metadatas=metadatas, |
|
persist_directory=self.persist_directory |
|
) |
|
else: |
|
self.vectorstore.add_texts(texts=chunks, metadatas=metadatas) |
|
|
|
|
|
if hasattr(self.vectorstore, 'persist'): |
|
self.vectorstore.persist() |
|
|
|
|
|
with st.status("Setting up Q&A system..."): |
|
retriever = self.vectorstore.as_retriever( |
|
search_kwargs={"k": 4} |
|
) |
|
|
|
|
|
qa_prompt = PromptTemplate( |
|
input_variables=["context", "question", "chat_history"], |
|
template=""" |
|
You are an AI assistant specializing in answering questions about documents. |
|
Use the following pieces of context to answer the question at the end. |
|
If you don't know the answer, just say you don't know. Don't try to make up an answer. |
|
Always cite the specific source or page number when possible. |
|
|
|
Context: |
|
{context} |
|
|
|
Chat History: |
|
{chat_history} |
|
|
|
Question: |
|
{question} |
|
|
|
Answer: |
|
""" |
|
) |
|
|
|
self.conversation_chain = ConversationalRetrievalChain.from_llm( |
|
llm=self.llm, |
|
retriever=retriever, |
|
memory=self.memory, |
|
combine_docs_chain_kwargs={"prompt": qa_prompt}, |
|
return_source_documents=True |
|
) |
|
|
|
|
|
st.session_state.conversation_chain = self.conversation_chain |
|
|
|
st.success(f"Successfully processed {document_name}") |
|
st.session_state.document_processed = True |
|
|
|
def ask(self, question: str) -> Dict[str, Any]: |
|
""" |
|
Ask a question about the loaded documents. |
|
|
|
Args: |
|
question: The question to ask |
|
|
|
Returns: |
|
Dictionary with the answer and source documents |
|
""" |
|
if self.conversation_chain is None: |
|
return {"answer": "Please load a document first before asking questions.", "sources": []} |
|
|
|
try: |
|
result = self.conversation_chain({"question": question}) |
|
|
|
|
|
sources = [] |
|
if "source_documents" in result: |
|
for doc in result["source_documents"]: |
|
source = doc.metadata.get("source", "Unknown") |
|
chunk = doc.metadata.get("chunk", "Unknown") |
|
if source not in [s["source"] for s in sources]: |
|
sources.append({"source": source, "chunk": chunk}) |
|
|
|
return { |
|
"answer": result["answer"], |
|
"sources": sources |
|
} |
|
|
|
except Exception as e: |
|
st.error(f"Error processing question: {e}") |
|
return {"answer": f"Error processing your question: {e}", "sources": []} |
|
|
|
def clear_memory(self) -> None: |
|
"""Clear the conversation memory.""" |
|
self.memory.clear() |
|
|
|
def get_document_summary(assistant, document_name): |
|
"""Get a summary of the loaded document.""" |
|
st.subheader("Document Summary") |
|
|
|
with st.status("Generating document summary..."): |
|
questions = [ |
|
"What is the main topic of this document?", |
|
"What are the key points from this document?", |
|
"Could you provide a summary of this document in 3-5 bullet points?" |
|
] |
|
|
|
for question in questions: |
|
result = assistant.ask(question) |
|
st.write(f"**{question}**") |
|
st.write(result["answer"]) |
|
st.divider() |
|
|
|
|
|
def main(): |
|
st.title("π AI-Powered PDF Reader & Q&A Assistant") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Settings") |
|
|
|
|
|
if "HF_TOKEN" in st.secrets: |
|
hf_token = st.secrets["HF_TOKEN"] |
|
token_source = "Using HF_TOKEN from app secrets" |
|
elif os.environ.get("HF_TOKEN"): |
|
hf_token = os.environ.get("HF_TOKEN") |
|
token_source = "Using HF_TOKEN from environment variables" |
|
else: |
|
hf_token = None |
|
token_source = "No HF_TOKEN found" |
|
|
|
st.info(token_source) |
|
|
|
|
|
use_manual_token = st.checkbox("Enter token manually", value=not hf_token) |
|
|
|
if use_manual_token: |
|
hf_token = st.text_input("Enter Hugging Face API Token:", type="password") |
|
|
|
|
|
st.subheader("Model Settings") |
|
model_name = st.selectbox( |
|
"Select LLM model:", |
|
[ |
|
"google/flan-t5-base", |
|
"google/flan-t5-small", |
|
"facebook/bart-large-cnn", |
|
"distilbert-base-uncased" |
|
], |
|
index=0 |
|
) |
|
|
|
embedding_model = st.selectbox( |
|
"Select Embedding model:", |
|
[ |
|
"sentence-transformers/all-MiniLM-L6-v2", |
|
"sentence-transformers/paraphrase-MiniLM-L3-v2" |
|
], |
|
index=0 |
|
) |
|
|
|
|
|
st.subheader("Upload Documents") |
|
uploaded_files = st.file_uploader("Upload PDF documents", |
|
type="pdf", |
|
accept_multiple_files=True) |
|
|
|
if uploaded_files: |
|
process_btn = st.button("Process Documents") |
|
if process_btn: |
|
if not hf_token: |
|
st.error("Please provide a valid Hugging Face API token.") |
|
else: |
|
|
|
try: |
|
assistant = PDFQAAssistant( |
|
hf_token=hf_token, |
|
model_name=model_name, |
|
embedding_model_name=embedding_model |
|
) |
|
|
|
|
|
for pdf_file in uploaded_files: |
|
file_name = pdf_file.name |
|
if file_name not in st.session_state.file_names: |
|
st.session_state.file_names.append(file_name) |
|
assistant.process_pdf(pdf_file, file_name) |
|
|
|
|
|
st.session_state.assistant = assistant |
|
except Exception as e: |
|
st.error(f"Error initializing assistant: {e}") |
|
st.error("Try selecting a different model or check your token permissions.") |
|
|
|
|
|
if st.session_state.get("document_processed", False): |
|
st.subheader("Document Management") |
|
|
|
if st.button("Clear Chat History"): |
|
if "assistant" in st.session_state: |
|
st.session_state.assistant.clear_memory() |
|
st.session_state.chat_history = [] |
|
st.success("Chat history cleared!") |
|
|
|
if st.button("Generate Document Summary"): |
|
if "assistant" in st.session_state and len(st.session_state.file_names) > 0: |
|
get_document_summary(st.session_state.assistant, |
|
st.session_state.file_names[0]) |
|
|
|
|
|
if not st.session_state.get("document_processed", False): |
|
st.info("π Please upload and process a PDF document to get started.") |
|
|
|
|
|
st.header("How It Works") |
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1: |
|
st.subheader("1. Upload PDF") |
|
st.markdown("Upload any PDF document you want to query.") |
|
|
|
with col2: |
|
st.subheader("2. Process Document") |
|
st.markdown("The AI will extract text and create searchable embeddings.") |
|
|
|
with col3: |
|
st.subheader("3. Ask Questions") |
|
st.markdown("Ask any question about your document and get accurate answers.") |
|
else: |
|
|
|
st.header("Ask Questions About Your Documents") |
|
|
|
|
|
st.caption(f"Processed Files: {', '.join(st.session_state.file_names)}") |
|
|
|
|
|
for message in st.session_state.chat_history: |
|
if message["role"] == "user": |
|
st.chat_message("user").write(message["content"]) |
|
else: |
|
st.chat_message("assistant").write(message["content"]) |
|
if message.get("sources"): |
|
with st.expander("View Sources"): |
|
for source in message["sources"]: |
|
st.write(f"- {source['source']} (chunk {source['chunk']})") |
|
|
|
|
|
if question := st.chat_input("Ask a question about your documents..."): |
|
|
|
st.session_state.chat_history.append({ |
|
"role": "user", |
|
"content": question |
|
}) |
|
|
|
|
|
st.chat_message("user").write(question) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Thinking..."): |
|
try: |
|
result = st.session_state.assistant.ask(question) |
|
|
|
st.write(result["answer"]) |
|
|
|
|
|
if result.get("sources"): |
|
with st.expander("View Sources"): |
|
for source in result["sources"]: |
|
st.write(f"- {source['source']} (chunk {source['chunk']})") |
|
|
|
|
|
st.session_state.chat_history.append({ |
|
"role": "assistant", |
|
"content": result["answer"], |
|
"sources": result.get("sources", []) |
|
}) |
|
except Exception as e: |
|
st.error(f"Error getting response: {e}") |
|
st.error("Please try a different question or model.") |
|
|
|
if __name__ == "__main__": |
|
main() |