File size: 4,170 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Dict, List, Set, cast

from chromadb.types import EmbeddingRecord, Operation, SeqId, Vector


class Batch:
    """Used to model the set of changes as an atomic operation"""

    _ids_to_records: Dict[str, EmbeddingRecord]
    _deleted_ids: Set[str]
    _written_ids: Set[str]
    _upsert_add_ids: Set[str]  # IDs that are being added in an upsert
    add_count: int
    update_count: int
    max_seq_id: SeqId

    def __init__(self) -> None:
        self._ids_to_records = {}
        self._deleted_ids = set()
        self._written_ids = set()
        self._upsert_add_ids = set()
        self.add_count = 0
        self.update_count = 0
        self.max_seq_id = 0

    def __len__(self) -> int:
        """Get the number of changes in this batch"""
        return len(self._written_ids) + len(self._deleted_ids)

    def get_deleted_ids(self) -> List[str]:
        """Get the list of deleted embeddings in this batch"""
        return list(self._deleted_ids)

    def get_written_ids(self) -> List[str]:
        """Get the list of written embeddings in this batch"""
        return list(self._written_ids)

    def get_written_vectors(self, ids: List[str]) -> List[Vector]:
        """Get the list of vectors to write in this batch"""
        return [cast(Vector, self._ids_to_records[id]["embedding"]) for id in ids]

    def get_record(self, id: str) -> EmbeddingRecord:
        """Get the record for a given ID"""
        return self._ids_to_records[id]

    def is_deleted(self, id: str) -> bool:
        """Check if a given ID is deleted"""
        return id in self._deleted_ids

    @property
    def delete_count(self) -> int:
        return len(self._deleted_ids)

    def apply(self, record: EmbeddingRecord, exists_already: bool = False) -> None:
        """Apply an embedding record to this batch. Records passed to this method are assumed to be validated for correctness.
        For example, a delete or update presumes the ID exists in the index. An add presumes the ID does not exist in the index.
        The exists_already flag should be set to True if the ID does exist in the index, and False otherwise.
        """

        id = record["id"]
        if record["operation"] == Operation.DELETE:
            # If the ID was previously written, remove it from the written set
            # And update the add/update/delete counts
            if id in self._written_ids:
                self._written_ids.remove(id)
                if self._ids_to_records[id]["operation"] == Operation.ADD:
                    self.add_count -= 1
                elif self._ids_to_records[id]["operation"] == Operation.UPDATE:
                    self.update_count -= 1
                    self._deleted_ids.add(id)
                elif self._ids_to_records[id]["operation"] == Operation.UPSERT:
                    if id in self._upsert_add_ids:
                        self.add_count -= 1
                        self._upsert_add_ids.remove(id)
                    else:
                        self.update_count -= 1
                        self._deleted_ids.add(id)
            elif id not in self._deleted_ids:
                self._deleted_ids.add(id)

            # Remove the record from the batch
            if id in self._ids_to_records:
                del self._ids_to_records[id]

        else:
            self._ids_to_records[id] = record
            self._written_ids.add(id)

            # If the ID was previously deleted, remove it from the deleted set
            # And update the delete count
            if id in self._deleted_ids:
                self._deleted_ids.remove(id)

            # Update the add/update counts
            if record["operation"] == Operation.UPSERT:
                if not exists_already:
                    self.add_count += 1
                    self._upsert_add_ids.add(id)
                else:
                    self.update_count += 1
            elif record["operation"] == Operation.ADD:
                self.add_count += 1
            elif record["operation"] == Operation.UPDATE:
                self.update_count += 1

        self.max_seq_id = max(self.max_seq_id, record["seq_id"])