Spaces:
Sleeping
Sleeping
from typing import Dict, List, Set, cast | |
from chromadb.types import EmbeddingRecord, Operation, SeqId, Vector | |
class Batch: | |
"""Used to model the set of changes as an atomic operation""" | |
_ids_to_records: Dict[str, EmbeddingRecord] | |
_deleted_ids: Set[str] | |
_written_ids: Set[str] | |
_upsert_add_ids: Set[str] # IDs that are being added in an upsert | |
add_count: int | |
update_count: int | |
max_seq_id: SeqId | |
def __init__(self) -> None: | |
self._ids_to_records = {} | |
self._deleted_ids = set() | |
self._written_ids = set() | |
self._upsert_add_ids = set() | |
self.add_count = 0 | |
self.update_count = 0 | |
self.max_seq_id = 0 | |
def __len__(self) -> int: | |
"""Get the number of changes in this batch""" | |
return len(self._written_ids) + len(self._deleted_ids) | |
def get_deleted_ids(self) -> List[str]: | |
"""Get the list of deleted embeddings in this batch""" | |
return list(self._deleted_ids) | |
def get_written_ids(self) -> List[str]: | |
"""Get the list of written embeddings in this batch""" | |
return list(self._written_ids) | |
def get_written_vectors(self, ids: List[str]) -> List[Vector]: | |
"""Get the list of vectors to write in this batch""" | |
return [cast(Vector, self._ids_to_records[id]["embedding"]) for id in ids] | |
def get_record(self, id: str) -> EmbeddingRecord: | |
"""Get the record for a given ID""" | |
return self._ids_to_records[id] | |
def is_deleted(self, id: str) -> bool: | |
"""Check if a given ID is deleted""" | |
return id in self._deleted_ids | |
def delete_count(self) -> int: | |
return len(self._deleted_ids) | |
def apply(self, record: EmbeddingRecord, exists_already: bool = False) -> None: | |
"""Apply an embedding record to this batch. Records passed to this method are assumed to be validated for correctness. | |
For example, a delete or update presumes the ID exists in the index. An add presumes the ID does not exist in the index. | |
The exists_already flag should be set to True if the ID does exist in the index, and False otherwise. | |
""" | |
id = record["id"] | |
if record["operation"] == Operation.DELETE: | |
# If the ID was previously written, remove it from the written set | |
# And update the add/update/delete counts | |
if id in self._written_ids: | |
self._written_ids.remove(id) | |
if self._ids_to_records[id]["operation"] == Operation.ADD: | |
self.add_count -= 1 | |
elif self._ids_to_records[id]["operation"] == Operation.UPDATE: | |
self.update_count -= 1 | |
self._deleted_ids.add(id) | |
elif self._ids_to_records[id]["operation"] == Operation.UPSERT: | |
if id in self._upsert_add_ids: | |
self.add_count -= 1 | |
self._upsert_add_ids.remove(id) | |
else: | |
self.update_count -= 1 | |
self._deleted_ids.add(id) | |
elif id not in self._deleted_ids: | |
self._deleted_ids.add(id) | |
# Remove the record from the batch | |
if id in self._ids_to_records: | |
del self._ids_to_records[id] | |
else: | |
self._ids_to_records[id] = record | |
self._written_ids.add(id) | |
# If the ID was previously deleted, remove it from the deleted set | |
# And update the delete count | |
if id in self._deleted_ids: | |
self._deleted_ids.remove(id) | |
# Update the add/update counts | |
if record["operation"] == Operation.UPSERT: | |
if not exists_already: | |
self.add_count += 1 | |
self._upsert_add_ids.add(id) | |
else: | |
self.update_count += 1 | |
elif record["operation"] == Operation.ADD: | |
self.add_count += 1 | |
elif record["operation"] == Operation.UPDATE: | |
self.update_count += 1 | |
self.max_seq_id = max(self.max_seq_id, record["seq_id"]) | |