Spaces:
Runtime error
Runtime error
from typing import TYPE_CHECKING, Literal | |
from injector import inject, singleton | |
from llama_index.core.indices import VectorStoreIndex | |
from llama_index.core.schema import NodeWithScore | |
from llama_index.core.storage import StorageContext | |
from pydantic import BaseModel, Field | |
from private_gpt.components.embedding.embedding_component import EmbeddingComponent | |
from private_gpt.components.llm.llm_component import LLMComponent | |
from private_gpt.components.node_store.node_store_component import NodeStoreComponent | |
from private_gpt.components.vector_store.vector_store_component import ( | |
VectorStoreComponent, | |
) | |
from private_gpt.open_ai.extensions.context_filter import ContextFilter | |
from private_gpt.server.ingest.model import IngestedDoc | |
if TYPE_CHECKING: | |
from llama_index.core.schema import RelatedNodeInfo | |
class Chunk(BaseModel): | |
object: Literal["context.chunk"] | |
score: float = Field(examples=[0.023]) | |
document: IngestedDoc | |
text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."]) | |
previous_texts: list[str] | None = Field( | |
default=None, | |
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]], | |
) | |
next_texts: list[str] | None = Field( | |
default=None, | |
examples=[ | |
[ | |
"New leads came from Google Ads campaign.", | |
"The campaign was run by the Marketing Department", | |
] | |
], | |
) | |
def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk": | |
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-" | |
return cls( | |
object="context.chunk", | |
score=node.score or 0.0, | |
document=IngestedDoc( | |
object="ingest.document", | |
doc_id=doc_id, | |
doc_metadata=node.metadata, | |
), | |
text=node.get_content(), | |
) | |
class ChunksService: | |
def __init__( | |
self, | |
llm_component: LLMComponent, | |
vector_store_component: VectorStoreComponent, | |
embedding_component: EmbeddingComponent, | |
node_store_component: NodeStoreComponent, | |
) -> None: | |
self.vector_store_component = vector_store_component | |
self.llm_component = llm_component | |
self.embedding_component = embedding_component | |
self.storage_context = StorageContext.from_defaults( | |
vector_store=vector_store_component.vector_store, | |
docstore=node_store_component.doc_store, | |
index_store=node_store_component.index_store, | |
) | |
def _get_sibling_nodes_text( | |
self, node_with_score: NodeWithScore, related_number: int, forward: bool = True | |
) -> list[str]: | |
explored_nodes_texts = [] | |
current_node = node_with_score.node | |
for _ in range(related_number): | |
explored_node_info: RelatedNodeInfo | None = ( | |
current_node.next_node if forward else current_node.prev_node | |
) | |
if explored_node_info is None: | |
break | |
explored_node = self.storage_context.docstore.get_node( | |
explored_node_info.node_id | |
) | |
explored_nodes_texts.append(explored_node.get_content()) | |
current_node = explored_node | |
return explored_nodes_texts | |
def retrieve_relevant( | |
self, | |
text: str, | |
context_filter: ContextFilter | None = None, | |
limit: int = 10, | |
prev_next_chunks: int = 0, | |
) -> list[Chunk]: | |
index = VectorStoreIndex.from_vector_store( | |
self.vector_store_component.vector_store, | |
storage_context=self.storage_context, | |
llm=self.llm_component.llm, | |
embed_model=self.embedding_component.embedding_model, | |
show_progress=True, | |
) | |
vector_index_retriever = self.vector_store_component.get_retriever( | |
index=index, context_filter=context_filter, similarity_top_k=limit | |
) | |
nodes = vector_index_retriever.retrieve(text) | |
nodes.sort(key=lambda n: n.score or 0.0, reverse=True) | |
retrieved_nodes = [] | |
for node in nodes: | |
chunk = Chunk.from_node(node) | |
chunk.previous_texts = self._get_sibling_nodes_text( | |
node, prev_next_chunks, False | |
) | |
chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks) | |
retrieved_nodes.append(chunk) | |
return retrieved_nodes | |