Spaces:
Runtime error
Runtime error
File size: 4,135 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from __future__ import annotations
import concurrent.futures
from typing import Any, Iterable, List, Optional
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
"""
Create an index of embeddings for a list of contexts.
Args:
contexts: List of contexts to embed.
embeddings: Embeddings model to use.
Returns:
Index of embeddings.
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
return np.array(list(executor.map(embeddings.embed_query, contexts)))
class SVMRetriever(BaseRetriever):
"""`SVM` retriever.
Largely based on
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb
"""
embeddings: Embeddings
"""Embeddings model to use."""
index: Any
"""Index of embeddings."""
texts: List[str]
"""List of texts to index."""
metadatas: Optional[List[dict]] = None
"""List of metadatas corresponding with each text."""
k: int = 4
"""Number of results to return."""
relevancy_threshold: Optional[float] = None
"""Threshold for relevancy."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@classmethod
def from_texts(
cls,
texts: List[str],
embeddings: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> SVMRetriever:
index = create_index(texts, embeddings)
return cls(
embeddings=embeddings,
index=index,
texts=texts,
metadatas=metadatas,
**kwargs,
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
embeddings: Embeddings,
**kwargs: Any,
) -> SVMRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
try:
from sklearn import svm
except ImportError:
raise ImportError(
"Could not import scikit-learn, please install with `pip install "
"scikit-learn`."
)
query_embeds = np.array(self.embeddings.embed_query(query))
x = np.concatenate([query_embeds[None, ...], self.index])
y = np.zeros(x.shape[0])
y[0] = 1
clf = svm.LinearSVC(
class_weight="balanced", verbose=False, max_iter=10000, tol=1e-6, C=0.1
)
clf.fit(x, y)
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
# svm.LinearSVC in scikit-learn is non-deterministic.
# if a text is the same as a query, there is no guarantee
# the query will be in the first index.
# this performs a simple swap, this works because anything
# left of the 0 should be equivalent.
zero_index = np.where(sorted_ix == 0)[0][0]
if zero_index != 0:
sorted_ix[0], sorted_ix[zero_index] = sorted_ix[zero_index], sorted_ix[0]
denominator = np.max(similarities) - np.min(similarities) + 1e-6
normalized_similarities = (similarities - np.min(similarities)) / denominator
top_k_results = []
for row in sorted_ix[1 : self.k + 1]:
if (
self.relevancy_threshold is None
or normalized_similarities[row] >= self.relevancy_threshold
):
metadata = self.metadatas[row - 1] if self.metadatas else {}
doc = Document(page_content=self.texts[row - 1], metadata=metadata)
top_k_results.append(doc)
return top_k_results
|