from typing import TYPE_CHECKING, Literal from injector import inject, singleton from llama_index.core.indices import VectorStoreIndex from llama_index.core.schema import NodeWithScore from llama_index.core.storage import StorageContext from pydantic import BaseModel, Field from private_gpt.components.embedding.embedding_component import EmbeddingComponent from private_gpt.components.llm.llm_component import LLMComponent from private_gpt.components.node_store.node_store_component import NodeStoreComponent from private_gpt.components.vector_store.vector_store_component import ( VectorStoreComponent, ) from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.server.ingest.model import IngestedDoc if TYPE_CHECKING: from llama_index.core.schema import RelatedNodeInfo class Chunk(BaseModel): object: Literal["context.chunk"] score: float = Field(examples=[0.023]) document: IngestedDoc text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."]) previous_texts: list[str] | None = Field( default=None, examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]], ) next_texts: list[str] | None = Field( default=None, examples=[ [ "New leads came from Google Ads campaign.", "The campaign was run by the Marketing Department", ] ], ) @classmethod def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk": doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-" return cls( object="context.chunk", score=node.score or 0.0, document=IngestedDoc( object="ingest.document", doc_id=doc_id, doc_metadata=node.metadata, ), text=node.get_content(), ) @singleton class ChunksService: @inject def __init__( self, llm_component: LLMComponent, vector_store_component: VectorStoreComponent, embedding_component: EmbeddingComponent, node_store_component: NodeStoreComponent, ) -> None: self.vector_store_component = vector_store_component self.llm_component = llm_component self.embedding_component = embedding_component self.storage_context = StorageContext.from_defaults( vector_store=vector_store_component.vector_store, docstore=node_store_component.doc_store, index_store=node_store_component.index_store, ) def _get_sibling_nodes_text( self, node_with_score: NodeWithScore, related_number: int, forward: bool = True ) -> list[str]: explored_nodes_texts = [] current_node = node_with_score.node for _ in range(related_number): explored_node_info: RelatedNodeInfo | None = ( current_node.next_node if forward else current_node.prev_node ) if explored_node_info is None: break explored_node = self.storage_context.docstore.get_node( explored_node_info.node_id ) explored_nodes_texts.append(explored_node.get_content()) current_node = explored_node return explored_nodes_texts def retrieve_relevant( self, text: str, context_filter: ContextFilter | None = None, limit: int = 10, prev_next_chunks: int = 0, ) -> list[Chunk]: index = VectorStoreIndex.from_vector_store( self.vector_store_component.vector_store, storage_context=self.storage_context, llm=self.llm_component.llm, embed_model=self.embedding_component.embedding_model, show_progress=True, ) vector_index_retriever = self.vector_store_component.get_retriever( index=index, context_filter=context_filter, similarity_top_k=limit ) nodes = vector_index_retriever.retrieve(text) nodes.sort(key=lambda n: n.score or 0.0, reverse=True) retrieved_nodes = [] for node in nodes: chunk = Chunk.from_node(node) chunk.previous_texts = self._get_sibling_nodes_text( node, prev_next_chunks, False ) chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks) retrieved_nodes.append(chunk) return retrieved_nodes