Spaces:
Sleeping
Sleeping
import uuid | |
import pytest | |
import chromadb.test.property.strategies as strategies | |
from unittest.mock import patch | |
from dataclasses import asdict | |
import random | |
from hypothesis.stateful import ( | |
Bundle, | |
RuleBasedStateMachine, | |
rule, | |
initialize, | |
multiple, | |
precondition, | |
invariant, | |
run_state_machine_as_test, | |
MultipleResults, | |
) | |
from typing import Dict | |
from chromadb.segment import ( | |
VectorReader | |
) | |
from chromadb.segment import SegmentManager | |
from chromadb.segment.impl.manager.local import LocalSegmentManager | |
from chromadb.types import SegmentScope | |
from chromadb.db.system import SysDB | |
from chromadb.config import System, get_class | |
# Memory limit use for testing | |
memory_limit = 100 | |
# Helper class to keep tract of the last use id | |
class LastUse: | |
def __init__(self, n: int): | |
self.n = n | |
self.store = [] | |
def add(self, id: uuid.UUID): | |
if id in self.store: | |
self.store.remove(id) | |
self.store.append(id) | |
else: | |
self.store.append(id) | |
while len(self.store) > self.n: | |
self.store.pop(0) | |
return self.store | |
def reset(self): | |
self.store = [] | |
class SegmentManagerStateMachine(RuleBasedStateMachine): | |
collections: Bundle[strategies.Collection] | |
collections = Bundle("collections") | |
collection_size_store: Dict[uuid.UUID, int] = {} | |
segment_collection: Dict[uuid.UUID, uuid.UUID] = {} | |
def __init__(self, system: System): | |
super().__init__() | |
self.segment_manager = system.require(SegmentManager) | |
self.segment_manager.start() | |
self.segment_manager.reset_state() | |
self.last_use = LastUse(n=40) | |
self.collection_created_counter = 0 | |
self.sysdb = system.require(SysDB) | |
self.system = system | |
def last_queried_segments_should_be_in_cache(self): | |
cache_sum = 0 | |
index = 0 | |
for id in reversed(self.last_use.store): | |
cache_sum += self.collection_size_store[id] | |
if cache_sum >= memory_limit and index is not 0: | |
break | |
assert id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache | |
index += 1 | |
def cache_should_not_be_bigger_than_settings(self): | |
segment_sizes = {id: self.collection_size_store[id] for id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache} | |
total_size = sum(segment_sizes.values()) | |
if len(segment_sizes) != 1: | |
assert total_size <= memory_limit | |
def initialize(self) -> None: | |
self.segment_manager.reset_state() | |
self.segment_manager.start() | |
self.collection_created_counter = 0 | |
self.last_use.reset() | |
def create_segment( | |
self, coll: strategies.Collection | |
) -> MultipleResults[strategies.Collection]: | |
segments = self.segment_manager.create_segments(asdict(coll)) | |
for segment in segments: | |
self.sysdb.create_segment(segment) | |
self.segment_collection[segment["id"]] = coll.id | |
self.collection_created_counter += 1 | |
self.collection_size_store[coll.id] = random.randint(0, memory_limit) | |
return multiple(coll) | |
def get_segment(self, coll: strategies.Collection) -> None: | |
segment = self.segment_manager.get_segment(collection_id=coll.id, type=VectorReader) | |
self.last_use.add(coll.id) | |
assert segment is not None | |
def mock_directory_size(directory: str): | |
path_id = directory.split("/").pop() | |
collection_id = SegmentManagerStateMachine.segment_collection[uuid.UUID(path_id)] | |
return SegmentManagerStateMachine.collection_size_store[collection_id] | |
def test_segment_manager(caplog: pytest.LogCaptureFixture, system: System) -> None: | |
system.settings.chroma_memory_limit_bytes = memory_limit | |
system.settings.chroma_segment_cache_policy = "LRU" | |
run_state_machine_as_test( | |
lambda: SegmentManagerStateMachine(system=system)) | |