Spaces:
Sleeping
Sleeping
import time | |
from typing import List | |
import logfire | |
from llama_index.core import QueryBundle | |
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever | |
from llama_index.core.schema import NodeWithScore, TextNode | |
from llama_index.postprocessor.cohere_rerank import CohereRerank | |
class CustomRetriever(BaseRetriever): | |
"""Custom retriever that performs both semantic search and hybrid search.""" | |
def __init__( | |
self, | |
vector_retriever: VectorIndexRetriever, | |
document_dict: dict, | |
) -> None: | |
"""Init params.""" | |
self._vector_retriever = vector_retriever | |
self._document_dict = document_dict | |
super().__init__() | |
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
"""Retrieve nodes given query.""" | |
# LlamaIndex adds "\ninput is " to the query string | |
query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "") | |
query_bundle.query_str = query_bundle.query_str.rstrip() | |
logfire.info(f"Retrieving 10 nodes with string: '{query_bundle}'") | |
start = time.time() | |
nodes = self._vector_retriever.retrieve(query_bundle) | |
duration = time.time() - start | |
logfire.info(f"Retrieving nodes took {duration:.2f}s") | |
# Filter out nodes with the same ref_doc_id | |
def filter_nodes_by_unique_doc_id(nodes): | |
unique_nodes = {} | |
for node in nodes: | |
doc_id = node.node.ref_doc_id | |
if doc_id is not None and doc_id not in unique_nodes: | |
unique_nodes[doc_id] = node | |
return list(unique_nodes.values()) | |
nodes = filter_nodes_by_unique_doc_id(nodes) | |
logfire.info( | |
f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}" | |
) | |
logfire.info(f"Nodes retrieved: {nodes}") | |
nodes_context = [] | |
for node in nodes: | |
# print("Node ID\t", node.node_id) | |
# print("Title\t", node.metadata["title"]) | |
# print("Text\t", node.text) | |
# print("Score\t", node.score) | |
# print("Metadata\t", node.metadata) | |
# print("-_" * 20) | |
if node.score < 0.2: | |
continue | |
if node.metadata["retrieve_doc"] == True: | |
# print("This node will be replaced by the document") | |
doc = self._document_dict[node.node.ref_doc_id] | |
# print(doc.text) | |
new_node = NodeWithScore( | |
node=TextNode(text=doc.text, metadata=node.metadata), # type: ignore | |
score=node.score, | |
) | |
nodes_context.append(new_node) | |
else: | |
nodes_context.append(node) | |
reranker = CohereRerank(top_n=5, model="rerank-english-v3.0") | |
nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle) | |
nodes_filtered = [] | |
for node in nodes_context: | |
if node.score < 0.15: | |
continue | |
else: | |
nodes_filtered.append(node) | |
logfire.info(f"Cohere raranking to {len(nodes_filtered)} nodes") | |
return nodes_filtered | |