File size: 4,559 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
Retriever model for chromadb
"""
from typing import Optional, List, Union
import openai
import dspy
import backoff
from dsp.utils import dotdict
try:
import openai.error
ERRORS = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError)
except Exception:
ERRORS = (openai.RateLimitError, openai.APIError)
try:
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from chromadb.api.types import (
Embeddable,
EmbeddingFunction
)
import chromadb.utils.embedding_functions as ef
except ImportError:
chromadb = None
if chromadb is None:
raise ImportError(
"The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`"
)
class ChromadbRM(dspy.Retrieve):
"""
A retrieval module that uses chromadb to return the top passages for a given query.
Assumes that the chromadb index has been created and populated with the following metadata:
- documents: The text of the passage
Args:
collection_name (str): chromadb collection name
persist_directory (str): chromadb persist directory
embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction.
k (int, optional): The number of top passages to retrieve. Defaults to 7.
Returns:
dspy.Prediction: An object containing the retrieved passages.
Examples:
Below is a code snippet that shows how to use this as the default retriever:
```python
llm = dspy.OpenAI(model="gpt-3.5-turbo")
retriever_model = ChromadbRM('collection_name', 'db_path')
dspy.settings.configure(lm=llm, rm=retriever_model)
# to test the retriever with "my query"
retriever_model("my query")
```
Below is a code snippet that shows how to use this in the forward() function of a module
```python
self.retrieve = ChromadbRM('collection_name', 'db_path', k=num_passages)
```
"""
def __init__(
self,
collection_name: str,
persist_directory: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(),
k: int = 7,
):
self._init_chromadb(collection_name, persist_directory)
self.ef = embedding_function
super().__init__(k=k)
def _init_chromadb(
self,
collection_name: str,
persist_directory: str,
) -> chromadb.Collection:
"""Initialize chromadb and return the loaded index.
Args:
collection_name (str): chromadb collection name
persist_directory (str): chromadb persist directory
Returns:
"""
self._chromadb_client = chromadb.Client(
Settings(
persist_directory=persist_directory,
is_persistent=True,
)
)
self._chromadb_collection = self._chromadb_client.get_or_create_collection(
name=collection_name,
)
@backoff.on_exception(
backoff.expo,
ERRORS,
max_time=15,
)
def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
"""Return query vector after creating embedding using OpenAI
Args:
queries (list): List of query strings to embed.
Returns:
List[List[float]]: List of embeddings corresponding to each query.
"""
return self.ef(queries)
def forward(
self, query_or_queries: Union[str, List[str]], k: Optional[int] = None
) -> dspy.Prediction:
"""Search with db for self.k top passages for query
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
Returns:
dspy.Prediction: An object containing the retrieved passages.
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
queries = [q for q in queries if q] # Filter empty queries
embeddings = self._get_embeddings(queries)
k = self.k if k is None else k
results = self._chromadb_collection.query(
query_embeddings=embeddings, n_results=k
)
passages = [dotdict({"long_text": x}) for x in results["documents"][0]]
return passages
|