Spaces:
Runtime error
Runtime error
File size: 6,299 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import datetime
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
"""Get the hours passed between two datetimes."""
return (time - ref_time).total_seconds() / 3600
class TimeWeightedVectorStoreRetriever(BaseRetriever):
"""Retriever that combines embedding similarity with
recency in retrieving values."""
vectorstore: VectorStore
"""The vectorstore to store documents and determine salience."""
search_kwargs: dict = Field(default_factory=lambda: dict(k=100))
"""Keyword arguments to pass to the vectorstore similarity search."""
# TODO: abstract as a queue
memory_stream: List[Document] = Field(default_factory=list)
"""The memory_stream of documents to search through."""
decay_rate: float = Field(default=0.01)
"""The exponential decay factor used as (1.0-decay_rate)**(hrs_passed)."""
k: int = 4
"""The maximum number of documents to retrieve in a given call."""
other_score_keys: List[str] = []
"""Other keys in the metadata to factor into the score, e.g. 'importance'."""
default_salience: Optional[float] = None
"""The salience to assign memories not retrieved from the vector store.
None assigns no salience to documents not fetched from the vector store.
"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _document_get_date(self, field: str, document: Document) -> datetime.datetime:
"""Return the value of the date field of a document."""
if field in document.metadata:
if isinstance(document.metadata[field], float):
return datetime.datetime.fromtimestamp(document.metadata[field])
return document.metadata[field]
return datetime.datetime.now()
def _get_combined_score(
self,
document: Document,
vector_relevance: Optional[float],
current_time: datetime.datetime,
) -> float:
"""Return the combined score for a document."""
hours_passed = _get_hours_passed(
current_time,
self._document_get_date("last_accessed_at", document),
)
score = (1.0 - self.decay_rate) ** hours_passed
for key in self.other_score_keys:
if key in document.metadata:
score += document.metadata[key]
if vector_relevance is not None:
score += vector_relevance
return score
def get_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
"""Return documents that are salient to the query."""
docs_and_scores: List[Tuple[Document, float]]
docs_and_scores = self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
results = {}
for fetched_doc, relevance in docs_and_scores:
if "buffer_idx" in fetched_doc.metadata:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
return results
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Return documents that are relevant to the query."""
current_time = datetime.datetime.now()
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(self.get_salient_docs(query))
rescored_docs = [
(doc, self._get_combined_score(doc, relevance, current_time))
for doc, relevance in docs_and_scores.values()
]
rescored_docs.sort(key=lambda x: x[1], reverse=True)
result = []
# Ensure frequently accessed memories aren't forgotten
for doc, _ in rescored_docs[: self.k]:
# TODO: Update vector store doc once `update` method is exposed.
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
buffered_doc.metadata["last_accessed_at"] = current_time
result.append(buffered_doc)
return result
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time")
if current_time is None:
current_time = datetime.datetime.now()
# Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return self.vectorstore.add_documents(dup_docs, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time")
if current_time is None:
current_time = datetime.datetime.now()
# Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return await self.vectorstore.aadd_documents(dup_docs, **kwargs)
|