Spaces:
Runtime error
Runtime error
import uuid | |
import weaviate | |
from weaviate import Client | |
from weaviate.embedded import EmbeddedOptions | |
from weaviate.util import generate_uuid5 | |
from autogpt.config import Config | |
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding | |
def default_schema(weaviate_index): | |
return { | |
"class": weaviate_index, | |
"properties": [ | |
{ | |
"name": "raw_text", | |
"dataType": ["text"], | |
"description": "original text for the embedding", | |
} | |
], | |
} | |
class WeaviateMemory(MemoryProviderSingleton): | |
def __init__(self, cfg): | |
auth_credentials = self._build_auth_credentials(cfg) | |
url = f"{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}" | |
if cfg.use_weaviate_embedded: | |
self.client = Client( | |
embedded_options=EmbeddedOptions( | |
hostname=cfg.weaviate_host, | |
port=int(cfg.weaviate_port), | |
persistence_data_path=cfg.weaviate_embedded_path, | |
) | |
) | |
print( | |
f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}" | |
) | |
else: | |
self.client = Client(url, auth_client_secret=auth_credentials) | |
self.index = WeaviateMemory.format_classname(cfg.memory_index) | |
self._create_schema() | |
def format_classname(index): | |
# weaviate uses capitalised index names | |
# The python client uses the following code to format | |
# index names before the corresponding class is created | |
if len(index) == 1: | |
return index.capitalize() | |
return index[0].capitalize() + index[1:] | |
def _create_schema(self): | |
schema = default_schema(self.index) | |
if not self.client.schema.contains(schema): | |
self.client.schema.create_class(schema) | |
def _build_auth_credentials(self, cfg): | |
if cfg.weaviate_username and cfg.weaviate_password: | |
return weaviate.AuthClientPassword( | |
cfg.weaviate_username, cfg.weaviate_password | |
) | |
if cfg.weaviate_api_key: | |
return weaviate.AuthApiKey(api_key=cfg.weaviate_api_key) | |
else: | |
return None | |
def add(self, data): | |
vector = get_ada_embedding(data) | |
doc_uuid = generate_uuid5(data, self.index) | |
data_object = {"raw_text": data} | |
with self.client.batch as batch: | |
batch.add_data_object( | |
uuid=doc_uuid, | |
data_object=data_object, | |
class_name=self.index, | |
vector=vector, | |
) | |
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" | |
def get(self, data): | |
return self.get_relevant(data, 1) | |
def clear(self): | |
self.client.schema.delete_all() | |
# weaviate does not yet have a neat way to just remove the items in an index | |
# without removing the entire schema, therefore we need to re-create it | |
# after a call to delete_all | |
self._create_schema() | |
return "Obliterated" | |
def get_relevant(self, data, num_relevant=5): | |
query_embedding = get_ada_embedding(data) | |
try: | |
results = ( | |
self.client.query.get(self.index, ["raw_text"]) | |
.with_near_vector({"vector": query_embedding, "certainty": 0.7}) | |
.with_limit(num_relevant) | |
.do() | |
) | |
if len(results["data"]["Get"][self.index]) > 0: | |
return [ | |
str(item["raw_text"]) for item in results["data"]["Get"][self.index] | |
] | |
else: | |
return [] | |
except Exception as err: | |
print(f"Unexpected error {err=}, {type(err)=}") | |
return [] | |
def get_stats(self): | |
result = self.client.query.aggregate(self.index).with_meta_count().do() | |
class_data = result["data"]["Aggregate"][self.index] | |
return class_data[0]["meta"] if class_data else {} | |