import uuid import weaviate from weaviate import Client from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 from autogpt.config import Config from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding def default_schema(weaviate_index): return { "class": weaviate_index, "properties": [ { "name": "raw_text", "dataType": ["text"], "description": "original text for the embedding", } ], } class WeaviateMemory(MemoryProviderSingleton): def __init__(self, cfg): auth_credentials = self._build_auth_credentials(cfg) url = f"{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}" if cfg.use_weaviate_embedded: self.client = Client( embedded_options=EmbeddedOptions( hostname=cfg.weaviate_host, port=int(cfg.weaviate_port), persistence_data_path=cfg.weaviate_embedded_path, ) ) print( f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}" ) else: self.client = Client(url, auth_client_secret=auth_credentials) self.index = WeaviateMemory.format_classname(cfg.memory_index) self._create_schema() @staticmethod def format_classname(index): # weaviate uses capitalised index names # The python client uses the following code to format # index names before the corresponding class is created if len(index) == 1: return index.capitalize() return index[0].capitalize() + index[1:] def _create_schema(self): schema = default_schema(self.index) if not self.client.schema.contains(schema): self.client.schema.create_class(schema) def _build_auth_credentials(self, cfg): if cfg.weaviate_username and cfg.weaviate_password: return weaviate.AuthClientPassword( cfg.weaviate_username, cfg.weaviate_password ) if cfg.weaviate_api_key: return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key) else: return None def add(self, data): vector = get_ada_embedding(data) doc_uuid = generate_uuid5(data, self.index) data_object = {"raw_text": data} with self.client.batch as batch: batch.add_data_object( uuid=doc_uuid, data_object=data_object, class_name=self.index, vector=vector, ) return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" def get(self, data): return self.get_relevant(data, 1) def clear(self): self.client.schema.delete_all() # weaviate does not yet have a neat way to just remove the items in an index # without removing the entire schema, therefore we need to re-create it # after a call to delete_all self._create_schema() return "Obliterated" def get_relevant(self, data, num_relevant=5): query_embedding = get_ada_embedding(data) try: results = ( self.client.query.get(self.index, ["raw_text"]) .with_near_vector({"vector": query_embedding, "certainty": 0.7}) .with_limit(num_relevant) .do() ) if len(results["data"]["Get"][self.index]) > 0: return [ str(item["raw_text"]) for item in results["data"]["Get"][self.index] ] else: return [] except Exception as err: print(f"Unexpected error {err=}, {type(err)=}") return [] def get_stats(self): result = self.client.query.aggregate(self.index).with_meta_count().do() class_data = result["data"]["Aggregate"][self.index] return class_data[0]["meta"] if class_data else {}