badalsahani's picture
feat: chroma initial deploy
287a0bc
from typing import Dict, List, Set, cast
from chromadb.types import EmbeddingRecord, Operation, SeqId, Vector
class Batch:
"""Used to model the set of changes as an atomic operation"""
_ids_to_records: Dict[str, EmbeddingRecord]
_deleted_ids: Set[str]
_written_ids: Set[str]
_upsert_add_ids: Set[str] # IDs that are being added in an upsert
add_count: int
update_count: int
max_seq_id: SeqId
def __init__(self) -> None:
self._ids_to_records = {}
self._deleted_ids = set()
self._written_ids = set()
self._upsert_add_ids = set()
self.add_count = 0
self.update_count = 0
self.max_seq_id = 0
def __len__(self) -> int:
"""Get the number of changes in this batch"""
return len(self._written_ids) + len(self._deleted_ids)
def get_deleted_ids(self) -> List[str]:
"""Get the list of deleted embeddings in this batch"""
return list(self._deleted_ids)
def get_written_ids(self) -> List[str]:
"""Get the list of written embeddings in this batch"""
return list(self._written_ids)
def get_written_vectors(self, ids: List[str]) -> List[Vector]:
"""Get the list of vectors to write in this batch"""
return [cast(Vector, self._ids_to_records[id]["embedding"]) for id in ids]
def get_record(self, id: str) -> EmbeddingRecord:
"""Get the record for a given ID"""
return self._ids_to_records[id]
def is_deleted(self, id: str) -> bool:
"""Check if a given ID is deleted"""
return id in self._deleted_ids
@property
def delete_count(self) -> int:
return len(self._deleted_ids)
def apply(self, record: EmbeddingRecord, exists_already: bool = False) -> None:
"""Apply an embedding record to this batch. Records passed to this method are assumed to be validated for correctness.
For example, a delete or update presumes the ID exists in the index. An add presumes the ID does not exist in the index.
The exists_already flag should be set to True if the ID does exist in the index, and False otherwise.
"""
id = record["id"]
if record["operation"] == Operation.DELETE:
# If the ID was previously written, remove it from the written set
# And update the add/update/delete counts
if id in self._written_ids:
self._written_ids.remove(id)
if self._ids_to_records[id]["operation"] == Operation.ADD:
self.add_count -= 1
elif self._ids_to_records[id]["operation"] == Operation.UPDATE:
self.update_count -= 1
self._deleted_ids.add(id)
elif self._ids_to_records[id]["operation"] == Operation.UPSERT:
if id in self._upsert_add_ids:
self.add_count -= 1
self._upsert_add_ids.remove(id)
else:
self.update_count -= 1
self._deleted_ids.add(id)
elif id not in self._deleted_ids:
self._deleted_ids.add(id)
# Remove the record from the batch
if id in self._ids_to_records:
del self._ids_to_records[id]
else:
self._ids_to_records[id] = record
self._written_ids.add(id)
# If the ID was previously deleted, remove it from the deleted set
# And update the delete count
if id in self._deleted_ids:
self._deleted_ids.remove(id)
# Update the add/update counts
if record["operation"] == Operation.UPSERT:
if not exists_already:
self.add_count += 1
self._upsert_add_ids.add(id)
else:
self.update_count += 1
elif record["operation"] == Operation.ADD:
self.add_count += 1
elif record["operation"] == Operation.UPDATE:
self.update_count += 1
self.max_seq_id = max(self.max_seq_id, record["seq_id"])