|
|
"""Context retrieval with reranking capabilities.""" |
|
|
|
|
|
import os |
|
|
from typing import List, Optional, Tuple, Dict, Any |
|
|
from langchain.schema import Document |
|
|
from langchain_community.vectorstores import Qdrant |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from sentence_transformers import CrossEncoder |
|
|
import numpy as np |
|
|
import torch |
|
|
from qdrant_client.http import models as rest |
|
|
import traceback |
|
|
|
|
|
from .filter import create_filter |
|
|
|
|
|
class ContextRetriever: |
|
|
""" |
|
|
Context retriever for hybrid search with optional filtering and reranking. |
|
|
""" |
|
|
|
|
|
def __init__(self, vectorstore: Qdrant, config: dict = None): |
|
|
""" |
|
|
Initialize the context retriever. |
|
|
|
|
|
Args: |
|
|
vectorstore: Qdrant vector store instance |
|
|
config: Configuration dictionary |
|
|
""" |
|
|
self.vectorstore = vectorstore |
|
|
self.config = config or {} |
|
|
self.reranker = None |
|
|
|
|
|
|
|
|
self.bm25_vectorizer = None |
|
|
self.bm25_matrix = None |
|
|
self.bm25_documents = None |
|
|
|
|
|
|
|
|
|
|
|
self.reranker_model_name = ( |
|
|
config.get('retrieval', {}).get('reranker_model') or |
|
|
config.get('ranker', {}).get('model') or |
|
|
config.get('reranker_model') or |
|
|
'BAAI/bge-reranker-v2-m3' |
|
|
) |
|
|
self.reranker_type = self._detect_reranker_type(self.reranker_model_name) |
|
|
|
|
|
try: |
|
|
if self.reranker_type == 'colbert': |
|
|
from colbert.infra import Run, ColBERTConfig |
|
|
from colbert.modeling.checkpoint import Checkpoint |
|
|
|
|
|
print(f"✅ RERANKER: ColBERT model detected ({self.reranker_model_name})") |
|
|
print(f"🔍 INTERACTION TYPE: Late interaction (token-level embeddings)") |
|
|
|
|
|
|
|
|
colbert_config = ColBERTConfig( |
|
|
doc_maxlen=300, |
|
|
query_maxlen=32, |
|
|
nbits=2, |
|
|
kmeans_niters=4, |
|
|
root="./colbert_data" |
|
|
) |
|
|
|
|
|
|
|
|
self.colbert_checkpoint = Checkpoint(self.reranker_model_name, colbert_config=colbert_config) |
|
|
self.colbert_model = self.colbert_checkpoint.model |
|
|
self.colbert_tokenizer = self.colbert_checkpoint.raw_tokenizer |
|
|
self.reranker = self._colbert_rerank |
|
|
print(f"✅ COLBERT: Model and tokenizer loaded successfully") |
|
|
|
|
|
else: |
|
|
|
|
|
from sentence_transformers import CrossEncoder |
|
|
self.reranker = CrossEncoder(self.reranker_model_name) |
|
|
print(f"✅ RERANKER: Initialized {self.reranker_model_name}") |
|
|
print(f"🔍 INTERACTION TYPE: Cross-encoder (single relevance score)") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Reranker initialization failed: {e}") |
|
|
self.reranker = None |
|
|
|
|
|
def _detect_reranker_type(self, model_name: str) -> str: |
|
|
""" |
|
|
Detect the type of reranker based on model name. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the reranker model |
|
|
|
|
|
Returns: |
|
|
'colbert' for ColBERT models, 'crossencoder' for others |
|
|
""" |
|
|
model_name_lower = model_name.lower() |
|
|
|
|
|
|
|
|
colbert_patterns = [ |
|
|
'colbert', |
|
|
'colbert-ir', |
|
|
'colbertv2', |
|
|
'colbert-v2' |
|
|
] |
|
|
|
|
|
for pattern in colbert_patterns: |
|
|
if pattern in model_name_lower: |
|
|
return 'colbert' |
|
|
|
|
|
|
|
|
return 'crossencoder' |
|
|
|
|
|
def _similarity_search_with_colbert_embeddings(self, query: str, k: int = 5, **kwargs) -> List[Tuple[Document, float]]: |
|
|
""" |
|
|
Perform similarity search and fetch ColBERT embeddings for documents. |
|
|
|
|
|
Args: |
|
|
query: Search query |
|
|
k: Number of documents to retrieve |
|
|
**kwargs: Additional search parameters (filter, etc.) |
|
|
|
|
|
Returns: |
|
|
List of (Document, score) tuples with ColBERT embeddings in metadata |
|
|
""" |
|
|
try: |
|
|
print(f"🔍 COLBERT RETRIEVAL: Fetching documents with ColBERT embeddings") |
|
|
|
|
|
|
|
|
|
|
|
if 'filter' in kwargs and kwargs['filter']: |
|
|
|
|
|
result = self.vectorstore.similarity_search_with_score( |
|
|
query, |
|
|
k=k, |
|
|
filter=kwargs['filter'] |
|
|
) |
|
|
else: |
|
|
|
|
|
result = self.vectorstore.similarity_search_with_score(query, k=k) |
|
|
|
|
|
|
|
|
if isinstance(result, tuple) and len(result) == 2: |
|
|
documents, scores = result |
|
|
elif isinstance(result, list): |
|
|
documents = [] |
|
|
scores = [] |
|
|
for item in result: |
|
|
if isinstance(item, tuple) and len(item) == 2: |
|
|
doc, score = item |
|
|
documents.append(doc) |
|
|
scores.append(score) |
|
|
else: |
|
|
documents.append(item) |
|
|
scores.append(0.0) |
|
|
else: |
|
|
documents = [] |
|
|
scores = [] |
|
|
|
|
|
|
|
|
|
|
|
from qdrant_client.http import models as rest |
|
|
|
|
|
collection_name = self.vectorstore.collection_name |
|
|
|
|
|
|
|
|
doc_ids = [] |
|
|
for doc in documents: |
|
|
|
|
|
doc_id = doc.metadata.get('id') or doc.metadata.get('_id') |
|
|
if not doc_id: |
|
|
|
|
|
import hashlib |
|
|
doc_id = hashlib.md5(doc.page_content.encode()).hexdigest() |
|
|
doc_ids.append(doc_id) |
|
|
|
|
|
|
|
|
search_result = self.vectorstore.client.retrieve( |
|
|
collection_name=collection_name, |
|
|
ids=doc_ids, |
|
|
with_payload=True, |
|
|
with_vectors=False |
|
|
) |
|
|
|
|
|
|
|
|
enhanced_documents = [] |
|
|
enhanced_scores = [] |
|
|
|
|
|
|
|
|
doc_id_to_score = {} |
|
|
for i, doc in enumerate(documents): |
|
|
doc_id = doc.metadata.get('id') or doc.metadata.get('_id') |
|
|
if not doc_id: |
|
|
import hashlib |
|
|
doc_id = hashlib.md5(doc.page_content.encode()).hexdigest() |
|
|
doc_id_to_score[doc_id] = scores[i] |
|
|
|
|
|
for point in search_result: |
|
|
|
|
|
payload = point.payload |
|
|
|
|
|
|
|
|
doc_id = str(point.id) |
|
|
original_score = doc_id_to_score.get(doc_id, 0.0) |
|
|
|
|
|
|
|
|
doc = Document( |
|
|
page_content=payload.get('page_content', ''), |
|
|
metadata={ |
|
|
**payload.get('metadata', {}), |
|
|
'colbert_embedding': payload.get('colbert_embedding'), |
|
|
'colbert_model': payload.get('colbert_model'), |
|
|
'colbert_calculated_at': payload.get('colbert_calculated_at') |
|
|
} |
|
|
) |
|
|
|
|
|
enhanced_documents.append(doc) |
|
|
enhanced_scores.append(original_score) |
|
|
|
|
|
print(f"✅ COLBERT RETRIEVAL: Retrieved {len(enhanced_documents)} documents with ColBERT embeddings") |
|
|
|
|
|
return list(zip(enhanced_documents, enhanced_scores)) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ COLBERT RETRIEVAL ERROR: {e}") |
|
|
print(f"❌ Falling back to regular similarity search") |
|
|
|
|
|
|
|
|
if 'filter' in kwargs and kwargs['filter']: |
|
|
return self.vectorstore.similarity_search_with_score(query, k=k, filter=kwargs['filter']) |
|
|
else: |
|
|
return self.vectorstore.similarity_search_with_score(query, k=k) |
|
|
|
|
|
def retrieve_context( |
|
|
self, |
|
|
query: str, |
|
|
k: int = 5, |
|
|
reports: Optional[List[str]] = None, |
|
|
sources: Optional[List[str]] = None, |
|
|
subtype: Optional[str] = None, |
|
|
year: Optional[str] = None, |
|
|
district: Optional[List[str]] = None, |
|
|
filenames: Optional[List[str]] = None, |
|
|
use_reranking: bool = False, |
|
|
qdrant_filter: Optional[rest.Filter] = None |
|
|
) -> List[Document]: |
|
|
""" |
|
|
Retrieve context documents using hybrid search with optional filtering and reranking. |
|
|
|
|
|
Args: |
|
|
query: User query |
|
|
top_k: Number of documents to retrieve |
|
|
reports: List of report names to filter by |
|
|
sources: List of sources to filter by |
|
|
subtype: Document subtype to filter by |
|
|
year: Year to filter by |
|
|
use_reranking: Whether to apply reranking |
|
|
qdrant_filter: Pre-built Qdrant filter to use |
|
|
|
|
|
Returns: |
|
|
List of retrieved documents |
|
|
""" |
|
|
try: |
|
|
|
|
|
retrieve_k = k |
|
|
|
|
|
|
|
|
search_kwargs = {} |
|
|
|
|
|
|
|
|
if qdrant_filter: |
|
|
search_kwargs = {"filter": qdrant_filter} |
|
|
print(f"✅ FILTERS APPLIED: Using inferred Qdrant filter") |
|
|
else: |
|
|
|
|
|
filter_obj = create_filter( |
|
|
reports=reports, |
|
|
sources=sources, |
|
|
subtype=subtype, |
|
|
year=year, |
|
|
district=district, |
|
|
filenames=filenames |
|
|
) |
|
|
|
|
|
if filter_obj: |
|
|
search_kwargs = {"filter": filter_obj} |
|
|
print(f"✅ FILTERS APPLIED: Using built filter") |
|
|
else: |
|
|
search_kwargs = {} |
|
|
print(f"⚠️ NO FILTERS APPLIED: All documents will be searched") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if use_reranking and self.reranker_type == 'colbert': |
|
|
result = self._similarity_search_with_colbert_embeddings( |
|
|
query, |
|
|
k=retrieve_k, |
|
|
**search_kwargs |
|
|
) |
|
|
else: |
|
|
result = self.vectorstore.similarity_search_with_score( |
|
|
query, |
|
|
k=retrieve_k, |
|
|
**search_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(result, tuple) and len(result) == 2: |
|
|
documents, scores = result |
|
|
elif isinstance(result, list) and len(result) > 0: |
|
|
|
|
|
documents = [] |
|
|
scores = [] |
|
|
for item in result: |
|
|
if isinstance(item, tuple) and len(item) == 2: |
|
|
doc, score = item |
|
|
documents.append(doc) |
|
|
scores.append(score) |
|
|
else: |
|
|
|
|
|
documents.append(item) |
|
|
scores.append(0.0) |
|
|
else: |
|
|
documents = [] |
|
|
scores = [] |
|
|
|
|
|
print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(documents)} documents (requested: {retrieve_k})") |
|
|
|
|
|
|
|
|
if len(documents) < retrieve_k and search_kwargs.get('filter'): |
|
|
print(f"⚠️ RETRIEVAL: Got {len(documents)} docs with filters, trying without filters...") |
|
|
try: |
|
|
result_no_filter = self.vectorstore.similarity_search_with_score( |
|
|
query, |
|
|
k=retrieve_k |
|
|
) |
|
|
|
|
|
if isinstance(result_no_filter, tuple) and len(result_no_filter) == 2: |
|
|
documents_no_filter, scores_no_filter = result_no_filter |
|
|
elif isinstance(result_no_filter, list): |
|
|
documents_no_filter = [] |
|
|
scores_no_filter = [] |
|
|
for item in result_no_filter: |
|
|
if isinstance(item, tuple) and len(item) == 2: |
|
|
doc, score = item |
|
|
documents_no_filter.append(doc) |
|
|
scores_no_filter.append(score) |
|
|
else: |
|
|
documents_no_filter.append(item) |
|
|
scores_no_filter.append(0.0) |
|
|
else: |
|
|
documents_no_filter = [] |
|
|
scores_no_filter = [] |
|
|
|
|
|
if len(documents_no_filter) > len(documents): |
|
|
print(f"✅ RETRIEVAL: Got {len(documents_no_filter)} docs without filters") |
|
|
documents = documents_no_filter |
|
|
scores = scores_no_filter |
|
|
except Exception as e: |
|
|
print(f"⚠️ RETRIEVAL: Fallback search failed: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ RETRIEVAL ERROR: {str(e)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
reranking_applied = False |
|
|
if use_reranking and len(documents) > 1: |
|
|
print(f"🔄 RERANKING: Applying {self.reranker_model_name} to {len(documents)} documents...") |
|
|
try: |
|
|
original_docs = documents.copy() |
|
|
original_scores = scores.copy() |
|
|
|
|
|
|
|
|
|
|
|
reranked_docs = self._apply_reranking(query, documents, scores) |
|
|
|
|
|
reranking_applied = len(reranked_docs) > 0 |
|
|
|
|
|
if reranking_applied: |
|
|
print(f"✅ RERANKING APPLIED: {self.reranker_model_name}") |
|
|
documents = reranked_docs |
|
|
|
|
|
|
|
|
else: |
|
|
print(f"⚠️ RERANKING FAILED: Using original order") |
|
|
documents = original_docs |
|
|
scores = original_scores |
|
|
return documents |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ RERANKING ERROR: {str(e)}") |
|
|
print(f"⚠️ RERANKING FAILED: Using original order") |
|
|
reranking_applied = False |
|
|
elif use_reranking and len(documents) <= 1: |
|
|
print(f"ℹ️ RERANKING: Skipped (only {len(documents)} document(s) retrieved)") |
|
|
if use_reranking: |
|
|
print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)") |
|
|
|
|
|
for i, (doc, score) in enumerate(zip(documents, scores)): |
|
|
doc.metadata['original_score'] = float(score) |
|
|
doc.metadata['reranking_applied'] = False |
|
|
return documents |
|
|
else: |
|
|
print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)") |
|
|
|
|
|
|
|
|
documents = documents[:k] |
|
|
scores = scores[:k] if scores else [0.0] * len(documents) |
|
|
|
|
|
|
|
|
for i, (doc, score) in enumerate(zip(documents, scores)): |
|
|
if hasattr(doc, 'metadata'): |
|
|
doc.metadata.update({ |
|
|
'reranking_applied': reranking_applied, |
|
|
'reranker_model': 'BAAI/bge-reranker-v2-m3' if reranking_applied else None, |
|
|
'original_rank': i + 1, |
|
|
'final_rank': i + 1, |
|
|
'original_score': float(score) if score is not None else 0.0 |
|
|
}) |
|
|
|
|
|
return documents |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ CONTEXT RETRIEVAL ERROR: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def _apply_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]: |
|
|
""" |
|
|
Apply reranking to documents using the appropriate reranker. |
|
|
|
|
|
Args: |
|
|
query: User query |
|
|
documents: List of documents to rerank |
|
|
scores: Original scores |
|
|
|
|
|
Returns: |
|
|
Reranked list of documents |
|
|
""" |
|
|
if not self.reranker or len(documents) == 0: |
|
|
return documents |
|
|
|
|
|
try: |
|
|
print(f"🔍 RERANKING METHOD: Starting reranking with {len(documents)} documents") |
|
|
print(f"🔍 RERANKING TYPE: {self.reranker_type.upper()}") |
|
|
|
|
|
if self.reranker_type == 'colbert': |
|
|
return self._apply_colbert_reranking(query, documents, scores) |
|
|
else: |
|
|
return self._apply_crossencoder_reranking(query, documents, scores) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ RERANKING ERROR: {str(e)}") |
|
|
return documents |
|
|
|
|
|
def _apply_crossencoder_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]: |
|
|
""" |
|
|
Apply reranking using CrossEncoder (BGE and other models). |
|
|
|
|
|
Args: |
|
|
query: User query |
|
|
documents: List of documents to rerank |
|
|
scores: Original scores |
|
|
|
|
|
Returns: |
|
|
Reranked list of documents |
|
|
""" |
|
|
|
|
|
pairs = [] |
|
|
for doc in documents: |
|
|
pairs.append([query, doc.page_content]) |
|
|
|
|
|
print(f"🔍 CROSS-ENCODER: Prepared {len(pairs)} pairs for reranking") |
|
|
|
|
|
|
|
|
rerank_scores = self.reranker.predict(pairs) |
|
|
|
|
|
|
|
|
if not isinstance(rerank_scores, (list, np.ndarray)): |
|
|
rerank_scores = [rerank_scores] |
|
|
|
|
|
|
|
|
if len(rerank_scores) != len(documents): |
|
|
print(f"⚠️ RERANKING WARNING: Expected {len(documents)} scores, got {len(rerank_scores)}") |
|
|
return documents |
|
|
|
|
|
print(f"🔍 CROSS-ENCODER: Got {len(rerank_scores)} rerank scores") |
|
|
print(f"🔍 CROSS-ENCODER SCORES: {rerank_scores[:5]}...") |
|
|
|
|
|
|
|
|
doc_scores = list(zip(documents, rerank_scores)) |
|
|
|
|
|
|
|
|
doc_scores.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
|
|
|
reranked_docs = [] |
|
|
for i, (doc, rerank_score) in enumerate(doc_scores): |
|
|
|
|
|
original_idx = documents.index(doc) |
|
|
original_score = scores[original_idx] if original_idx < len(scores) else 0.0 |
|
|
|
|
|
|
|
|
new_doc = Document( |
|
|
page_content=doc.page_content, |
|
|
metadata={ |
|
|
**doc.metadata, |
|
|
'reranking_applied': True, |
|
|
'reranker_model': self.reranker_model_name, |
|
|
'reranker_type': self.reranker_type, |
|
|
'original_rank': original_idx + 1, |
|
|
'final_rank': i + 1, |
|
|
'original_score': float(original_score), |
|
|
'reranked_score': float(rerank_score) |
|
|
} |
|
|
) |
|
|
reranked_docs.append(new_doc) |
|
|
|
|
|
print(f"✅ CROSS-ENCODER: Reranked {len(reranked_docs)} documents") |
|
|
|
|
|
return reranked_docs |
|
|
|
|
|
def _apply_colbert_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]: |
|
|
""" |
|
|
Apply reranking using ColBERT late interaction. |
|
|
|
|
|
Args: |
|
|
query: User query |
|
|
documents: List of documents to rerank |
|
|
scores: Original scores |
|
|
|
|
|
Returns: |
|
|
Reranked list of documents |
|
|
""" |
|
|
|
|
|
return self._colbert_rerank(query, documents, scores) |
|
|
|
|
|
def _colbert_rerank(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]: |
|
|
""" |
|
|
ColBERT reranking using late interaction with pre-calculated embeddings support. |
|
|
|
|
|
Args: |
|
|
query: User query |
|
|
documents: List of documents to rerank |
|
|
scores: Original scores |
|
|
|
|
|
Returns: |
|
|
Reranked list of documents |
|
|
""" |
|
|
try: |
|
|
print(f"🔍 COLBERT: Starting late interaction reranking with {len(documents)} documents") |
|
|
|
|
|
|
|
|
pre_calculated_embeddings = [] |
|
|
documents_without_embeddings = [] |
|
|
documents_without_indices = [] |
|
|
|
|
|
for i, doc in enumerate(documents): |
|
|
if (hasattr(doc, 'metadata') and |
|
|
'colbert_embedding' in doc.metadata and |
|
|
doc.metadata['colbert_embedding'] is not None): |
|
|
|
|
|
colbert_embedding = doc.metadata['colbert_embedding'] |
|
|
if isinstance(colbert_embedding, list): |
|
|
colbert_embedding = torch.tensor(colbert_embedding) |
|
|
pre_calculated_embeddings.append(colbert_embedding) |
|
|
else: |
|
|
|
|
|
documents_without_embeddings.append(doc) |
|
|
documents_without_indices.append(i) |
|
|
|
|
|
|
|
|
query_embeddings = self.colbert_checkpoint.queryFromText([query]) |
|
|
|
|
|
|
|
|
if documents_without_embeddings: |
|
|
print(f"🔄 COLBERT: Calculating embeddings for {len(documents_without_embeddings)} documents without pre-calculated embeddings") |
|
|
doc_texts = [doc.page_content for doc in documents_without_embeddings] |
|
|
doc_embeddings = self.colbert_checkpoint.docFromText(doc_texts) |
|
|
|
|
|
|
|
|
for i, embedding in enumerate(doc_embeddings): |
|
|
idx = documents_without_indices[i] |
|
|
pre_calculated_embeddings.insert(idx, embedding) |
|
|
else: |
|
|
print(f"✅ COLBERT: Using pre-calculated embeddings for all {len(documents)} documents") |
|
|
|
|
|
|
|
|
|
|
|
colbert_scores = [] |
|
|
for i, doc_embedding in enumerate(pre_calculated_embeddings): |
|
|
|
|
|
sim_matrix = torch.matmul(query_embeddings[0], doc_embedding.transpose(-1, -2)) |
|
|
|
|
|
|
|
|
max_sim_per_query_token = torch.max(sim_matrix, dim=-1)[0] |
|
|
|
|
|
|
|
|
final_score = torch.sum(max_sim_per_query_token).item() |
|
|
colbert_scores.append(final_score) |
|
|
|
|
|
|
|
|
doc_scores = list(zip(documents, colbert_scores)) |
|
|
doc_scores.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
|
|
|
reranked_docs = [] |
|
|
for i, (doc, colbert_score) in enumerate(doc_scores): |
|
|
original_idx = documents.index(doc) |
|
|
original_score = scores[original_idx] if original_idx < len(scores) else 0.0 |
|
|
|
|
|
new_doc = Document( |
|
|
page_content=doc.page_content, |
|
|
metadata={ |
|
|
**doc.metadata, |
|
|
'reranking_applied': True, |
|
|
'reranker_model': self.reranker_model_name, |
|
|
'reranker_type': self.reranker_type, |
|
|
'original_rank': original_idx + 1, |
|
|
'final_rank': i + 1, |
|
|
'original_score': float(original_score), |
|
|
'reranked_score': float(colbert_score), |
|
|
'colbert_score': float(colbert_score), |
|
|
'colbert_embedding_pre_calculated': 'colbert_embedding' in doc.metadata |
|
|
} |
|
|
) |
|
|
reranked_docs.append(new_doc) |
|
|
|
|
|
print(f"✅ COLBERT: Reranked {len(reranked_docs)} documents using late interaction") |
|
|
print(f"🔍 COLBERT SCORES: {[f'{score:.4f}' for score in colbert_scores[:5]]}...") |
|
|
|
|
|
return reranked_docs |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ COLBERT RERANKING ERROR: {str(e)}") |
|
|
print(f"❌ COLBERT TRACEBACK: {traceback.format_exc()}") |
|
|
|
|
|
return documents |
|
|
|
|
|
def retrieve_with_scores(self, query: str, vectorstore=None, k: int = 5, reports: List[str] = None, |
|
|
sources: List[str] = None, subtype: List[str] = None, |
|
|
year: List[str] = None, use_reranking: bool = False, |
|
|
qdrant_filter: Optional[rest.Filter] = None) -> Tuple[List[Document], List[float]]: |
|
|
""" |
|
|
Retrieve context documents with scores using hybrid search with optional reranking. |
|
|
|
|
|
Args: |
|
|
query: User query |
|
|
vectorstore: Optional vectorstore instance (for compatibility) |
|
|
k: Number of documents to retrieve |
|
|
reports: List of report names to filter by |
|
|
sources: List of sources to filter by |
|
|
subtype: Document subtype to filter by |
|
|
year: List of years to filter by |
|
|
use_reranking: Whether to apply reranking |
|
|
qdrant_filter: Pre-built Qdrant filter |
|
|
|
|
|
Returns: |
|
|
Tuple of (documents, scores) |
|
|
""" |
|
|
try: |
|
|
|
|
|
if vectorstore: |
|
|
self.vectorstore = vectorstore |
|
|
|
|
|
|
|
|
search_strategy = self.config.get('retrieval', {}).get('search_strategy', 'vector_only') |
|
|
|
|
|
if search_strategy == 'vector_only': |
|
|
|
|
|
print(f"🔄 VECTOR SEARCH: Retrieving {k} documents...") |
|
|
|
|
|
if qdrant_filter: |
|
|
print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter") |
|
|
|
|
|
results = self.vectorstore.similarity_search_with_score( |
|
|
query, |
|
|
k=k, |
|
|
filter=qdrant_filter |
|
|
) |
|
|
else: |
|
|
|
|
|
filter_conditions = self._build_filter_conditions(reports, sources, subtype, year) |
|
|
if filter_conditions: |
|
|
print(f"✅ FILTER APPLIED: {filter_conditions}") |
|
|
results = self.vectorstore.similarity_search_with_score( |
|
|
query, |
|
|
k=k, |
|
|
filter=filter_conditions |
|
|
) |
|
|
else: |
|
|
print(f"ℹ️ NO FILTERS APPLIED: All documents will be searched") |
|
|
results = self.vectorstore.similarity_search_with_score(query, k=k) |
|
|
|
|
|
print(f"🔍 SEARCH DEBUG: Raw result type: {type(results)}") |
|
|
print(f"🔍 SEARCH DEBUG: Raw result length: {len(results)}") |
|
|
|
|
|
|
|
|
if results and isinstance(results[0], tuple): |
|
|
documents = [doc for doc, score in results] |
|
|
scores = [score for doc, score in results] |
|
|
print(f"🔍 SEARCH DEBUG: After unpacking - documents: {len(documents)}, scores: {len(scores)}") |
|
|
else: |
|
|
documents = results |
|
|
scores = [0.0] * len(documents) |
|
|
print(f"🔍 SEARCH DEBUG: No scores available, using default") |
|
|
|
|
|
print(f"🔧 CONVERTING: Converting {len(documents)} documents") |
|
|
|
|
|
|
|
|
final_documents = [] |
|
|
for i, (doc, score) in enumerate(zip(documents, scores)): |
|
|
if hasattr(doc, 'page_content'): |
|
|
new_doc = Document( |
|
|
page_content=doc.page_content, |
|
|
metadata=doc.metadata.copy() |
|
|
) |
|
|
|
|
|
new_doc.metadata['original_score'] = float(score) if score is not None else 0.0 |
|
|
final_documents.append(new_doc) |
|
|
else: |
|
|
print(f"⚠️ WARNING: Document {i} has no page_content") |
|
|
|
|
|
print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(final_documents)} documents") |
|
|
|
|
|
|
|
|
if use_reranking and len(final_documents) > 1: |
|
|
print(f"🔄 RERANKING: Applying {self.reranker_model} to {len(final_documents)} documents...") |
|
|
final_documents = self._apply_reranking(query, final_documents, scores) |
|
|
print(f"✅ RERANKING APPLIED: {self.reranker_model}") |
|
|
else: |
|
|
print(f"ℹ️ RERANKING: Skipped (disabled or no documents)") |
|
|
|
|
|
return final_documents, scores |
|
|
|
|
|
else: |
|
|
print(f"❌ UNSUPPORTED STRATEGY: {search_strategy}") |
|
|
return [], [] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ RETRIEVAL ERROR: {e}") |
|
|
print(f"❌ RETRIEVAL TRACEBACK: {traceback.format_exc()}") |
|
|
return [], [] |
|
|
|
|
|
def _build_filter_conditions(self, reports: List[str] = None, sources: List[str] = None, |
|
|
subtype: List[str] = None, year: List[str] = None) -> Optional[rest.Filter]: |
|
|
""" |
|
|
Build Qdrant filter conditions from individual parameters. |
|
|
|
|
|
Args: |
|
|
reports: List of report names |
|
|
sources: List of sources |
|
|
subtype: Document subtype |
|
|
year: List of years |
|
|
|
|
|
Returns: |
|
|
Qdrant filter or None |
|
|
""" |
|
|
conditions = [] |
|
|
|
|
|
if reports: |
|
|
conditions.append(rest.FieldCondition( |
|
|
key="metadata.filename", |
|
|
match=rest.MatchAny(any=reports) |
|
|
)) |
|
|
|
|
|
if sources: |
|
|
conditions.append(rest.FieldCondition( |
|
|
key="metadata.source", |
|
|
match=rest.MatchAny(any=sources) |
|
|
)) |
|
|
|
|
|
if subtype: |
|
|
conditions.append(rest.FieldCondition( |
|
|
key="metadata.subtype", |
|
|
match=rest.MatchAny(any=subtype) |
|
|
)) |
|
|
|
|
|
if year: |
|
|
conditions.append(rest.FieldCondition( |
|
|
key="metadata.year", |
|
|
match=rest.MatchAny(any=year) |
|
|
)) |
|
|
|
|
|
if conditions: |
|
|
return rest.Filter(must=conditions) |
|
|
|
|
|
return None |
|
|
|
|
|
def get_context( |
|
|
query: str, |
|
|
vectorstore: Qdrant, |
|
|
k: int = 5, |
|
|
reports: Optional[List[str]] = None, |
|
|
sources: Optional[List[str]] = None, |
|
|
subtype: Optional[str] = None, |
|
|
year: Optional[str] = None, |
|
|
use_reranking: bool = False, |
|
|
qdrant_filter: Optional[rest.Filter] = None |
|
|
) -> List[Document]: |
|
|
""" |
|
|
Convenience function to get context documents. |
|
|
|
|
|
Args: |
|
|
query: User query |
|
|
vectorstore: Qdrant vector store instance |
|
|
k: Number of documents to retrieve |
|
|
reports: Optional list of report names to filter by |
|
|
sources: Optional list of source categories to filter by |
|
|
subtype: Optional subtype to filter by |
|
|
year: Optional year to filter by |
|
|
use_reranking: Whether to apply reranking |
|
|
qdrant_filter: Optional pre-built Qdrant filter |
|
|
|
|
|
Returns: |
|
|
List of retrieved documents |
|
|
""" |
|
|
retriever = ContextRetriever(vectorstore) |
|
|
return retriever.retrieve_context( |
|
|
query=query, |
|
|
k=k, |
|
|
reports=reports, |
|
|
sources=sources, |
|
|
subtype=subtype, |
|
|
year=year, |
|
|
use_reranking=use_reranking, |
|
|
qdrant_filter=qdrant_filter |
|
|
) |
|
|
|
|
|
|
|
|
def format_context_for_llm(documents: List[Document]) -> str: |
|
|
""" |
|
|
Format retrieved documents for LLM input. |
|
|
|
|
|
Args: |
|
|
documents: List of Document objects |
|
|
|
|
|
Returns: |
|
|
Formatted string for LLM |
|
|
""" |
|
|
if not documents: |
|
|
return "" |
|
|
|
|
|
formatted_parts = [] |
|
|
for i, doc in enumerate(documents, 1): |
|
|
content = doc.page_content.strip() |
|
|
source = doc.metadata.get('filename', 'Unknown') |
|
|
|
|
|
formatted_parts.append(f"Document {i} (Source: {source}):\n{content}") |
|
|
|
|
|
return "\n\n".join(formatted_parts) |
|
|
|
|
|
|
|
|
def get_context_metadata(documents: List[Document]) -> Dict[str, Any]: |
|
|
""" |
|
|
Extract metadata summary from retrieved documents. |
|
|
|
|
|
Args: |
|
|
documents: List of Document objects |
|
|
|
|
|
Returns: |
|
|
Dictionary with metadata summary |
|
|
""" |
|
|
if not documents: |
|
|
return {} |
|
|
|
|
|
sources = set() |
|
|
years = set() |
|
|
doc_types = set() |
|
|
|
|
|
for doc in documents: |
|
|
metadata = doc.metadata |
|
|
if 'filename' in metadata: |
|
|
sources.add(metadata['filename']) |
|
|
if 'year' in metadata: |
|
|
years.add(metadata['year']) |
|
|
if 'source' in metadata: |
|
|
doc_types.add(metadata['source']) |
|
|
|
|
|
return { |
|
|
"num_documents": len(documents), |
|
|
"sources": list(sources), |
|
|
"years": list(years), |
|
|
"document_types": list(doc_types) |
|
|
} |
|
|
|