Spaces:
Sleeping
Sleeping
from typing import Any, Dict, List, Sequence, Set | |
from uuid import UUID | |
from chromadb.config import Settings, System | |
from chromadb.ingest import CollectionAssignmentPolicy, Consumer | |
from chromadb.proto.chroma_pb2_grpc import ( | |
# SegmentServerServicer, | |
# add_SegmentServerServicer_to_server, | |
VectorReaderServicer, | |
add_VectorReaderServicer_to_server, | |
) | |
import chromadb.proto.chroma_pb2 as proto | |
import grpc | |
from concurrent import futures | |
from chromadb.proto.convert import ( | |
to_proto_vector_embedding_record | |
) | |
from chromadb.segment import SegmentImplementation, SegmentType | |
from chromadb.telemetry.opentelemetry import ( | |
OpenTelemetryClient | |
) | |
from chromadb.types import EmbeddingRecord | |
from chromadb.segment.distributed import MemberlistProvider, Memberlist | |
from chromadb.utils.rendezvous_hash import assign, murmur3hasher | |
from chromadb.ingest.impl.pulsar_admin import PulsarAdmin | |
import logging | |
import os | |
# This file is a prototype. It will be replaced with a real distributed segment server | |
# written in a different language. This is just a proof of concept to get the distributed | |
# segment type working end to end. | |
# Run this with python -m chromadb.segment.impl.distributed.server | |
SEGMENT_TYPE_IMPLS = { | |
SegmentType.HNSW_DISTRIBUTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", | |
} | |
class SegmentServer(VectorReaderServicer): | |
_segment_cache: Dict[UUID, SegmentImplementation] = {} | |
_system: System | |
_opentelemetry_client: OpenTelemetryClient | |
_memberlist_provider: MemberlistProvider | |
_curr_memberlist: Memberlist | |
_assigned_topics: Set[str] | |
_topic_to_subscription: Dict[str, UUID] | |
_consumer: Consumer | |
def __init__(self, system: System) -> None: | |
super().__init__() | |
self._system = system | |
# Init dependency services | |
self._opentelemetry_client = system.require(OpenTelemetryClient) | |
# TODO: add term and epoch to segment server | |
self._memberlist_provider = system.require(MemberlistProvider) | |
self._memberlist_provider.set_memberlist_name("worker-memberlist") | |
self._assignment_policy = system.require(CollectionAssignmentPolicy) | |
self._create_pulsar_topics() | |
self._consumer = system.require(Consumer) | |
# Init data | |
self._topic_to_subscription = {} | |
self._assigned_topics = set() | |
self._curr_memberlist = self._memberlist_provider.get_memberlist() | |
self._compute_assigned_topics() | |
self._memberlist_provider.register_updated_memberlist_callback( | |
self._on_memberlist_update | |
) | |
def _compute_assigned_topics(self) -> None: | |
"""Uses rendezvous hashing to compute the topics that this node is responsible for""" | |
if not self._curr_memberlist: | |
self._assigned_topics = set() | |
return | |
topics = self._assignment_policy.get_topics() | |
my_ip = os.environ["MY_POD_IP"] | |
new_assignments: List[str] = [] | |
for topic in topics: | |
assigned = assign(topic, self._curr_memberlist, murmur3hasher) | |
if assigned == my_ip: | |
new_assignments.append(topic) | |
new_assignments_set = set(new_assignments) | |
# TODO: We need to lock around this assignment | |
net_new_assignments = new_assignments_set - self._assigned_topics | |
removed_assignments = self._assigned_topics - new_assignments_set | |
for topic in removed_assignments: | |
subscription = self._topic_to_subscription[topic] | |
self._consumer.unsubscribe(subscription) | |
del self._topic_to_subscription[topic] | |
for topic in net_new_assignments: | |
subscription = self._consumer.subscribe(topic, self._on_message) | |
self._topic_to_subscription[topic] = subscription | |
self._assigned_topics = new_assignments_set | |
print( | |
f"Topic assigment updated and now assigned to {len(self._assigned_topics)} topics" | |
) | |
def _on_memberlist_update(self, memberlist: Memberlist) -> None: | |
"""Called when the memberlist is updated""" | |
self._curr_memberlist = memberlist | |
if len(self._curr_memberlist) > 0: | |
self._compute_assigned_topics() | |
else: | |
# In this case we'd want to warn that there are no members but | |
# this is not an error, as it could be that the cluster is just starting up | |
print("Memberlist is empty") | |
def _on_message(self, embedding_records: Sequence[EmbeddingRecord]) -> None: | |
"""Called when a message is received from the consumer""" | |
print(f"Received {len(embedding_records)} records") | |
print( | |
f"First record: {embedding_records[0]} is for collection {embedding_records[0]['collection_id']}" | |
) | |
return None | |
def _create_pulsar_topics(self) -> None: | |
"""This creates the pulsar topics used by the system. | |
HACK: THIS IS COMPLETELY A HACK AND WILL BE REPLACED | |
BY A PROPER TOPIC MANAGEMENT SYSTEM IN THE COORDINATOR""" | |
topics = self._assignment_policy.get_topics() | |
admin = PulsarAdmin(self._system) | |
for topic in topics: | |
admin.create_topic(topic) | |
def QueryVectors( | |
self, request: proto.QueryVectorsRequest, context: Any | |
) -> proto.QueryVectorsResponse: | |
context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |
context.set_details("Query segment not implemented yet") | |
return proto.QueryVectorsResponse() | |
# @trace_method( | |
# "SegmentServer.GetVectors", OpenTelemetryGranularity.OPERATION_AND_SEGMENT | |
# ) | |
# def GetVectors( | |
# self, request: proto.GetVectorsRequest, context: Any | |
# ) -> proto.GetVectorsResponse: | |
# segment_id = UUID(hex=request.segment_id) | |
# if segment_id not in self._segment_cache: | |
# context.set_code(grpc.StatusCode.NOT_FOUND) | |
# context.set_details("Segment not found") | |
# return proto.GetVectorsResponse() | |
# else: | |
# segment = self._segment_cache[segment_id] | |
# segment = cast(VectorReader, segment) | |
# segment_results = segment.get_vectors(request.ids) | |
# return_records = [] | |
# for record in segment_results: | |
# # TODO: encoding should be based on stored encoding for segment | |
# # For now we just assume float32 | |
# return_record = to_proto_vector_embedding_record( | |
# record, ScalarEncoding.FLOAT32 | |
# ) | |
# return_records.append(return_record) | |
# return proto.GetVectorsResponse(records=return_records) | |
# def _cls(self, segment: Segment) -> Type[SegmentImplementation]: | |
# classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] | |
# cls = get_class(classname, SegmentImplementation) | |
# return cls | |
# def _create_instance(self, segment: Segment) -> None: | |
# if segment["id"] not in self._segment_cache: | |
# cls = self._cls(segment) | |
# instance = cls(self._system, segment) | |
# instance.start() | |
# self._segment_cache[segment["id"]] = instance | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
system = System(Settings()) | |
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |
segment_server = SegmentServer(system) | |
# add_SegmentServerServicer_to_server(segment_server, server) # type: ignore | |
add_VectorReaderServicer_to_server(segment_server, server) # type: ignore | |
server.add_insecure_port( | |
f"[::]:{system.settings.require('chroma_server_grpc_port')}" | |
) | |
system.start() | |
server.start() | |
server.wait_for_termination() | |