File size: 5,041 Bytes
40a5d34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""Redis memory provider."""
from __future__ import annotations
from typing import Any
import numpy as np
import redis
from colorama import Fore, Style
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
from autogpt.llm_utils import create_embedding_with_ada
from autogpt.logs import logger
from autogpt.memory.base import MemoryProviderSingleton
SCHEMA = [
TextField("data"),
VectorField(
"embedding",
"HNSW",
{"TYPE": "FLOAT32", "DIM": 1536, "DISTANCE_METRIC": "COSINE"},
),
]
class RedisMemory(MemoryProviderSingleton):
def __init__(self, cfg):
"""
Initializes the Redis memory provider.
Args:
cfg: The config object.
Returns: None
"""
redis_host = cfg.redis_host
redis_port = cfg.redis_port
redis_password = cfg.redis_password
self.dimension = 1536
self.redis = redis.Redis(
host=redis_host,
port=redis_port,
password=redis_password,
db=0, # Cannot be changed
)
self.cfg = cfg
# Check redis connection
try:
self.redis.ping()
except redis.ConnectionError as e:
logger.typewriter_log(
"FAILED TO CONNECT TO REDIS",
Fore.RED,
Style.BRIGHT + str(e) + Style.RESET_ALL,
)
logger.double_check(
"Please ensure you have setup and configured Redis properly for use. "
+ f"You can check out {Fore.CYAN + Style.BRIGHT}"
f"https://github.com/Torantulino/Auto-GPT#redis-setup{Style.RESET_ALL}"
" to ensure you've set up everything correctly."
)
exit(1)
if cfg.wipe_redis_on_start:
self.redis.flushall()
try:
self.redis.ft(f"{cfg.memory_index}").create_index(
fields=SCHEMA,
definition=IndexDefinition(
prefix=[f"{cfg.memory_index}:"], index_type=IndexType.HASH
),
)
except Exception as e:
print("Error creating Redis search index: ", e)
existing_vec_num = self.redis.get(f"{cfg.memory_index}-vec_num")
self.vec_num = int(existing_vec_num.decode("utf-8")) if existing_vec_num else 0
def add(self, data: str) -> str:
"""
Adds a data point to the memory.
Args:
data: The data to add.
Returns: Message indicating that the data has been added.
"""
if "Command Error:" in data:
return ""
vector = create_embedding_with_ada(data)
vector = np.array(vector).astype(np.float32).tobytes()
data_dict = {b"data": data, "embedding": vector}
pipe = self.redis.pipeline()
pipe.hset(f"{self.cfg.memory_index}:{self.vec_num}", mapping=data_dict)
_text = (
f"Inserting data into memory at index: {self.vec_num}:\n" f"data: {data}"
)
self.vec_num += 1
pipe.set(f"{self.cfg.memory_index}-vec_num", self.vec_num)
pipe.execute()
return _text
def get(self, data: str) -> list[Any] | None:
"""
Gets the data from the memory that is most relevant to the given data.
Args:
data: The data to compare to.
Returns: The most relevant data.
"""
return self.get_relevant(data, 1)
def clear(self) -> str:
"""
Clears the redis server.
Returns: A message indicating that the memory has been cleared.
"""
self.redis.flushall()
return "Obliviated"
def get_relevant(self, data: str, num_relevant: int = 5) -> list[Any] | None:
"""
Returns all the data in the memory that is relevant to the given data.
Args:
data: The data to compare to.
num_relevant: The number of relevant data to return.
Returns: A list of the most relevant data.
"""
query_embedding = create_embedding_with_ada(data)
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
query = (
Query(base_query)
.return_fields("data", "vector_score")
.sort_by("vector_score")
.dialect(2)
)
query_vector = np.array(query_embedding).astype(np.float32).tobytes()
try:
results = self.redis.ft(f"{self.cfg.memory_index}").search(
query, query_params={"vector": query_vector}
)
except Exception as e:
print("Error calling Redis search: ", e)
return None
return [result.data for result in results.docs]
def get_stats(self):
"""
Returns: The stats of the memory index.
"""
return self.redis.ft(f"{self.cfg.memory_index}").info()
|