import torch from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer from chunker_final import chunk_documents_to_dict import numpy as np class Retriever: def __init__(self, docs: dict) -> None: self.chunked_docs = chunk_documents_to_dict(docs) self.chunk_ids = list(self.chunked_docs.keys()) self.chunk_texts = list(self.chunked_docs.values()) tokenized_chunks = [text.lower().split(" ") for text in self.chunk_texts] self.bm25 = BM25Okapi(tokenized_chunks) self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1') self.doc_embeddings = self.sbert.encode(self.chunk_texts) def get_docs(self, query, method, n=15) -> dict: if method == "BM25": scores = self._get_bm25_scores(query) elif method == "semantic": scores = self._get_semantic_scores(query) elif method == "combined search": bm25_scores = self._get_bm25_scores(query) semantic_scores = self._get_semantic_scores(query) scores = 0.3 * bm25_scores + 0.7 * semantic_scores else: raise ValueError(f"Invalid search method: {method}") sorted_indices = scores.argsort(descending=True) result = {self.chunk_ids[i]: self.chunk_texts[i] for i in sorted_indices[:n]} return result def rerank(self, query, retrieved_docs: dict) -> dict: query_embedding = self.sbert.encode(query) rerank_scores = {} for chunk_id, chunk_text in retrieved_docs.items(): chunk_embedding = self.sbert.encode(chunk_text) similarity = np.dot(query_embedding, chunk_embedding) / ( np.linalg.norm(query_embedding) * np.linalg.norm(chunk_embedding) ) rerank_scores[chunk_id] = similarity sorted_chunks = sorted(rerank_scores.items(), key=lambda x: x[1], reverse=True) reranked_docs = {chunk_id: retrieved_docs[chunk_id] for chunk_id, _ in sorted_chunks} return reranked_docs def _get_bm25_scores(self, query): tokenized_query = query.lower().split(" ") return torch.tensor(self.bm25.get_scores(tokenized_query)) def _get_semantic_scores(self, query): query_embedding = self.sbert.encode(query) scores = np.dot(self.doc_embeddings, query_embedding) / ( np.linalg.norm(self.doc_embeddings, axis=1) * np.linalg.norm(query_embedding) ) return torch.tensor(scores)