File size: 4,397 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import uuid

import pytest
import chromadb.test.property.strategies as strategies
from unittest.mock import patch
from dataclasses import asdict
import random
from hypothesis.stateful import (
    Bundle,
    RuleBasedStateMachine,
    rule,
    initialize,
    multiple,
    precondition,
    invariant,
    run_state_machine_as_test,
    MultipleResults,
)
from typing import Dict
from chromadb.segment import (
    VectorReader
)
from chromadb.segment import SegmentManager

from chromadb.segment.impl.manager.local import LocalSegmentManager
from chromadb.types import SegmentScope
from chromadb.db.system import SysDB
from chromadb.config import System, get_class

# Memory limit use for testing
memory_limit = 100

# Helper class to keep tract of the last use id
class LastUse:
    def __init__(self, n: int):
        self.n = n
        self.store = []

    def add(self, id: uuid.UUID):
        if id in self.store:
            self.store.remove(id)
            self.store.append(id)
        else:
            self.store.append(id)
            while len(self.store) > self.n:
                self.store.pop(0)
        return self.store

    def reset(self):
        self.store = []


class SegmentManagerStateMachine(RuleBasedStateMachine):
    collections: Bundle[strategies.Collection]
    collections = Bundle("collections")
    collection_size_store: Dict[uuid.UUID, int] = {}
    segment_collection: Dict[uuid.UUID, uuid.UUID] = {}

    def __init__(self, system: System):
        super().__init__()
        self.segment_manager = system.require(SegmentManager)
        self.segment_manager.start()
        self.segment_manager.reset_state()
        self.last_use = LastUse(n=40)
        self.collection_created_counter = 0
        self.sysdb = system.require(SysDB)
        self.system = system

    @invariant()
    def last_queried_segments_should_be_in_cache(self):
        cache_sum = 0
        index = 0
        for id in reversed(self.last_use.store):
            cache_sum += self.collection_size_store[id]
            if cache_sum >= memory_limit and index is not 0:
                break
            assert id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache
            index += 1

    @invariant()
    @precondition(lambda self: self.system.settings.is_persistent is True)
    def cache_should_not_be_bigger_than_settings(self):
        segment_sizes = {id: self.collection_size_store[id] for id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache}
        total_size = sum(segment_sizes.values())
        if len(segment_sizes) != 1:
            assert total_size <= memory_limit

    @initialize()
    def initialize(self) -> None:
        self.segment_manager.reset_state()
        self.segment_manager.start()
        self.collection_created_counter = 0
        self.last_use.reset()

    @rule(target=collections, coll=strategies.collections())
    @precondition(lambda self: self.collection_created_counter <= 50)
    def create_segment(
            self, coll: strategies.Collection
    ) -> MultipleResults[strategies.Collection]:
        segments = self.segment_manager.create_segments(asdict(coll))
        for segment in segments:
            self.sysdb.create_segment(segment)
            self.segment_collection[segment["id"]] = coll.id
        self.collection_created_counter += 1
        self.collection_size_store[coll.id] = random.randint(0, memory_limit)
        return multiple(coll)

    @rule(coll=collections)
    def get_segment(self, coll: strategies.Collection) -> None:
        segment = self.segment_manager.get_segment(collection_id=coll.id, type=VectorReader)
        self.last_use.add(coll.id)
        assert segment is not None


    @staticmethod
    def mock_directory_size(directory: str):
        path_id = directory.split("/").pop()
        collection_id = SegmentManagerStateMachine.segment_collection[uuid.UUID(path_id)]
        return SegmentManagerStateMachine.collection_size_store[collection_id]


@patch('chromadb.segment.impl.manager.local.get_directory_size', SegmentManagerStateMachine.mock_directory_size)
def test_segment_manager(caplog: pytest.LogCaptureFixture, system: System) -> None:
    system.settings.chroma_memory_limit_bytes = memory_limit
    system.settings.chroma_segment_cache_policy = "LRU"

    run_state_machine_as_test(
        lambda: SegmentManagerStateMachine(system=system))