DocsBot / app /services /vector_store.py
BabaK07's picture
Polish retrieval workflow and UI
d197c9d
import json
from typing import Any
import requests
from langchain_groq import ChatGroq
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from app.config import get_settings
from app.models import DocumentChunk
class JinaEmbeddings:
def __init__(self, *, api_key: str, base_url: str, model: str, dimensions: int) -> None:
self.api_key = api_key
self.base_url = base_url
self.model = model
self.dimensions = dimensions
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._embed(texts=texts, task="retrieval.passage")
def embed_query(self, text: str) -> list[float]:
vectors = self._embed(texts=[text], task="retrieval.query")
return vectors[0] if vectors else [0.0] * self.dimensions
def _embed(self, *, texts: list[str], task: str) -> list[list[float]]:
if not texts:
return []
response = requests.post(
self.base_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
},
json={
"model": self.model,
"task": task,
"embedding_type": "float",
"normalized": True,
"input": texts,
},
timeout=60,
)
response.raise_for_status()
data = response.json().get("data", [])
vectors = [row.get("embedding", []) for row in data]
validated: list[list[float]] = []
for vector in vectors:
if len(vector) != self.dimensions:
raise ValueError(
f"Jina embedding dimension mismatch: got {len(vector)}, expected {self.dimensions}. "
"Adjust EMBEDDING_DIMENSIONS or switch embedding model."
)
validated.append(vector)
return validated
class JinaReranker:
def __init__(self, *, api_key: str, base_url: str, model: str) -> None:
self.api_key = api_key
self.base_url = base_url
self.model = model
def rerank(self, *, query: str, documents: list[str], top_n: int) -> list[dict[str, Any]]:
if not documents:
return []
response = requests.post(
self.base_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
},
json={
"model": self.model,
"query": query,
"top_n": top_n,
"documents": documents,
"return_documents": False,
},
timeout=60,
)
response.raise_for_status()
return response.json().get("results", [])
class VectorStoreService:
def __init__(self) -> None:
self.settings = get_settings()
if not self.settings.jina_api_key:
raise RuntimeError("JINA_API_KEY is required for document embedding and retrieval.")
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=150,
separators=[
"\n\n",
"\n",
". ",
"? ",
"! ",
"; ",
", ",
" ",
"",
],
keep_separator=True,
)
self.embeddings = JinaEmbeddings(
api_key=self.settings.jina_api_key,
base_url=self.settings.jina_api_base,
model=self.settings.jina_embedding_model,
dimensions=self.settings.embedding_dimensions,
)
self.retrieval_router = (
ChatGroq(
api_key=self.settings.groq_api_key,
model=self.settings.model_name,
temperature=0,
)
if self.settings.groq_api_key
else None
)
self.reranker = JinaReranker(
api_key=self.settings.jina_api_key,
base_url=self.settings.jina_reranker_api_base,
model=self.settings.jina_reranker_model,
)
def _get_embeddings(self) -> Any:
return self.embeddings
def _choose_retrieval_sizes(
self,
*,
db: Session,
query: str,
file_hashes: list[str],
requested_k: int,
) -> tuple[int, int]:
available_chunks = db.scalar(
select(func.count())
.select_from(DocumentChunk)
.where(DocumentChunk.file_hash.in_(file_hashes))
) or 0
if available_chunks <= 0:
return 0, 0
if self.retrieval_router is None:
raise RuntimeError("GROQ_API_KEY is required for LLM-based retrieval size selection.")
prompt = (
"You are a retrieval planner for a RAG system.\n"
"Choose how many chunks to keep after reranking and how many vector candidates to send to the reranker.\n"
"Return only valid JSON with this exact schema:\n"
'{"final_k": 4, "candidate_k": 12}\n\n'
"Rules:\n"
f"- final_k must be between 1 and {min(8, available_chunks)}\n"
f"- candidate_k must be between final_k and {min(30, available_chunks)}\n"
"- candidate_k should usually be around 2x to 4x final_k\n"
"- Use larger values for broad, comparative, or synthesis-heavy queries\n"
"- Use smaller values for narrow fact lookup queries\n\n"
f"Query: {query}\n"
f"Selected documents: {len(file_hashes)}\n"
f"Available chunks: {available_chunks}\n"
f"Requested final_k hint: {requested_k}\n"
f"Configured minimum final_k: {self.settings.retrieval_k}\n"
f"Configured minimum candidate_k: {self.settings.rerank_candidate_k}\n"
)
response = self.retrieval_router.invoke(prompt)
content = response.content if isinstance(response.content, str) else str(response.content)
if "```json" in content:
content = content.split("```json", 1)[1].split("```", 1)[0].strip()
elif "```" in content:
content = content.split("```", 1)[1].split("```", 1)[0].strip()
data = json.loads(content)
final_k = int(data["final_k"])
candidate_k = int(data["candidate_k"])
final_k = max(1, min(final_k, available_chunks, 8))
candidate_floor = max(final_k, self.settings.rerank_candidate_k)
candidate_k = max(final_k, candidate_k)
candidate_k = min(max(candidate_floor, candidate_k), available_chunks, 30)
return final_k, candidate_k
def _rerank_matches(self, *, query: str, matches: list[dict[str, Any]], top_n: int) -> list[dict[str, Any]]:
if self.reranker is None or not matches:
return matches[:top_n]
try:
results = self.reranker.rerank(
query=query,
documents=[match["content"] for match in matches],
top_n=min(top_n, len(matches)),
)
except requests.RequestException:
return matches[:top_n]
reranked: list[dict[str, Any]] = []
for item in results:
index = item.get("index")
if not isinstance(index, int) or index < 0 or index >= len(matches):
continue
match = dict(matches[index])
score = item.get("relevance_score")
if isinstance(score, (int, float)):
match["rerank_score"] = float(score)
reranked.append(match)
return reranked or matches[:top_n]
def add_document(self, *, db: Session, document_id: int, file_hash: str, filename: str, pages: list[tuple[int, str]]) -> None:
chunk_rows: list[tuple[int | None, str]] = []
for page_number, page_text in pages:
if not page_text.strip():
continue
page_chunks = self.splitter.split_text(page_text)
chunk_rows.extend((page_number, chunk) for chunk in page_chunks if chunk.strip())
chunks = [chunk for _, chunk in chunk_rows]
if not chunks:
return
embeddings_client = self._get_embeddings()
embeddings = embeddings_client.embed_documents(chunks)
db.execute(delete(DocumentChunk).where(DocumentChunk.document_id == document_id))
rows = [
DocumentChunk(
document_id=document_id,
file_hash=file_hash,
filename=filename,
chunk_index=index,
page_number=page_number,
content=chunk,
embedding=embedding,
)
for index, ((page_number, chunk), embedding) in enumerate(zip(chunk_rows, embeddings, strict=False))
]
db.add_all(rows)
db.flush()
def similarity_search(self, *, db: Session, query: str, file_hashes: list[str], k: int = 4) -> list[dict[str, Any]]:
if not file_hashes:
return []
final_k, candidate_k = self._choose_retrieval_sizes(
db=db,
query=query,
file_hashes=file_hashes,
requested_k=k,
)
if final_k == 0:
return []
query_embedding = self._get_embeddings().embed_query(query)
stmt = (
select(
DocumentChunk.document_id,
DocumentChunk.content,
DocumentChunk.filename,
DocumentChunk.file_hash,
DocumentChunk.chunk_index,
DocumentChunk.page_number,
DocumentChunk.embedding.cosine_distance(query_embedding).label("distance"),
)
.where(DocumentChunk.file_hash.in_(file_hashes))
.order_by(DocumentChunk.embedding.cosine_distance(query_embedding))
.limit(candidate_k)
)
results = db.execute(stmt).all()
matches: list[dict[str, Any]] = []
for row in results:
matches.append(
{
"content": row.content,
"metadata": {
"document_id": row.document_id,
"filename": row.filename,
"file_hash": row.file_hash,
"chunk_index": row.chunk_index,
"page_number": row.page_number,
},
"distance": row.distance,
}
)
return self._rerank_matches(query=query, matches=matches, top_n=final_k)