Spaces:
Sleeping
Sleeping
Seif-aber
commited on
Commit
·
edac567
1
Parent(s):
46d0d8e
implemented pdf chat assistant with gemini and RAG
Browse files- Dockerfile +0 -20
- README.md +60 -13
- config/__init__.py +0 -0
- config/settings.py +58 -0
- requirements.txt +7 -3
- src/app.py +167 -0
- src/components/__init__.py +1 -0
- src/components/chat_interface.py +120 -0
- src/components/file_uploader.py +27 -0
- src/components/pdf_viewer.py +57 -0
- src/models/__init__.py +0 -0
- src/models/chat_models.py +21 -0
- src/services/__init__.py +0 -0
- src/services/embedding_service.py +140 -0
- src/services/gemini_client.py +77 -0
- src/services/pdf_processor.py +94 -0
- src/services/rag_service.py +105 -0
- src/streamlit_app.py +0 -40
- src/utils/__init__.py +0 -0
- src/utils/chunking.py +37 -0
- src/utils/vector_store.py +101 -0
Dockerfile
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
FROM python:3.13.5-slim
|
| 2 |
-
|
| 3 |
-
WORKDIR /app
|
| 4 |
-
|
| 5 |
-
RUN apt-get update && apt-get install -y \
|
| 6 |
-
build-essential \
|
| 7 |
-
curl \
|
| 8 |
-
git \
|
| 9 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
-
|
| 11 |
-
COPY requirements.txt ./
|
| 12 |
-
COPY src/ ./src/
|
| 13 |
-
|
| 14 |
-
RUN pip3 install -r requirements.txt
|
| 15 |
-
|
| 16 |
-
EXPOSE 8501
|
| 17 |
-
|
| 18 |
-
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
-
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,19 +1,66 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
tags:
|
| 9 |
-
- streamlit
|
| 10 |
pinned: false
|
| 11 |
-
short_description: A Streamlit web app that lets you chat with your PDF
|
| 12 |
---
|
| 13 |
|
| 14 |
-
#
|
| 15 |
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: PDF Chat Assistant
|
| 3 |
+
emoji: 📄
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
app_file: src/app.py
|
|
|
|
|
|
|
| 8 |
pinned: false
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# PDF Chat Assistant
|
| 12 |
|
| 13 |
+
Interact with your PDF using Retrieval-Augmented Generation (RAG) + Gemini.
|
| 14 |
+
Upload a PDF, it is chunked, embedded, and you can ask questions with contextual, streamed answers.
|
| 15 |
|
| 16 |
+
## Features
|
| 17 |
+
- PDF upload & inline preview
|
| 18 |
+
- Automatic text extraction, cleaning, chunking
|
| 19 |
+
- Embedding storage (pickle vector store)
|
| 20 |
+
- Similarity-based context retrieval
|
| 21 |
+
- Gemini response generation (streaming)
|
| 22 |
+
- Scrollable chat UI
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## Conda Setup
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
git clone https://github.com/Seif-aber/pdf_chat_assistant
|
| 29 |
+
cd pdf-chat-assistant
|
| 30 |
+
|
| 31 |
+
# Create environment
|
| 32 |
+
conda create -n pdfchat python=3.12 -y
|
| 33 |
+
conda activate pdfchat
|
| 34 |
+
|
| 35 |
+
# Install dependencies
|
| 36 |
+
pip install -r requirements.txt
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Environment Variables
|
| 40 |
+
|
| 41 |
+
Create a `.env` file in project root:
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
GEMINI_API_KEY=your_key_here
|
| 45 |
+
GEMINI_MODEL=gemini-2.5-flash
|
| 46 |
+
EMBEDDING_MODEL=models/embedding-001
|
| 47 |
+
STREAMLIT_PORT=8501
|
| 48 |
+
MAX_PDF_SIZE_MB=10
|
| 49 |
+
CHUNK_SIZE=1000
|
| 50 |
+
CHUNK_OVERLAP=200
|
| 51 |
+
UPLOAD_FOLDER=data/uploads
|
| 52 |
+
EMBEDDINGS_FOLDER=data/embeddings
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Then:
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
streamlit run src/app.py --server.port $STREAMLIT_PORT
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## How It Works
|
| 62 |
+
1. Upload PDF → saved to a temp file.
|
| 63 |
+
2. Text extracted (PyPDF2 / pypdf fallback) and chunked with overlap.
|
| 64 |
+
3. Each chunk embedded via Gemini Embeddings API.
|
| 65 |
+
4. On question: create query embedding → cosine similarity → top chunks form context.
|
| 66 |
+
5. Gemini model generates constrained to context.
|
config/__init__.py
ADDED
|
File without changes
|
config/settings.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Central application configuration"""
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 7 |
+
_DOTENV_PATH = _PROJECT_ROOT / ".env"
|
| 8 |
+
|
| 9 |
+
load_dotenv(dotenv_path=_DOTENV_PATH, override=False)
|
| 10 |
+
|
| 11 |
+
class Config:
|
| 12 |
+
"""Holds application configuration values loaded from .env only (no silent fallbacks)."""
|
| 13 |
+
|
| 14 |
+
GEMINI_API_KEY: str | None = os.getenv("GEMINI_API_KEY")
|
| 15 |
+
GEMINI_MODEL: str | None = os.getenv("GEMINI_MODEL")
|
| 16 |
+
EMBEDDING_MODEL: str | None = os.getenv("EMBEDDING_MODEL")
|
| 17 |
+
|
| 18 |
+
STREAMLIT_PORT: str | None = os.getenv("STREAMLIT_PORT")
|
| 19 |
+
MAX_PDF_SIZE_MB: str | None = os.getenv("MAX_PDF_SIZE_MB")
|
| 20 |
+
CHUNK_SIZE: str | None = os.getenv("CHUNK_SIZE")
|
| 21 |
+
CHUNK_OVERLAP: str | None = os.getenv("CHUNK_OVERLAP")
|
| 22 |
+
|
| 23 |
+
UPLOAD_FOLDER: str | None = os.getenv("UPLOAD_FOLDER")
|
| 24 |
+
EMBEDDINGS_FOLDER: str | None = os.getenv("EMBEDDINGS_FOLDER")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
EMBEDDING_STORAGE_PATH: str | None = None
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def validate(cls) -> None:
|
| 31 |
+
"""Validate required variables & finalize derived values."""
|
| 32 |
+
required = {
|
| 33 |
+
"GEMINI_API_KEY": cls.GEMINI_API_KEY,
|
| 34 |
+
"GEMINI_MODEL": cls.GEMINI_MODEL,
|
| 35 |
+
"EMBEDDING_MODEL": cls.EMBEDDING_MODEL,
|
| 36 |
+
"STREAMLIT_PORT": cls.STREAMLIT_PORT,
|
| 37 |
+
"MAX_PDF_SIZE_MB": cls.MAX_PDF_SIZE_MB,
|
| 38 |
+
"CHUNK_SIZE": cls.CHUNK_SIZE,
|
| 39 |
+
"CHUNK_OVERLAP": cls.CHUNK_OVERLAP,
|
| 40 |
+
"UPLOAD_FOLDER": cls.UPLOAD_FOLDER,
|
| 41 |
+
"EMBEDDINGS_FOLDER": cls.EMBEDDINGS_FOLDER,
|
| 42 |
+
}
|
| 43 |
+
missing = [k for k, v in required.items() if not v]
|
| 44 |
+
if missing:
|
| 45 |
+
raise ValueError(f"Missing required environment variables in .env: {', '.join(missing)}")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
cls.STREAMLIT_PORT = int(cls.STREAMLIT_PORT)
|
| 49 |
+
cls.MAX_PDF_SIZE_MB = int(cls.MAX_PDF_SIZE_MB)
|
| 50 |
+
cls.CHUNK_SIZE = int(cls.CHUNK_SIZE)
|
| 51 |
+
cls.CHUNK_OVERLAP = int(cls.CHUNK_OVERLAP)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
os.makedirs(cls.UPLOAD_FOLDER, exist_ok=True)
|
| 55 |
+
os.makedirs(cls.EMBEDDINGS_FOLDER, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
cls.EMBEDDING_STORAGE_PATH = os.path.join(cls.EMBEDDINGS_FOLDER, "pdf_embeddings.pkl") # type: ignore
|
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
PyPDF2
|
| 3 |
+
pypdf
|
| 4 |
+
numpy
|
| 5 |
+
google-generativeai>=0.7.0
|
| 6 |
+
python-dotenv
|
| 7 |
+
pydantic
|
src/app.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
import hashlib
|
| 5 |
+
from components.file_uploader import FileUploader
|
| 6 |
+
from components.pdf_viewer import PdfViewer
|
| 7 |
+
from components.chat_interface import ChatInterface
|
| 8 |
+
from services.pdf_processor import PDFProcessor
|
| 9 |
+
from services.embedding_service import EmbeddingService
|
| 10 |
+
from services.gemini_client import GeminiClient
|
| 11 |
+
from services.rag_service import RAGService
|
| 12 |
+
|
| 13 |
+
def initialize_session_state():
|
| 14 |
+
defaults = {
|
| 15 |
+
"chat_history": [],
|
| 16 |
+
"pdf_processed": False,
|
| 17 |
+
"pdf_id": None,
|
| 18 |
+
"pdf_chunks": [],
|
| 19 |
+
"uploaded_file_path": None,
|
| 20 |
+
"current_file_name": None,
|
| 21 |
+
"current_file_hash": None,
|
| 22 |
+
"processing": False,
|
| 23 |
+
"streaming": False,
|
| 24 |
+
"chat_input": "",
|
| 25 |
+
"clear_chat_input": False, # <--- new flag
|
| 26 |
+
}
|
| 27 |
+
for k, v in defaults.items():
|
| 28 |
+
if k not in st.session_state:
|
| 29 |
+
st.session_state[k] = v
|
| 30 |
+
|
| 31 |
+
def reset_app_state(embedding_service: EmbeddingService) -> None:
|
| 32 |
+
old_path = st.session_state.get("uploaded_file_path")
|
| 33 |
+
if old_path and os.path.exists(old_path):
|
| 34 |
+
try: os.unlink(old_path)
|
| 35 |
+
except Exception: pass
|
| 36 |
+
try:
|
| 37 |
+
embedding_service.vector_store.clear_embeddings()
|
| 38 |
+
except Exception:
|
| 39 |
+
pass
|
| 40 |
+
st.session_state.chat_history = []
|
| 41 |
+
st.session_state.pdf_processed = False
|
| 42 |
+
st.session_state.pdf_id = None
|
| 43 |
+
st.session_state.pdf_chunks = []
|
| 44 |
+
st.session_state.uploaded_file_path = None
|
| 45 |
+
st.session_state.current_file_name = None
|
| 46 |
+
st.session_state.current_file_hash = None
|
| 47 |
+
st.session_state.processing = False
|
| 48 |
+
st.session_state.streaming = False
|
| 49 |
+
|
| 50 |
+
def _file_hash(uploaded_file) -> str:
|
| 51 |
+
return hashlib.md5(uploaded_file.getvalue()).hexdigest()
|
| 52 |
+
|
| 53 |
+
def auto_process_pdf(uploaded_file, tmp_file_path, embedding_service: EmbeddingService, force: bool = False):
|
| 54 |
+
if st.session_state.processing:
|
| 55 |
+
return
|
| 56 |
+
if st.session_state.pdf_processed and not force:
|
| 57 |
+
return
|
| 58 |
+
st.session_state.processing = True
|
| 59 |
+
status = st.empty()
|
| 60 |
+
try:
|
| 61 |
+
status.markdown("⏳ Processing PDF... 10%")
|
| 62 |
+
pdf_processor = PDFProcessor()
|
| 63 |
+
chunks = pdf_processor.process_pdf(tmp_file_path)
|
| 64 |
+
if not chunks:
|
| 65 |
+
status.error("Failed to extract text.")
|
| 66 |
+
return
|
| 67 |
+
pdf_id = uploaded_file.name.replace(".pdf","").replace(" ","_").replace(".","_")
|
| 68 |
+
st.session_state.pdf_id = pdf_id
|
| 69 |
+
st.session_state.pdf_chunks = chunks
|
| 70 |
+
status.markdown("⏳ Processing PDF... 50%")
|
| 71 |
+
embedding_service.store_pdf_embeddings(pdf_id, chunks)
|
| 72 |
+
status.markdown("⏳ Processing PDF... 90%")
|
| 73 |
+
st.session_state.pdf_processed = True
|
| 74 |
+
status.success(f"✅ Processing complete (100%). {len(chunks)} chunks ready.")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
status.error(f"❌ Error: {e}")
|
| 77 |
+
finally:
|
| 78 |
+
st.session_state.processing = False
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
st.set_page_config(page_title="PDF Chat Assistant", page_icon="📄", layout="wide")
|
| 82 |
+
st.title("📄 PDF Chat Assistant")
|
| 83 |
+
initialize_session_state()
|
| 84 |
+
|
| 85 |
+
embedding_service = EmbeddingService()
|
| 86 |
+
gemini_client = GeminiClient()
|
| 87 |
+
rag_service = RAGService(embedding_service, gemini_client)
|
| 88 |
+
|
| 89 |
+
col1, col2 = st.columns([1,1])
|
| 90 |
+
|
| 91 |
+
with col1:
|
| 92 |
+
st.header("📁 Upload & Preview PDF")
|
| 93 |
+
uploaded_file = FileUploader().upload_file()
|
| 94 |
+
if uploaded_file:
|
| 95 |
+
new_hash = _file_hash(uploaded_file)
|
| 96 |
+
if st.session_state.current_file_hash and st.session_state.current_file_hash != new_hash:
|
| 97 |
+
reset_app_state(embedding_service)
|
| 98 |
+
if st.session_state.current_file_hash != new_hash:
|
| 99 |
+
st.session_state.current_file_name = uploaded_file.name
|
| 100 |
+
st.session_state.current_file_hash = new_hash
|
| 101 |
+
st.session_state.pdf_processed = False
|
| 102 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
|
| 103 |
+
tmp.write(uploaded_file.getvalue())
|
| 104 |
+
tmp_path = tmp.name
|
| 105 |
+
st.session_state.uploaded_file_path = tmp_path
|
| 106 |
+
PdfViewer().display_pdf(tmp_path)
|
| 107 |
+
auto_process_pdf(uploaded_file, tmp_path, embedding_service)
|
| 108 |
+
if st.session_state.pdf_processed and not st.session_state.processing:
|
| 109 |
+
if st.button("🔄 Reprocess PDF"):
|
| 110 |
+
st.session_state.pdf_processed = False
|
| 111 |
+
auto_process_pdf(uploaded_file, tmp_path, embedding_service, force=True)
|
| 112 |
+
else:
|
| 113 |
+
st.info("Upload a PDF to begin.")
|
| 114 |
+
|
| 115 |
+
with col2:
|
| 116 |
+
st.header("💬 Chat with your PDF")
|
| 117 |
+
if st.session_state.processing:
|
| 118 |
+
st.info("⏳ Processing... Please wait.")
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
chat_ui = ChatInterface()
|
| 122 |
+
|
| 123 |
+
if st.session_state.pdf_processed and st.session_state.pdf_id:
|
| 124 |
+
if st.session_state.clear_chat_input:
|
| 125 |
+
st.session_state.chat_input = ""
|
| 126 |
+
st.session_state.clear_chat_input = False
|
| 127 |
+
|
| 128 |
+
chat_ui.render(st.session_state.chat_history)
|
| 129 |
+
disabled = st.session_state.streaming
|
| 130 |
+
user_input = st.text_input(
|
| 131 |
+
"Ask a question:",
|
| 132 |
+
key="chat_input",
|
| 133 |
+
placeholder="Type your question...",
|
| 134 |
+
disabled=disabled,
|
| 135 |
+
label_visibility="collapsed"
|
| 136 |
+
)
|
| 137 |
+
send = st.button("Send", disabled=disabled or not user_input.strip(), use_container_width=True)
|
| 138 |
+
|
| 139 |
+
if send and user_input.strip():
|
| 140 |
+
query = user_input.strip()
|
| 141 |
+
st.session_state.chat_history.append({"role": "user", "content": query})
|
| 142 |
+
st.session_state.streaming = True
|
| 143 |
+
|
| 144 |
+
st.session_state.clear_chat_input = True
|
| 145 |
+
stream_iter = rag_service.stream_response(
|
| 146 |
+
query,
|
| 147 |
+
st.session_state.pdf_id,
|
| 148 |
+
st.session_state.chat_history
|
| 149 |
+
)
|
| 150 |
+
assistant_text = chat_ui.stream_assistant(st.session_state.chat_history, stream_iter)
|
| 151 |
+
st.session_state.chat_history.append({"role": "assistant", "content": assistant_text})
|
| 152 |
+
st.session_state.streaming = False
|
| 153 |
+
st.rerun()
|
| 154 |
+
|
| 155 |
+
col_a, col_b = st.columns([1,1])
|
| 156 |
+
with col_a:
|
| 157 |
+
if st.button("Clear Chat", disabled=st.session_state.streaming):
|
| 158 |
+
st.session_state.chat_history = []
|
| 159 |
+
st.session_state.clear_chat_input = True
|
| 160 |
+
st.rerun()
|
| 161 |
+
with col_b:
|
| 162 |
+
pass
|
| 163 |
+
else:
|
| 164 |
+
st.info("Upload and wait for processing to chat.")
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
main()
|
src/components/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# This file is intentionally left blank.
|
src/components/chat_interface.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scrollable + streaming chat interface."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
import html
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
_CHAT_CSS = """
|
| 8 |
+
<style>
|
| 9 |
+
#chat-container {
|
| 10 |
+
height: 520px;
|
| 11 |
+
overflow-y: auto;
|
| 12 |
+
padding: 0.5rem 0.75rem 0.25rem 0.75rem;
|
| 13 |
+
border: 1px solid #e3e3e3;
|
| 14 |
+
border-radius: 10px;
|
| 15 |
+
background: #fafafa;
|
| 16 |
+
scroll-behavior: smooth;
|
| 17 |
+
}
|
| 18 |
+
.chat-msg { margin: 0 0 14px 0; max-width: 85%; }
|
| 19 |
+
.chat-row-user { display:flex; justify-content:flex-end; }
|
| 20 |
+
.chat-row-assistant { display:flex; justify-content:flex-start; }
|
| 21 |
+
.bubble {
|
| 22 |
+
padding:10px 14px;
|
| 23 |
+
border-radius:14px;
|
| 24 |
+
line-height:1.35;
|
| 25 |
+
font-size:0.93rem;
|
| 26 |
+
box-shadow:0 1px 2px rgba(0,0,0,0.08);
|
| 27 |
+
word-wrap:break-word;
|
| 28 |
+
white-space:pre-wrap;
|
| 29 |
+
}
|
| 30 |
+
.bubble-user {
|
| 31 |
+
background:linear-gradient(135deg,#4b8df8,#2563eb);
|
| 32 |
+
color:#fff;
|
| 33 |
+
border-bottom-right-radius:4px;
|
| 34 |
+
}
|
| 35 |
+
.bubble-assistant {
|
| 36 |
+
background:#ffffff;
|
| 37 |
+
border:1px solid #ddd;
|
| 38 |
+
border-bottom-left-radius:4px;
|
| 39 |
+
}
|
| 40 |
+
.meta {
|
| 41 |
+
font-size:0.6rem;
|
| 42 |
+
opacity:0.55;
|
| 43 |
+
margin-top:4px;
|
| 44 |
+
text-align:right;
|
| 45 |
+
user-select:none;
|
| 46 |
+
}
|
| 47 |
+
</style>
|
| 48 |
+
<script>
|
| 49 |
+
function scrollChat(){
|
| 50 |
+
const el = window.parent.document.querySelector('#chat-container');
|
| 51 |
+
if(el){ el.scrollTop = el.scrollHeight; }
|
| 52 |
+
}
|
| 53 |
+
</script>
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
class ChatInterface:
|
| 57 |
+
"""Renders scrollable chat and supports streaming assistant output."""
|
| 58 |
+
|
| 59 |
+
def __init__(self):
|
| 60 |
+
self.chat_history = []
|
| 61 |
+
|
| 62 |
+
def render(self, chat_history: List[Dict]) -> None:
|
| 63 |
+
st.markdown(_CHAT_CSS, unsafe_allow_html=True)
|
| 64 |
+
if not chat_history:
|
| 65 |
+
st.info("No messages yet. Ask something about the PDF.")
|
| 66 |
+
return
|
| 67 |
+
st.markdown(self._history_to_html(chat_history), unsafe_allow_html=True)
|
| 68 |
+
|
| 69 |
+
def stream_assistant(self, chat_history: List[Dict], stream_iter) -> str:
|
| 70 |
+
"""
|
| 71 |
+
Render existing messages then stream new assistant message.
|
| 72 |
+
Returns final assistant text.
|
| 73 |
+
"""
|
| 74 |
+
st.markdown(_CHAT_CSS, unsafe_allow_html=True)
|
| 75 |
+
placeholder = st.empty()
|
| 76 |
+
assistant_text = ""
|
| 77 |
+
# Re-render on each chunk for smooth streaming
|
| 78 |
+
for chunk in stream_iter:
|
| 79 |
+
assistant_text += chunk
|
| 80 |
+
merged = chat_history + [{"role": "assistant", "content": assistant_text}]
|
| 81 |
+
placeholder.markdown(self._history_to_html(merged), unsafe_allow_html=True)
|
| 82 |
+
time.sleep(0.03)
|
| 83 |
+
return assistant_text
|
| 84 |
+
|
| 85 |
+
def input_box(self, key: str = "chat_input") -> str:
|
| 86 |
+
return st.text_input(
|
| 87 |
+
"Ask a question:",
|
| 88 |
+
key=key,
|
| 89 |
+
placeholder="Type your question and press Enter...",
|
| 90 |
+
label_visibility="collapsed",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def add_message(self, role: str, content: str):
|
| 94 |
+
"""
|
| 95 |
+
Add a message to chat history
|
| 96 |
+
"""
|
| 97 |
+
self.chat_history.append({
|
| 98 |
+
"role": role,
|
| 99 |
+
"content": content
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
def clear_chat(self):
|
| 103 |
+
"""Clear the chat history"""
|
| 104 |
+
self.chat_history = []
|
| 105 |
+
|
| 106 |
+
def _history_to_html(self, history: List[Dict]) -> str:
|
| 107 |
+
rows = []
|
| 108 |
+
for m in history:
|
| 109 |
+
role = m.get("role", "user")
|
| 110 |
+
safe = html.escape(m.get("content", ""))
|
| 111 |
+
row_cls = "chat-row-user" if role == "user" else "chat-row-assistant"
|
| 112 |
+
bub_cls = "bubble bubble-user" if role == "user" else "bubble bubble-assistant"
|
| 113 |
+
label = "You" if role == "user" else "Assistant"
|
| 114 |
+
rows.append(
|
| 115 |
+
f'<div class="{row_cls}"><div class="chat-msg">'
|
| 116 |
+
f'<div class="{bub_cls}">{safe}</div>'
|
| 117 |
+
f'<div class="meta">{label}</div>'
|
| 118 |
+
f'</div></div>'
|
| 119 |
+
)
|
| 120 |
+
return f'<div id="chat-container">{"".join(rows)}</div><script>scrollChat();</script>'
|
src/components/file_uploader.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Streamlit component: PDF file uploader."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
UploadedFile = "UploadedFile"
|
| 6 |
+
|
| 7 |
+
class FileUploader:
|
| 8 |
+
"""Encapsulates upload widget usage."""
|
| 9 |
+
|
| 10 |
+
def __init__(self) -> None:
|
| 11 |
+
"""Initialize with no uploaded file."""
|
| 12 |
+
self.uploaded_file: Optional[st.runtime.uploaded_file_manager.UploadedFile] = None
|
| 13 |
+
|
| 14 |
+
def upload_file(self) -> Optional[st.runtime.uploaded_file_manager.UploadedFile]:
|
| 15 |
+
"""
|
| 16 |
+
Render uploader and return uploaded file.
|
| 17 |
+
"""
|
| 18 |
+
self.uploaded_file = st.file_uploader("Choose a PDF file", type=["pdf"])
|
| 19 |
+
if self.uploaded_file:
|
| 20 |
+
return self.uploaded_file
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
def get_file_content(self) -> Optional[bytes]:
|
| 24 |
+
"""
|
| 25 |
+
Return raw bytes of uploaded file.
|
| 26 |
+
"""
|
| 27 |
+
return self.uploaded_file.getvalue() if self.uploaded_file else None
|
src/components/pdf_viewer.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Embed a PDF file in the Streamlit UI (base64 iframe fallback)."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import base64
|
| 4 |
+
import os
|
| 5 |
+
from PyPDF2 import PdfReader
|
| 6 |
+
|
| 7 |
+
class PdfViewer:
|
| 8 |
+
"""Display a PDF document inside the app."""
|
| 9 |
+
|
| 10 |
+
def display_pdf(self, pdf_path: str) -> None:
|
| 11 |
+
"""
|
| 12 |
+
Render only the PDF iframe (metrics removed).
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
pdf_path: Path to local PDF file.
|
| 16 |
+
"""
|
| 17 |
+
try:
|
| 18 |
+
self._iframe(pdf_path)
|
| 19 |
+
except Exception as e:
|
| 20 |
+
st.error(f"PDF preview error: {e}")
|
| 21 |
+
|
| 22 |
+
def _iframe(self, pdf_path: str) -> None:
|
| 23 |
+
"""
|
| 24 |
+
Create a base64 iframe embed.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
pdf_path: Path to PDF.
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
with open(pdf_path, "rb") as f:
|
| 31 |
+
data = f.read()
|
| 32 |
+
b64 = base64.b64encode(data).decode("utf-8")
|
| 33 |
+
html = f"""
|
| 34 |
+
<div style="width:100%; height:600px; border:1px solid #ddd; border-radius:4px; overflow:hidden;">
|
| 35 |
+
<iframe src="data:application/pdf;base64,{b64}" width="100%" height="100%" style="border:none;"></iframe>
|
| 36 |
+
</div>
|
| 37 |
+
"""
|
| 38 |
+
st.markdown(html, unsafe_allow_html=True)
|
| 39 |
+
st.download_button("📥 Download PDF", data, file_name=os.path.basename(pdf_path), mime="application/pdf")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
st.warning(f"Inline PDF display failed: {e}")
|
| 42 |
+
|
| 43 |
+
def _info(self, pdf_path: str) -> dict:
|
| 44 |
+
"""
|
| 45 |
+
Collect minimal PDF info (retained for potential future use).
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
pdf_path: Path to PDF.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Dict with num_pages & encrypted flag.
|
| 52 |
+
"""
|
| 53 |
+
try:
|
| 54 |
+
reader = PdfReader(pdf_path)
|
| 55 |
+
return {"num_pages": len(reader.pages), "encrypted": reader.is_encrypted}
|
| 56 |
+
except Exception:
|
| 57 |
+
return {"num_pages": 0, "encrypted": False}
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/chat_models.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
class Message(BaseModel):
|
| 5 |
+
user_id: str
|
| 6 |
+
content: str
|
| 7 |
+
timestamp: str
|
| 8 |
+
|
| 9 |
+
class ChatContext(BaseModel):
|
| 10 |
+
messages: List[Message]
|
| 11 |
+
pdf_id: Optional[str] = None
|
| 12 |
+
|
| 13 |
+
class UserPrompt(BaseModel):
|
| 14 |
+
user_id: str
|
| 15 |
+
prompt: str
|
| 16 |
+
context: ChatContext
|
| 17 |
+
|
| 18 |
+
class AssistantResponse(BaseModel):
|
| 19 |
+
response: str
|
| 20 |
+
context: ChatContext
|
| 21 |
+
pdf_id: Optional[str] = None
|
src/services/__init__.py
ADDED
|
File without changes
|
src/services/embedding_service.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate, store, and query embeddings via Gemini API."""
|
| 2 |
+
import numpy as np
|
| 3 |
+
import google.generativeai as genai
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from config.settings import Config
|
| 6 |
+
from src.utils.vector_store import VectorStore
|
| 7 |
+
|
| 8 |
+
class EmbeddingService:
|
| 9 |
+
"""Handles embedding generation, storage, and similarity search."""
|
| 10 |
+
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
"""Configure Gemini and initialize vector store."""
|
| 13 |
+
Config.validate()
|
| 14 |
+
genai.configure(api_key=Config.GEMINI_API_KEY)
|
| 15 |
+
self.embedding_model = Config.EMBEDDING_MODEL
|
| 16 |
+
self.vector_store = VectorStore(storage_path=Config.EMBEDDING_STORAGE_PATH)
|
| 17 |
+
|
| 18 |
+
def generate_embeddings(self, texts: List[str]) -> List[np.ndarray]:
|
| 19 |
+
"""
|
| 20 |
+
Embed a list of document texts.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
texts: List of strings.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
List of embedding vectors (np.ndarray).
|
| 27 |
+
"""
|
| 28 |
+
embeddings: List[np.ndarray] = []
|
| 29 |
+
for i, text in enumerate(texts):
|
| 30 |
+
try:
|
| 31 |
+
result = genai.embed_content(
|
| 32 |
+
model=self.embedding_model,
|
| 33 |
+
content=text,
|
| 34 |
+
task_type="retrieval_document",
|
| 35 |
+
)
|
| 36 |
+
embeddings.append(np.array(result["embedding"]))
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"[EmbeddingService] Doc embed error idx {i}: {e}")
|
| 39 |
+
embeddings.append(np.zeros(768))
|
| 40 |
+
return embeddings
|
| 41 |
+
|
| 42 |
+
def generate_query_embedding(self, query: str) -> np.ndarray:
|
| 43 |
+
"""
|
| 44 |
+
Create an embedding for a query.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
query: User query text.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Query embedding vector.
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
result = genai.embed_content(
|
| 54 |
+
model=self.embedding_model,
|
| 55 |
+
content=query,
|
| 56 |
+
task_type="retrieval_query",
|
| 57 |
+
)
|
| 58 |
+
return np.array(result["embedding"])
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"[EmbeddingService] Query embed error: {e}")
|
| 61 |
+
return np.zeros(768)
|
| 62 |
+
|
| 63 |
+
def store_pdf_embeddings(self, pdf_id: str, chunks: List[str]) -> None:
|
| 64 |
+
"""
|
| 65 |
+
Embed and store all chunks for a PDF (replacing previous).
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
pdf_id: Unique PDF identifier.
|
| 69 |
+
chunks: List of chunk strings.
|
| 70 |
+
"""
|
| 71 |
+
self.clear_pdf_embeddings(pdf_id)
|
| 72 |
+
for idx, (chunk, vec) in enumerate(zip(chunks, self.generate_embeddings(chunks))):
|
| 73 |
+
key = f"{pdf_id}_chunk_{idx}"
|
| 74 |
+
self.vector_store.add_embedding(
|
| 75 |
+
key=key,
|
| 76 |
+
vector=vec.tolist(),
|
| 77 |
+
metadata={"pdf_id": pdf_id, "chunk_index": idx, "text": chunk},
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def find_similar_chunks(self, query: str, pdf_id: Optional[str] = None, top_k: int = 3) -> List[Dict]:
|
| 81 |
+
"""
|
| 82 |
+
Retrieve top_k most similar stored chunks.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
query: User query string.
|
| 86 |
+
pdf_id: Restrict to given PDF id if set.
|
| 87 |
+
top_k: Number of results.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List of similarity result dicts.
|
| 91 |
+
"""
|
| 92 |
+
q_vec = self.generate_query_embedding(query)
|
| 93 |
+
results = []
|
| 94 |
+
for key in self.vector_store.get_all_embeddings():
|
| 95 |
+
if pdf_id and not key.startswith(f"{pdf_id}_"):
|
| 96 |
+
continue
|
| 97 |
+
data = self.vector_store.get_embedding_data(key)
|
| 98 |
+
if not data:
|
| 99 |
+
continue
|
| 100 |
+
vec = np.array(data["vector"])
|
| 101 |
+
sim = self._cosine_similarity(q_vec, vec)
|
| 102 |
+
md = data.get("metadata", {})
|
| 103 |
+
results.append(
|
| 104 |
+
{
|
| 105 |
+
"key": key,
|
| 106 |
+
"similarity": sim,
|
| 107 |
+
"text": md.get("text", ""),
|
| 108 |
+
"chunk_index": md.get("chunk_index", 0),
|
| 109 |
+
"pdf_id": md.get("pdf_id", ""),
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
results.sort(key=lambda r: r["similarity"], reverse=True)
|
| 113 |
+
return results[:top_k]
|
| 114 |
+
|
| 115 |
+
def clear_pdf_embeddings(self, pdf_id: str) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Remove all embeddings tied to a PDF.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
pdf_id: Identifier.
|
| 121 |
+
"""
|
| 122 |
+
self.vector_store.remove_embeddings_by_prefix(f"{pdf_id}_")
|
| 123 |
+
|
| 124 |
+
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
|
| 125 |
+
"""
|
| 126 |
+
Compute cosine similarity.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
a: Vector A
|
| 130 |
+
b: Vector B
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Cosine similarity or 0.0 on failure.
|
| 134 |
+
"""
|
| 135 |
+
if not a.any() or not b.any():
|
| 136 |
+
return 0.0
|
| 137 |
+
denom = (np.linalg.norm(a) * np.linalg.norm(b))
|
| 138 |
+
if denom == 0:
|
| 139 |
+
return 0.0
|
| 140 |
+
return float(np.dot(a, b) / denom)
|
src/services/gemini_client.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import google.generativeai as genai
|
| 4 |
+
from typing import List, Dict, Optional, Iterator
|
| 5 |
+
|
| 6 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
| 7 |
+
|
| 8 |
+
from config.settings import Config
|
| 9 |
+
|
| 10 |
+
class GeminiClient:
|
| 11 |
+
"""Generate responses (full or streaming) using Gemini with optional context & history."""
|
| 12 |
+
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
"""Configure model instance."""
|
| 15 |
+
Config.validate()
|
| 16 |
+
genai.configure(api_key=Config.GEMINI_API_KEY)
|
| 17 |
+
self.model = genai.GenerativeModel(Config.GEMINI_MODEL)
|
| 18 |
+
|
| 19 |
+
def generate_response(self, prompt: str, context: str = "", chat_history: Optional[List[Dict]] = None) -> str:
|
| 20 |
+
"""
|
| 21 |
+
Produce a model response.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
prompt: User question.
|
| 25 |
+
context: Retrieved PDF context.
|
| 26 |
+
chat_history: Prior messages list.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Response string (or error message).
|
| 30 |
+
"""
|
| 31 |
+
try:
|
| 32 |
+
full_prompt = self._build_prompt(prompt, context, chat_history)
|
| 33 |
+
resp = self.model.generate_content(full_prompt)
|
| 34 |
+
return getattr(resp, "text", "").strip() or "No response generated."
|
| 35 |
+
except Exception as e:
|
| 36 |
+
return f"Error generating response: {e}"
|
| 37 |
+
|
| 38 |
+
def stream_response(self, prompt: str, context: str = "", chat_history: Optional[List[Dict]] = None) -> Iterator[str]:
|
| 39 |
+
"""
|
| 40 |
+
Stream model tokens/chunks. Yields incremental text fragments.
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
full_prompt = self._build_prompt(prompt, context, chat_history)
|
| 44 |
+
for chunk in self.model.generate_content(full_prompt, stream=True):
|
| 45 |
+
txt = getattr(chunk, "text", "")
|
| 46 |
+
if txt:
|
| 47 |
+
yield txt
|
| 48 |
+
except Exception as e:
|
| 49 |
+
yield f"[Error] {e}"
|
| 50 |
+
|
| 51 |
+
def _build_prompt(self, user_prompt: str, context: str, chat_history: Optional[List[Dict]]) -> str:
|
| 52 |
+
"""
|
| 53 |
+
Construct final prompt sent to LLM.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
user_prompt: Current question.
|
| 57 |
+
context: Retrieved context text.
|
| 58 |
+
chat_history: List of previous user/assistant dicts.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Combined prompt string.
|
| 62 |
+
"""
|
| 63 |
+
system = (
|
| 64 |
+
"You are an assistant answering questions about an uploaded PDF. "
|
| 65 |
+
"Base answers only on provided context. If unknown, say you lack the info."
|
| 66 |
+
)
|
| 67 |
+
parts = [system]
|
| 68 |
+
if context:
|
| 69 |
+
parts.append(f"\nContext:\n{context}")
|
| 70 |
+
if chat_history:
|
| 71 |
+
parts.append("\nRecent conversation:")
|
| 72 |
+
for m in chat_history[-5:]:
|
| 73 |
+
role = m.get("role", "user")
|
| 74 |
+
content = m.get("content", "")
|
| 75 |
+
parts.append(f"{role}: {content}")
|
| 76 |
+
parts.append(f"\nQuestion: {user_prompt}\nAnswer:")
|
| 77 |
+
return "\n".join(parts)
|
src/services/pdf_processor.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Dict, Optional
|
| 4 |
+
|
| 5 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from PyPDF2 import PdfReader
|
| 9 |
+
except ImportError:
|
| 10 |
+
try:
|
| 11 |
+
from pypdf import PdfReader
|
| 12 |
+
except ImportError:
|
| 13 |
+
print("Error: PDF reading library not found. Please install PyPDF2 or pypdf.")
|
| 14 |
+
PdfReader = None
|
| 15 |
+
|
| 16 |
+
from src.utils.chunking import chunk_pdf_text, clean_text
|
| 17 |
+
from config.settings import Config
|
| 18 |
+
|
| 19 |
+
class PDFProcessor:
|
| 20 |
+
"""Process PDFs into cleaned text chunks."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, chunk_size: Optional[int] = None, overlap: Optional[int] = None) -> None:
|
| 23 |
+
"""
|
| 24 |
+
Initialize processor with chunk parameters.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
chunk_size: Characters per chunk (defaults to config).
|
| 28 |
+
overlap: Overlap between chunks (defaults to config).
|
| 29 |
+
"""
|
| 30 |
+
self.chunk_size = chunk_size or Config.CHUNK_SIZE
|
| 31 |
+
self.overlap = overlap or Config.CHUNK_OVERLAP
|
| 32 |
+
|
| 33 |
+
def process_pdf(self, file_path: str) -> List[str]:
|
| 34 |
+
"""
|
| 35 |
+
Read PDF, extract text, clean, and chunk.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
file_path: Path to PDF.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
List of chunk strings.
|
| 42 |
+
"""
|
| 43 |
+
raw = self._extract_text(file_path)
|
| 44 |
+
if not raw.strip():
|
| 45 |
+
return []
|
| 46 |
+
cleaned = clean_text(raw)
|
| 47 |
+
chunks = chunk_pdf_text(cleaned, self.chunk_size, self.overlap)
|
| 48 |
+
return [c for c in chunks if len(c.strip()) > 50]
|
| 49 |
+
|
| 50 |
+
def get_pdf_info(self, file_path: str) -> Dict:
|
| 51 |
+
"""
|
| 52 |
+
Retrieve simple info (pages, metadata, encryption).
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
file_path: Path to PDF.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Dict of info.
|
| 59 |
+
"""
|
| 60 |
+
try:
|
| 61 |
+
reader = PdfReader(file_path)
|
| 62 |
+
return {
|
| 63 |
+
"num_pages": len(reader.pages),
|
| 64 |
+
"metadata": reader.metadata,
|
| 65 |
+
"encrypted": reader.is_encrypted,
|
| 66 |
+
}
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"[PDFProcessor] Info error: {e}")
|
| 69 |
+
return {}
|
| 70 |
+
|
| 71 |
+
def _extract_text(self, file_path: str) -> str:
|
| 72 |
+
"""
|
| 73 |
+
Extract text from all pages.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
file_path: Path to PDF.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Concatenated text with page separators.
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
reader = PdfReader(file_path)
|
| 83 |
+
out: List[str] = []
|
| 84 |
+
for idx, page in enumerate(reader.pages):
|
| 85 |
+
try:
|
| 86 |
+
text = page.extract_text() or ""
|
| 87 |
+
if text.strip():
|
| 88 |
+
out.append(f"\n--- Page {idx+1} ---\n{text}")
|
| 89 |
+
except Exception as pe:
|
| 90 |
+
print(f"[PDFProcessor] Page {idx+1} extraction failed: {pe}")
|
| 91 |
+
return "".join(out)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"[PDFProcessor] Read error: {e}")
|
| 94 |
+
return ""
|
src/services/rag_service.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Dict, Optional, Iterator
|
| 4 |
+
|
| 5 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
| 6 |
+
|
| 7 |
+
from src.services.embedding_service import EmbeddingService
|
| 8 |
+
from src.services.gemini_client import GeminiClient
|
| 9 |
+
|
| 10 |
+
class RAGService:
|
| 11 |
+
"""Combine retrieval + generation workflow."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, embedding_service: EmbeddingService, gemini_client: GeminiClient) -> None:
|
| 14 |
+
"""
|
| 15 |
+
Init RAG service.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
embedding_service: EmbeddingService instance.
|
| 19 |
+
gemini_client: GeminiClient instance.
|
| 20 |
+
"""
|
| 21 |
+
self.embedding_service = embedding_service
|
| 22 |
+
self.gemini_client = gemini_client
|
| 23 |
+
|
| 24 |
+
def get_response(self, user_query: str, pdf_id: str, chat_history: Optional[List[Dict]] = None) -> str:
|
| 25 |
+
"""
|
| 26 |
+
Retrieve context & generate answer.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
user_query: User question.
|
| 30 |
+
pdf_id: PDF identifier.
|
| 31 |
+
chat_history: Prior messages.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Assistant answer text.
|
| 35 |
+
"""
|
| 36 |
+
chunks = self.embedding_service.find_similar_chunks(user_query, pdf_id=pdf_id, top_k=3)
|
| 37 |
+
context = self._format_context(chunks)
|
| 38 |
+
return self.gemini_client.generate_response(user_query, context=context, chat_history=chat_history)
|
| 39 |
+
|
| 40 |
+
def stream_response(self, user_query: str, pdf_id: str, chat_history: Optional[List[Dict]] = None) -> Iterator[str]:
|
| 41 |
+
"""
|
| 42 |
+
Retrieve context then stream model output.
|
| 43 |
+
"""
|
| 44 |
+
chunks = self.embedding_service.find_similar_chunks(user_query, pdf_id=pdf_id, top_k=3)
|
| 45 |
+
context = self._format_context(chunks)
|
| 46 |
+
return self.gemini_client.stream_response(user_query, context=context, chat_history=chat_history)
|
| 47 |
+
|
| 48 |
+
def _format_context(self, chunks: List[Dict]) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Format retrieved chunks for prompt.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
chunks: Retrieval result list.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Joined context string.
|
| 57 |
+
"""
|
| 58 |
+
if not chunks:
|
| 59 |
+
return ""
|
| 60 |
+
lines: List[str] = []
|
| 61 |
+
for idx, c in enumerate(chunks, start=1):
|
| 62 |
+
if c.get("similarity", 0) > 0.05:
|
| 63 |
+
lines.append(f"[Chunk {idx} sim={c['similarity']:.2f}]\n{c.get('text','')}")
|
| 64 |
+
return "\n\n".join(lines)
|
| 65 |
+
|
| 66 |
+
def retrieve_relevant_chunks(self, user_prompt: str, pdf_id: str, top_k: int = 3) -> List[Dict]:
|
| 67 |
+
"""
|
| 68 |
+
Retrieve relevant chunks based on user prompt
|
| 69 |
+
"""
|
| 70 |
+
return self.embedding_service.find_similar_chunks(
|
| 71 |
+
query=user_prompt,
|
| 72 |
+
pdf_id=pdf_id,
|
| 73 |
+
top_k=top_k
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def generate_response_with_sources(self, user_query: str, pdf_id: str, chat_history: List[Dict] = None) -> Dict:
|
| 77 |
+
"""
|
| 78 |
+
Generate response with source information
|
| 79 |
+
"""
|
| 80 |
+
try:
|
| 81 |
+
# Retrieve relevant chunks
|
| 82 |
+
relevant_chunks = self.retrieve_relevant_chunks(user_query, pdf_id)
|
| 83 |
+
|
| 84 |
+
# Prepare context
|
| 85 |
+
context = self._format_context(relevant_chunks)
|
| 86 |
+
|
| 87 |
+
# Generate response
|
| 88 |
+
response = self.gemini_client.generate_response(
|
| 89 |
+
prompt=user_query,
|
| 90 |
+
context=context,
|
| 91 |
+
chat_history=chat_history
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"response": response,
|
| 96 |
+
"sources": relevant_chunks,
|
| 97 |
+
"context_used": context
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
return {
|
| 102 |
+
"response": f"Sorry, I encountered an error: {str(e)}",
|
| 103 |
+
"sources": [],
|
| 104 |
+
"context_used": ""
|
| 105 |
+
}
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/chunking.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Basic text cleaning and fixed-size overlapping chunking utilities."""
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
def clean_text(text: str) -> str:
|
| 5 |
+
"""
|
| 6 |
+
Normalize whitespace in text.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
text: Raw text.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
Cleaned single-spaced text.
|
| 13 |
+
"""
|
| 14 |
+
return " ".join(text.split())
|
| 15 |
+
|
| 16 |
+
def chunk_pdf_text(pdf_text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
|
| 17 |
+
"""
|
| 18 |
+
Split text into overlapping chunks.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
pdf_text: Full text.
|
| 22 |
+
chunk_size: Max chars per chunk.
|
| 23 |
+
overlap: Overlapping chars between chunks.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
List of chunk strings.
|
| 27 |
+
"""
|
| 28 |
+
if chunk_size <= overlap:
|
| 29 |
+
raise ValueError("chunk_size must be greater than overlap")
|
| 30 |
+
chunks: List[str] = []
|
| 31 |
+
start = 0
|
| 32 |
+
length = len(pdf_text)
|
| 33 |
+
while start < length:
|
| 34 |
+
end = min(start + chunk_size, length)
|
| 35 |
+
chunks.append(pdf_text[start:end])
|
| 36 |
+
start += chunk_size - overlap
|
| 37 |
+
return chunks
|
src/utils/vector_store.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Lightweight on-disk key → embedding store (pickle-based)."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from typing import List, Dict, Optional, Any
|
| 4 |
+
import pickle
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
class VectorStore:
|
| 8 |
+
"""Persist simple embedding entries (vector + metadata) to a pickle file."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, storage_path: str) -> None:
|
| 11 |
+
"""
|
| 12 |
+
Initialize the vector store.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
storage_path: Path to pickle file used for persistence.
|
| 16 |
+
"""
|
| 17 |
+
self.storage_path = storage_path
|
| 18 |
+
self.embeddings: Dict[str, Dict[str, Any]] = {}
|
| 19 |
+
os.makedirs(os.path.dirname(storage_path), exist_ok=True)
|
| 20 |
+
self.load_embeddings()
|
| 21 |
+
|
| 22 |
+
def load_embeddings(self) -> None:
|
| 23 |
+
"""Load embeddings from disk if file exists."""
|
| 24 |
+
if os.path.exists(self.storage_path):
|
| 25 |
+
try:
|
| 26 |
+
with open(self.storage_path, "rb") as f:
|
| 27 |
+
self.embeddings = pickle.load(f)
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"[VectorStore] Error loading embeddings: {e}")
|
| 30 |
+
self.embeddings = {}
|
| 31 |
+
|
| 32 |
+
def save_embeddings(self) -> None:
|
| 33 |
+
"""Persist current embeddings to disk."""
|
| 34 |
+
try:
|
| 35 |
+
with open(self.storage_path, "wb") as f:
|
| 36 |
+
pickle.dump(self.embeddings, f)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"[VectorStore] Error saving embeddings: {e}")
|
| 39 |
+
|
| 40 |
+
def add_embedding(self, key: str, vector: List[float], metadata: Optional[Dict] = None) -> None:
|
| 41 |
+
"""
|
| 42 |
+
Add or overwrite an embedding entry.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
key: Unique identifier (e.g. 'pdf1_chunk_0')
|
| 46 |
+
vector: Embedding vector as list of floats
|
| 47 |
+
metadata: Optional metadata dictionary
|
| 48 |
+
"""
|
| 49 |
+
self.embeddings[key] = {"vector": vector, "metadata": metadata or {}}
|
| 50 |
+
self.save_embeddings()
|
| 51 |
+
|
| 52 |
+
def get_embedding_data(self, key: str) -> Optional[Dict]:
|
| 53 |
+
"""
|
| 54 |
+
Retrieve full embedding entry.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
key: Embedding key
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Dict with 'vector' and 'metadata' or None.
|
| 61 |
+
"""
|
| 62 |
+
return self.embeddings.get(key)
|
| 63 |
+
|
| 64 |
+
def get_embedding_vector(self, key: str) -> Optional[List[float]]:
|
| 65 |
+
"""
|
| 66 |
+
Retrieve only the vector.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
key: Embedding key
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Vector list or None.
|
| 73 |
+
"""
|
| 74 |
+
entry = self.embeddings.get(key)
|
| 75 |
+
return entry["vector"] if entry else None
|
| 76 |
+
|
| 77 |
+
def get_all_embeddings(self) -> List[str]:
|
| 78 |
+
"""
|
| 79 |
+
List all embedding keys.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
List of keys.
|
| 83 |
+
"""
|
| 84 |
+
return list(self.embeddings.keys())
|
| 85 |
+
|
| 86 |
+
def clear_embeddings(self) -> None:
|
| 87 |
+
"""Remove all embeddings."""
|
| 88 |
+
self.embeddings = {}
|
| 89 |
+
self.save_embeddings()
|
| 90 |
+
|
| 91 |
+
def remove_embeddings_by_prefix(self, prefix: str) -> None:
|
| 92 |
+
"""
|
| 93 |
+
Remove embeddings whose keys start with prefix.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
prefix: Key prefix filter.
|
| 97 |
+
"""
|
| 98 |
+
to_remove = [k for k in self.embeddings if k.startswith(prefix)]
|
| 99 |
+
for k in to_remove:
|
| 100 |
+
del self.embeddings[k]
|
| 101 |
+
self.save_embeddings()
|