File size: 7,756 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from typing import Any, Dict, List, Sequence, Set
from uuid import UUID
from chromadb.config import Settings, System
from chromadb.ingest import CollectionAssignmentPolicy, Consumer
from chromadb.proto.chroma_pb2_grpc import (
    # SegmentServerServicer,
    # add_SegmentServerServicer_to_server,
    VectorReaderServicer,
    add_VectorReaderServicer_to_server,
)
import chromadb.proto.chroma_pb2 as proto
import grpc
from concurrent import futures
from chromadb.proto.convert import (
    to_proto_vector_embedding_record
)
from chromadb.segment import SegmentImplementation, SegmentType
from chromadb.telemetry.opentelemetry import (
    OpenTelemetryClient
)
from chromadb.types import EmbeddingRecord
from chromadb.segment.distributed import MemberlistProvider, Memberlist
from chromadb.utils.rendezvous_hash import assign, murmur3hasher
from chromadb.ingest.impl.pulsar_admin import PulsarAdmin
import logging
import os

# This file is a prototype. It will be replaced with a real distributed segment server
# written in a different language. This is just a proof of concept to get the distributed
# segment type working end to end.

# Run this with python -m chromadb.segment.impl.distributed.server

SEGMENT_TYPE_IMPLS = {
    SegmentType.HNSW_DISTRIBUTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment",
}


class SegmentServer(VectorReaderServicer):
    _segment_cache: Dict[UUID, SegmentImplementation] = {}
    _system: System
    _opentelemetry_client: OpenTelemetryClient
    _memberlist_provider: MemberlistProvider
    _curr_memberlist: Memberlist
    _assigned_topics: Set[str]
    _topic_to_subscription: Dict[str, UUID]
    _consumer: Consumer

    def __init__(self, system: System) -> None:
        super().__init__()
        self._system = system

        # Init dependency services
        self._opentelemetry_client = system.require(OpenTelemetryClient)
        # TODO: add term and epoch to segment server
        self._memberlist_provider = system.require(MemberlistProvider)
        self._memberlist_provider.set_memberlist_name("worker-memberlist")
        self._assignment_policy = system.require(CollectionAssignmentPolicy)
        self._create_pulsar_topics()
        self._consumer = system.require(Consumer)

        # Init data
        self._topic_to_subscription = {}
        self._assigned_topics = set()
        self._curr_memberlist = self._memberlist_provider.get_memberlist()
        self._compute_assigned_topics()

        self._memberlist_provider.register_updated_memberlist_callback(
            self._on_memberlist_update
        )

    def _compute_assigned_topics(self) -> None:
        """Uses rendezvous hashing to compute the topics that this node is responsible for"""
        if not self._curr_memberlist:
            self._assigned_topics = set()
            return
        topics = self._assignment_policy.get_topics()
        my_ip = os.environ["MY_POD_IP"]
        new_assignments: List[str] = []
        for topic in topics:
            assigned = assign(topic, self._curr_memberlist, murmur3hasher)
            if assigned == my_ip:
                new_assignments.append(topic)
        new_assignments_set = set(new_assignments)
        # TODO: We need to lock around this assignment
        net_new_assignments = new_assignments_set - self._assigned_topics
        removed_assignments = self._assigned_topics - new_assignments_set

        for topic in removed_assignments:
            subscription = self._topic_to_subscription[topic]
            self._consumer.unsubscribe(subscription)
            del self._topic_to_subscription[topic]

        for topic in net_new_assignments:
            subscription = self._consumer.subscribe(topic, self._on_message)
            self._topic_to_subscription[topic] = subscription

        self._assigned_topics = new_assignments_set
        print(
            f"Topic assigment updated and now assigned to {len(self._assigned_topics)} topics"
        )

    def _on_memberlist_update(self, memberlist: Memberlist) -> None:
        """Called when the memberlist is updated"""
        self._curr_memberlist = memberlist
        if len(self._curr_memberlist) > 0:
            self._compute_assigned_topics()
        else:
            # In this case we'd want to warn that there are no members but
            # this is not an error, as it could be that the cluster is just starting up
            print("Memberlist is empty")

    def _on_message(self, embedding_records: Sequence[EmbeddingRecord]) -> None:
        """Called when a message is received from the consumer"""
        print(f"Received {len(embedding_records)} records")
        print(
            f"First record: {embedding_records[0]} is for collection {embedding_records[0]['collection_id']}"
        )
        return None

    def _create_pulsar_topics(self) -> None:
        """This creates the pulsar topics used by the system.
        HACK: THIS IS COMPLETELY A HACK AND WILL BE REPLACED
        BY A PROPER TOPIC MANAGEMENT SYSTEM IN THE COORDINATOR"""
        topics = self._assignment_policy.get_topics()
        admin = PulsarAdmin(self._system)
        for topic in topics:
            admin.create_topic(topic)

    def QueryVectors(
        self, request: proto.QueryVectorsRequest, context: Any
    ) -> proto.QueryVectorsResponse:
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details("Query segment not implemented yet")
        return proto.QueryVectorsResponse()

    # @trace_method(
    #     "SegmentServer.GetVectors", OpenTelemetryGranularity.OPERATION_AND_SEGMENT
    # )
    # def GetVectors(
    #     self, request: proto.GetVectorsRequest, context: Any
    # ) -> proto.GetVectorsResponse:
    #     segment_id = UUID(hex=request.segment_id)
    #     if segment_id not in self._segment_cache:
    #         context.set_code(grpc.StatusCode.NOT_FOUND)
    #         context.set_details("Segment not found")
    #         return proto.GetVectorsResponse()
    #     else:
    #         segment = self._segment_cache[segment_id]
    #         segment = cast(VectorReader, segment)
    #         segment_results = segment.get_vectors(request.ids)
    #         return_records = []
    #         for record in segment_results:
    #             # TODO: encoding should be based on stored encoding for segment
    #             # For now we just assume float32
    #             return_record = to_proto_vector_embedding_record(
    #                 record, ScalarEncoding.FLOAT32
    #             )
    #             return_records.append(return_record)
    #         return proto.GetVectorsResponse(records=return_records)

    # def _cls(self, segment: Segment) -> Type[SegmentImplementation]:
    #     classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])]
    #     cls = get_class(classname, SegmentImplementation)
    #     return cls

    # def _create_instance(self, segment: Segment) -> None:
    #     if segment["id"] not in self._segment_cache:
    #         cls = self._cls(segment)
    #         instance = cls(self._system, segment)
    #         instance.start()
    #         self._segment_cache[segment["id"]] = instance


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    system = System(Settings())
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    segment_server = SegmentServer(system)
    # add_SegmentServerServicer_to_server(segment_server, server)  # type: ignore
    add_VectorReaderServicer_to_server(segment_server, server)  # type: ignore
    server.add_insecure_port(
        f"[::]:{system.settings.require('chroma_server_grpc_port')}"
    )
    system.start()
    server.start()
    server.wait_for_termination()