|
import os |
|
from typing import Dict, List, Any |
|
|
|
import uuid |
|
from copy import deepcopy |
|
from langchain.embeddings import OpenAIEmbeddings |
|
|
|
from chromadb import Client as ChromaClient |
|
from aiflows.messages import FlowMessage |
|
from aiflows.base_flows import AtomicFlow |
|
|
|
import hydra |
|
|
|
import os |
|
from typing import Dict, List, Any |
|
|
|
import uuid |
|
from copy import deepcopy |
|
from langchain.embeddings import OpenAIEmbeddings |
|
|
|
from aiflows.messages import FlowMessage |
|
from aiflows.base_flows import AtomicFlow |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain.document_loaders import TextLoader |
|
from langchain.vectorstores import Chroma |
|
import hydra |
|
|
|
class ChromaDBFlow(AtomicFlow): |
|
""" A flow that uses the ChromaDB model to write and read memories stored in a database |
|
|
|
*Configuration Parameters*: |
|
|
|
- `name` (str): The name of the flow. Default: "chroma_db" |
|
- `description` (str): A description of the flow. This description is used to generate the help message of the flow. |
|
Default: "ChromaDB is a document store that uses vector embeddings to store and retrieve documents." |
|
- `backend` (Dict[str, Any]): The configuration of the backend which is used to fetch api keys. Default: LiteLLMBackend with the |
|
default parameters of LiteLLMBackend (see aiflows.backends.LiteLLMBackend). Except for the following parameter whose default value is overwritten: |
|
- `api_infos` (List[Dict[str, Any]]): The list of api infos. Default: No default value, this parameter is required. |
|
- `model_name` (str): The name of the model. Default: "". In the current implementation, this parameter is not used. |
|
- `similarity_search_kwargs` (Dict[str, Any]): The parameters to pass to the similarity search method of the ChromaDB. Default: |
|
- `k` (int): The number of documents to retrieve. Default: 2 |
|
- `filter` (str): The filter to apply to the documents. Default: null |
|
- `paths_to_data` (List[str]): The paths to the data to store in the database at instantiation. Default: [] |
|
- `chunk_size` (int): The size of the chunks to split the documents into. Default: 700 |
|
- `seperator` (str): The separator to use to split the documents. Default: "\n" |
|
- `chunk_overlap` (int): The overlap between the chunks. Default: 0 |
|
- `persist_directory` (str): The directory to persist the database. Default: "./demo_db_dir" |
|
|
|
- Other parameters are inherited from the default configuration of AtomicFlow (see AtomicFlow) |
|
|
|
*Input Interface*: |
|
|
|
- `operation` (str): The operation to perform. It can be "write" or "read". |
|
- `content` (str or List[str]): The content to write or read. If operation is "write", it must be a string or a list of strings. If operation is "read", it must be a string. |
|
|
|
*Output Interface*: |
|
|
|
- `retrieved` (str or List[str]): The retrieved content. If operation is "write", it is an empty string. If operation is "read", it is a string or a list of strings. |
|
|
|
:param backend: The backend of the flow (used to retrieve the API key) |
|
:type backend: LiteLLMBackend |
|
:param \**kwargs: Additional arguments to pass to the flow. |
|
""" |
|
def __init__(self, backend,**kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.backend = backend |
|
|
|
def set_up_flow_state(self): |
|
super().set_up_flow_state() |
|
self.flow_state["db_created"] =False |
|
|
|
@classmethod |
|
def _set_up_backend(cls, config): |
|
""" This instantiates the backend of the flow from a configuration file. |
|
|
|
:param config: The configuration of the backend. |
|
:type config: Dict[str, Any] |
|
:return: The backend of the flow. |
|
:rtype: Dict[str, LiteLLMBackend] |
|
""" |
|
kwargs = {} |
|
|
|
kwargs["backend"] = \ |
|
hydra.utils.instantiate(config['backend'], _convert_="partial") |
|
|
|
return kwargs |
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config): |
|
""" This method instantiates the flow from a configuration file |
|
|
|
:param config: The configuration of the flow. |
|
:type config: Dict[str, Any] |
|
:return: The instantiated flow. |
|
:rtype: ChromaDBFlow |
|
""" |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
kwargs.update(cls._set_up_backend(flow_config)) |
|
|
|
|
|
return cls(**kwargs) |
|
|
|
|
|
def get_embeddings_model(self): |
|
api_information = self.backend.get_key() |
|
if api_information.backend_used == "openai": |
|
embeddings = OpenAIEmbeddings(openai_api_key=api_information.api_key) |
|
else: |
|
|
|
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY")) |
|
return embeddings |
|
|
|
|
|
def get_db(self): |
|
db_created = self.flow_state["db_created"] |
|
|
|
if hasattr(self, 'db'): |
|
|
|
db = self.db |
|
|
|
elif db_created or len(self.flow_config["paths_to_data"]) == 0: |
|
|
|
db = Chroma( |
|
persist_directory=self.flow_config["persist_directory"], |
|
embedding_function=self.get_embeddings_model() |
|
) |
|
else: |
|
|
|
full_docs = [] |
|
text_splitter = CharacterTextSplitter( |
|
chunk_size=self.flow_config["chunk_size"], |
|
chunk_overlap=self.flow_config["chunk_overlap"], |
|
separator=self.flow_config["separator"] |
|
) |
|
|
|
for path in self.flow_config["paths_to_data"]: |
|
loader = TextLoader(path) |
|
documents = loader.load() |
|
docs = text_splitter.split_documents(documents) |
|
full_docs.extend(docs) |
|
|
|
db = Chroma.from_documents( |
|
full_docs, |
|
self.get_embeddings_model(), |
|
persist_directory=self.flow_config["persist_directory"] |
|
) |
|
|
|
self.flow_state["db_created"] = True |
|
return db |
|
|
|
def run(self, input_message: FlowMessage): |
|
""" This method runs the flow. It runs the ChromaDBFlow. It either writes or reads memories from the database. |
|
|
|
:param input_message: The input message of the flow. |
|
:type input_message: FlowMessage |
|
""" |
|
|
|
self.db = self.get_db() |
|
|
|
input_data = input_message.data |
|
|
|
embeddings = self.get_embeddings_model() |
|
|
|
response = {} |
|
|
|
operation = input_data["operation"] |
|
if operation not in ["write", "read"]: |
|
raise ValueError(f"Operation '{operation}' not supported") |
|
|
|
content = input_data["content"] |
|
|
|
if operation == "read": |
|
if not isinstance(content, str): |
|
raise ValueError(f"content(query) must be a string during read, got {type(content)}: {content}") |
|
if content == "": |
|
response["retrieved"] = [[""]] |
|
else: |
|
query = content |
|
query_result = self.db.similarity_search(query, **self.flow_config["similarity_search_kwargs"]) |
|
|
|
response["retrieved"] = [doc.page_content for doc in query_result] |
|
|
|
elif operation == "write": |
|
if content != "": |
|
if not isinstance(content, list): |
|
content = [content] |
|
documents = content |
|
self.db._collection.add( |
|
ids=[str(uuid.uuid4()) for _ in range(len(documents))], |
|
embeddings=embeddings.embed_documents(documents), |
|
documents=documents |
|
) |
|
|
|
response["retrieved"] = "" |
|
|
|
reply = self.package_output_message( |
|
input_message = input_message, |
|
response = response |
|
) |
|
self.send_message(reply) |
|
|