|
import asyncio |
|
import time |
|
import traceback |
|
from typing import List, Optional |
|
|
|
import logfire |
|
import tiktoken |
|
from cohere import AsyncClient |
|
from llama_index.core import QueryBundle |
|
from llama_index.core.async_utils import run_async_tasks |
|
from llama_index.core.callbacks import CBEventType, EventPayload |
|
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever |
|
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode |
|
from llama_index.postprocessor.cohere_rerank import CohereRerank |
|
from llama_index.postprocessor.cohere_rerank.base import CohereRerank |
|
|
|
|
|
class AsyncCohereRerank(CohereRerank): |
|
def __init__( |
|
self, |
|
top_n: int = 5, |
|
model: str = "rerank-english-v3.0", |
|
api_key: Optional[str] = None, |
|
) -> None: |
|
super().__init__(top_n=top_n, model=model, api_key=api_key) |
|
self._api_key = api_key |
|
self._model = model |
|
self._top_n = top_n |
|
|
|
async def apostprocess_nodes( |
|
self, |
|
nodes: List[NodeWithScore], |
|
query_bundle: Optional[QueryBundle] = None, |
|
) -> List[NodeWithScore]: |
|
if query_bundle is None: |
|
raise ValueError("Query bundle must be provided.") |
|
|
|
if len(nodes) == 0: |
|
return [] |
|
|
|
async_client = AsyncClient(api_key=self._api_key) |
|
|
|
with self.callback_manager.event( |
|
CBEventType.RERANKING, |
|
payload={ |
|
EventPayload.NODES: nodes, |
|
EventPayload.MODEL_NAME: self._model, |
|
EventPayload.QUERY_STR: query_bundle.query_str, |
|
EventPayload.TOP_K: self._top_n, |
|
}, |
|
) as event: |
|
texts = [ |
|
node.node.get_content(metadata_mode=MetadataMode.EMBED) |
|
for node in nodes |
|
] |
|
|
|
results = await async_client.rerank( |
|
model=self._model, |
|
top_n=self._top_n, |
|
query=query_bundle.query_str, |
|
documents=texts, |
|
) |
|
|
|
new_nodes = [] |
|
for result in results.results: |
|
new_node_with_score = NodeWithScore( |
|
node=nodes[result.index].node, score=result.relevance_score |
|
) |
|
new_nodes.append(new_node_with_score) |
|
event.on_end(payload={EventPayload.NODES: new_nodes}) |
|
|
|
return new_nodes |
|
|
|
|
|
class CustomRetriever(BaseRetriever): |
|
"""Custom retriever that performs both semantic search and hybrid search.""" |
|
|
|
def __init__( |
|
self, |
|
vector_retriever: VectorIndexRetriever, |
|
document_dict: dict, |
|
keyword_retriever, |
|
mode: str = "AND", |
|
) -> None: |
|
"""Init params.""" |
|
|
|
self._vector_retriever = vector_retriever |
|
self._document_dict = document_dict |
|
|
|
self._keyword_retriever = keyword_retriever |
|
if mode not in ("AND", "OR"): |
|
raise ValueError("Invalid mode.") |
|
self._mode = mode |
|
|
|
super().__init__() |
|
|
|
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: |
|
"""Retrieve nodes given query.""" |
|
|
|
|
|
query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "") |
|
query_bundle.query_str = query_bundle.query_str.rstrip() |
|
|
|
|
|
start = time.time() |
|
nodes = await self._vector_retriever.aretrieve(query_bundle) |
|
keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle) |
|
|
|
|
|
|
|
|
|
vector_ids = {n.node.node_id for n in nodes} |
|
keyword_ids = {n.node.node_id for n in keyword_nodes} |
|
|
|
combined_dict = {n.node.node_id: n for n in nodes} |
|
combined_dict.update({n.node.node_id: n for n in keyword_nodes}) |
|
|
|
if self._mode == "AND": |
|
retrieve_ids = vector_ids.intersection(keyword_ids) |
|
else: |
|
retrieve_ids = vector_ids.union(keyword_ids) |
|
|
|
nodes = [combined_dict[rid] for rid in retrieve_ids] |
|
|
|
|
|
def filter_nodes_by_unique_doc_id(nodes): |
|
unique_nodes = {} |
|
for node in nodes: |
|
|
|
doc_id = node.node.source_node.node_id |
|
if doc_id is not None and doc_id not in unique_nodes: |
|
unique_nodes[doc_id] = node |
|
return list(unique_nodes.values()) |
|
|
|
nodes = filter_nodes_by_unique_doc_id(nodes) |
|
|
|
|
|
|
|
|
|
|
|
nodes_context = [] |
|
for node in nodes: |
|
|
|
|
|
|
|
|
|
|
|
|
|
doc_id = node.node.source_node.node_id |
|
if node.metadata["retrieve_doc"] == True: |
|
|
|
|
|
|
|
doc = self._document_dict[doc_id] |
|
|
|
new_node = NodeWithScore( |
|
node=TextNode(text=doc.text, metadata=node.metadata, id_=doc_id), |
|
score=node.score, |
|
) |
|
nodes_context.append(new_node) |
|
else: |
|
node.node.node_id = doc_id |
|
nodes_context.append(node) |
|
|
|
try: |
|
reranker = AsyncCohereRerank(top_n=3, model="rerank-english-v3.0") |
|
nodes_context = await reranker.apostprocess_nodes( |
|
nodes_context, query_bundle |
|
) |
|
|
|
except Exception as e: |
|
error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n" |
|
error_msg += "Traceback:\n" |
|
error_msg += traceback.format_exc() |
|
logfire.error(error_msg) |
|
|
|
nodes_filtered = [] |
|
total_tokens = 0 |
|
enc = tiktoken.encoding_for_model("gpt-4o-mini") |
|
for node in nodes_context: |
|
if node.score < 0.10: |
|
continue |
|
|
|
|
|
if "tokens" in node.node.metadata: |
|
node_tokens = node.node.metadata["tokens"] |
|
else: |
|
node_tokens = len(enc.encode(node.node.text)) |
|
|
|
if total_tokens + node_tokens > 100_000: |
|
logfire.info("Skipping node due to token count exceeding 100k") |
|
break |
|
|
|
total_tokens += node_tokens |
|
nodes_filtered.append(node) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return nodes_filtered[:3] |
|
|
|
|
|
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: |
|
"""Retrieve nodes given query.""" |
|
|
|
|
|
query_bundle.query_str = query_bundle.query_str.replace("\ninput is ", "") |
|
query_bundle.query_str = query_bundle.query_str.rstrip() |
|
logfire.info(f"Retrieving nodes with string: '{query_bundle}'") |
|
|
|
start = time.time() |
|
nodes = self._vector_retriever.retrieve(query_bundle) |
|
keyword_nodes = self._keyword_retriever.retrieve(query_bundle) |
|
|
|
logfire.info(f"Number of vector nodes: {len(nodes)}") |
|
logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}") |
|
|
|
vector_ids = {n.node.node_id for n in nodes} |
|
keyword_ids = {n.node.node_id for n in keyword_nodes} |
|
|
|
combined_dict = {n.node.node_id: n for n in nodes} |
|
combined_dict.update({n.node.node_id: n for n in keyword_nodes}) |
|
|
|
if self._mode == "AND": |
|
retrieve_ids = vector_ids.intersection(keyword_ids) |
|
else: |
|
retrieve_ids = vector_ids.union(keyword_ids) |
|
|
|
nodes = [combined_dict[rid] for rid in retrieve_ids] |
|
|
|
def filter_nodes_by_unique_doc_id(nodes): |
|
unique_nodes = {} |
|
for node in nodes: |
|
|
|
doc_id = node.node.source_node.node_id |
|
if doc_id is not None and doc_id not in unique_nodes: |
|
unique_nodes[doc_id] = node |
|
return list(unique_nodes.values()) |
|
|
|
nodes = filter_nodes_by_unique_doc_id(nodes) |
|
logfire.info( |
|
f"Number of nodes after filtering the ones with same ref_doc_id: {len(nodes)}" |
|
) |
|
logfire.info(f"Nodes retrieved: {nodes}") |
|
|
|
nodes_context = [] |
|
for node in nodes: |
|
doc_id = node.node.source_node.node_id |
|
if node.metadata["retrieve_doc"] == True: |
|
doc = self._document_dict[doc_id] |
|
new_node = NodeWithScore( |
|
node=TextNode(text=doc.text, metadata=node.metadata, id_=doc_id), |
|
score=node.score, |
|
) |
|
nodes_context.append(new_node) |
|
else: |
|
node.node.node_id = doc_id |
|
nodes_context.append(node) |
|
|
|
try: |
|
reranker = CohereRerank(top_n=3, model="rerank-english-v3.0") |
|
nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle) |
|
|
|
except Exception as e: |
|
error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n" |
|
error_msg += "Traceback:\n" |
|
error_msg += traceback.format_exc() |
|
logfire.error(error_msg) |
|
|
|
nodes_filtered = [] |
|
total_tokens = 0 |
|
enc = tiktoken.encoding_for_model("gpt-4o-mini") |
|
for node in nodes_context: |
|
if node.score < 0.10: |
|
continue |
|
if "tokens" in node.node.metadata: |
|
node_tokens = node.node.metadata["tokens"] |
|
else: |
|
node_tokens = len(enc.encode(node.node.text)) |
|
|
|
if total_tokens + node_tokens > 100_000: |
|
logfire.info("Skipping node due to token count exceeding 100k") |
|
break |
|
|
|
total_tokens += node_tokens |
|
nodes_filtered.append(node) |
|
|
|
logfire.info(f"Final nodes to context {len(nodes_filtered)} nodes") |
|
logfire.info(f"Total tokens: {total_tokens}") |
|
|
|
duration = time.time() - start |
|
logfire.info(f"Retrieving nodes took {duration:.2f}s") |
|
|
|
return nodes_filtered[:3] |
|
|