File size: 3,791 Bytes
d6fdb17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
""" Milvus memory storage provider."""
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections

from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding


class MilvusMemory(MemoryProviderSingleton):
    """Milvus memory storage provider."""

    def __init__(self, cfg) -> None:
        """Construct a milvus memory storage connection.

        Args:
            cfg (Config): Auto-GPT global config.
        """
        # connect to milvus server.
        connections.connect(address=cfg.milvus_addr)
        fields = [
            FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=1536),
            FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
        ]

        # create collection if not exist and load it.
        self.milvus_collection = cfg.milvus_collection
        self.schema = CollectionSchema(fields, "auto-gpt memory storage")
        self.collection = Collection(self.milvus_collection, self.schema)
        # create index if not exist.
        if not self.collection.has_index():
            self.collection.release()
            self.collection.create_index(
                "embeddings",
                {
                    "metric_type": "IP",
                    "index_type": "HNSW",
                    "params": {"M": 8, "efConstruction": 64},
                },
                index_name="embeddings",
            )
        self.collection.load()

    def add(self, data) -> str:
        """Add an embedding of data into memory.

        Args:
            data (str): The raw text to construct embedding index.

        Returns:
            str: log.
        """
        embedding = get_ada_embedding(data)
        result = self.collection.insert([[embedding], [data]])
        _text = (
            "Inserting data into memory at primary key: "
            f"{result.primary_keys[0]}:\n data: {data}"
        )
        return _text

    def get(self, data):
        """Return the most relevant data in memory.
        Args:
            data: The data to compare to.
        """
        return self.get_relevant(data, 1)

    def clear(self) -> str:
        """Drop the index in memory.

        Returns:
            str: log.
        """
        self.collection.drop()
        self.collection = Collection(self.milvus_collection, self.schema)
        self.collection.create_index(
            "embeddings",
            {
                "metric_type": "IP",
                "index_type": "HNSW",
                "params": {"M": 8, "efConstruction": 64},
            },
            index_name="embeddings",
        )
        self.collection.load()
        return "Obliviated"

    def get_relevant(self, data: str, num_relevant: int = 5):
        """Return the top-k relevant data in memory.
        Args:
            data: The data to compare to.
            num_relevant (int, optional): The max number of relevant data.
                Defaults to 5.

        Returns:
            list: The top-k relevant data.
        """
        # search the embedding and return the most relevant text.
        embedding = get_ada_embedding(data)
        search_params = {
            "metrics_type": "IP",
            "params": {"nprobe": 8},
        }
        result = self.collection.search(
            [embedding],
            "embeddings",
            search_params,
            num_relevant,
            output_fields=["raw_text"],
        )
        return [item.entity.value_of_field("raw_text") for item in result[0]]

    def get_stats(self) -> str:
        """
        Returns: The stats of the milvus cache.
        """
        return f"Entities num: {self.collection.num_entities}"