File size: 3,425 Bytes
2224798
37cbdf5
 
2224798
37cbdf5
 
 
a4332b3
37cbdf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
069addf
 
 
 
2224798
 
37cbdf5
 
2224798
 
 
37cbdf5
 
 
 
 
 
 
 
 
 
a4332b3
 
 
 
37cbdf5
 
 
 
 
 
 
 
 
a4332b3
 
37cbdf5
 
 
 
 
2224798
37cbdf5
 
 
 
 
 
e1d5b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import time
from typing import List

import logfire
from llama_index.core import QueryBundle
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.schema import NodeWithScore, TextNode
from llama_index.postprocessor.cohere_rerank import CohereRerank


class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        document_dict: dict,
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._document_dict = document_dict
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""

        # LlamaIndex adds "\ninput is " to the query string
        query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "")
        query_bundle.query_str = query_bundle.query_str.rstrip()

        logfire.info(f"Retrieving 10 nodes with string: '{query_bundle}'")
        start = time.time()
        nodes = self._vector_retriever.retrieve(query_bundle)

        duration = time.time() - start
        logfire.info(f"Retrieving nodes took {duration:.2f}s")

        # Filter out nodes with the same ref_doc_id
        def filter_nodes_by_unique_doc_id(nodes):
            unique_nodes = {}
            for node in nodes:
                doc_id = node.node.ref_doc_id
                if doc_id is not None and doc_id not in unique_nodes:
                    unique_nodes[doc_id] = node
            return list(unique_nodes.values())

        nodes = filter_nodes_by_unique_doc_id(nodes)
        logfire.info(
            f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}"
        )
        logfire.info(f"Nodes retrieved: {nodes}")

        nodes_context = []
        for node in nodes:
            # print("Node ID\t", node.node_id)
            # print("Title\t", node.metadata["title"])
            # print("Text\t", node.text)
            # print("Score\t", node.score)
            # print("Metadata\t", node.metadata)
            # print("-_" * 20)
            if node.score < 0.2:
                continue
            if node.metadata["retrieve_doc"] == True:
                # print("This node will be replaced by the document")
                doc = self._document_dict[node.node.ref_doc_id]
                # print(doc.text)
                new_node = NodeWithScore(
                    node=TextNode(text=doc.text, metadata=node.metadata),  # type: ignore
                    score=node.score,
                )
                nodes_context.append(new_node)
            else:
                nodes_context.append(node)

        try:
            reranker = CohereRerank(top_n=5, model="rerank-english-v3.0")
            nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle)
            nodes_filtered = []
            for node in nodes_context:
                if node.score < 0.10:  # type: ignore
                    continue
                else:
                    nodes_filtered.append(node)
            logfire.info(f"Cohere raranking to {len(nodes_filtered)} nodes")

            return nodes_filtered
        except Exception as e:
            logfire.error(f"Error reranking nodes with Cohere: {e}")
            return nodes_context