File size: 5,693 Bytes
a19a241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# file: retrieval.py

import time
import asyncio
import numpy as np
import torch
from groq import AsyncGroq
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.nn.functional import cosine_similarity
from typing import List, Dict, Tuple

from embedding import EmbeddingClient
from langchain_core.documents import Document

# --- Configuration ---
HYDE_MODEL = "llama3-8b-8192"
RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
INITIAL_K_CANDIDATES = 20
TOP_K_CHUNKS = 10 

async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
    """Generates a hypothetical document (HyDE) to enhance search."""
    if not groq_api_key:
        print("Groq API key not set. Skipping HyDE generation.")
        return ""

    print(f"Starting HyDE generation for query: '{query}'...")
    client = AsyncGroq(api_key=groq_api_key)
    prompt = (
        f"Write a brief, formal passage that answers the following question. "
        f"Use specific terminology as if it were from a larger document. "
        f"Do not include the question or conversational text.\n\n"
        f"Question: {query}\n\n"
        f"Hypothetical Passage:"
    )

    try:
        chat_completion = await client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=HYDE_MODEL,
            temperature=0.7,
            max_tokens=500,
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        print(f"An error occurred during HyDE generation: {e}")
        return ""

class Retriever:
    """Manages hybrid search, combining BM25, dense search, and a reranker."""

    def __init__(self, embedding_client: EmbeddingClient):
        self.embedding_client = embedding_client
        self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
        self.bm25 = None
        self.document_chunks = []
        self.chunk_embeddings = None
        print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")

    def index(self, documents: List[Document]):
        """Builds the search index from document chunks."""
        self.document_chunks = documents
        corpus = [doc.page_content for doc in documents]
        if not corpus:
            print("No documents to index.")
            return

        print("Indexing documents for retrieval...")
        # 1. Initialize BM25 model
        tokenized_corpus = [doc.split(" ") for doc in corpus]
        self.bm25 = BM25Okapi(tokenized_corpus)
        # 2. Compute and store dense embeddings
        self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
        print("Indexing complete.")

    def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
        """Performs the initial hybrid search to get candidate chunks."""
        if self.bm25 is None or self.chunk_embeddings is None:
            raise ValueError("Retriever has not been indexed. Call index() first.")

        # Enhance query with hypothetical document
        enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query

        # BM25 (keyword) search
        tokenized_query = query.split(" ")
        bm25_scores = self.bm25.get_scores(tokenized_query)

        # Dense (semantic) search
        query_embedding = self.embedding_client.create_embeddings([enhanced_query])
        dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()

        # Normalize and combine scores
        scaler = MinMaxScaler()
        norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
        norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
        combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
        
        # Get top initial candidates
        top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
        return [(idx, combined_scores[idx]) for idx in top_indices]

    async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
        """Reranks the candidate chunks using a CrossEncoder model."""
        if not candidates:
            return []

        print(f"Reranking {len(candidates)} candidates...")
        rerank_input = [[query, chunk["content"]] for chunk in candidates]
        
        # Run synchronous prediction in a separate thread
        rerank_scores = await asyncio.to_thread(
            self.reranker.predict, rerank_input, show_progress_bar=False
        )

        # Combine candidates with their new scores and sort
        for candidate, score in zip(candidates, rerank_scores):
            candidate['rerank_score'] = score
        
        candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
        return candidates[:TOP_K_CHUNKS]

    async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]:
        """Executes the full retrieval pipeline: hybrid search followed by reranking."""
        print(f"Retrieving documents for query: '{query}'")
        # 1. Get initial candidates from hybrid search
        initial_candidates_info = self._hybrid_search(query, hyde_doc)
        
        retrieved_candidates = [{
            "content": self.document_chunks[idx].page_content,
            "metadata": self.document_chunks[idx].metadata,
            "initial_score": score
        } for idx, score in initial_candidates_info]

        # 2. Rerank the candidates to get the final list
        final_chunks = await self._rerank(query, retrieved_candidates)
        print(f"Retrieved and reranked {len(final_chunks)} final chunks.")
        return final_chunks