badalsahani's picture
feat: chroma initial deploy
287a0bc
from typing import Any, Dict, List, Sequence, Set
from uuid import UUID
from chromadb.config import Settings, System
from chromadb.ingest import CollectionAssignmentPolicy, Consumer
from chromadb.proto.chroma_pb2_grpc import (
# SegmentServerServicer,
# add_SegmentServerServicer_to_server,
VectorReaderServicer,
add_VectorReaderServicer_to_server,
)
import chromadb.proto.chroma_pb2 as proto
import grpc
from concurrent import futures
from chromadb.proto.convert import (
to_proto_vector_embedding_record
)
from chromadb.segment import SegmentImplementation, SegmentType
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient
)
from chromadb.types import EmbeddingRecord
from chromadb.segment.distributed import MemberlistProvider, Memberlist
from chromadb.utils.rendezvous_hash import assign, murmur3hasher
from chromadb.ingest.impl.pulsar_admin import PulsarAdmin
import logging
import os
# This file is a prototype. It will be replaced with a real distributed segment server
# written in a different language. This is just a proof of concept to get the distributed
# segment type working end to end.
# Run this with python -m chromadb.segment.impl.distributed.server
SEGMENT_TYPE_IMPLS = {
SegmentType.HNSW_DISTRIBUTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment",
}
class SegmentServer(VectorReaderServicer):
_segment_cache: Dict[UUID, SegmentImplementation] = {}
_system: System
_opentelemetry_client: OpenTelemetryClient
_memberlist_provider: MemberlistProvider
_curr_memberlist: Memberlist
_assigned_topics: Set[str]
_topic_to_subscription: Dict[str, UUID]
_consumer: Consumer
def __init__(self, system: System) -> None:
super().__init__()
self._system = system
# Init dependency services
self._opentelemetry_client = system.require(OpenTelemetryClient)
# TODO: add term and epoch to segment server
self._memberlist_provider = system.require(MemberlistProvider)
self._memberlist_provider.set_memberlist_name("worker-memberlist")
self._assignment_policy = system.require(CollectionAssignmentPolicy)
self._create_pulsar_topics()
self._consumer = system.require(Consumer)
# Init data
self._topic_to_subscription = {}
self._assigned_topics = set()
self._curr_memberlist = self._memberlist_provider.get_memberlist()
self._compute_assigned_topics()
self._memberlist_provider.register_updated_memberlist_callback(
self._on_memberlist_update
)
def _compute_assigned_topics(self) -> None:
"""Uses rendezvous hashing to compute the topics that this node is responsible for"""
if not self._curr_memberlist:
self._assigned_topics = set()
return
topics = self._assignment_policy.get_topics()
my_ip = os.environ["MY_POD_IP"]
new_assignments: List[str] = []
for topic in topics:
assigned = assign(topic, self._curr_memberlist, murmur3hasher)
if assigned == my_ip:
new_assignments.append(topic)
new_assignments_set = set(new_assignments)
# TODO: We need to lock around this assignment
net_new_assignments = new_assignments_set - self._assigned_topics
removed_assignments = self._assigned_topics - new_assignments_set
for topic in removed_assignments:
subscription = self._topic_to_subscription[topic]
self._consumer.unsubscribe(subscription)
del self._topic_to_subscription[topic]
for topic in net_new_assignments:
subscription = self._consumer.subscribe(topic, self._on_message)
self._topic_to_subscription[topic] = subscription
self._assigned_topics = new_assignments_set
print(
f"Topic assigment updated and now assigned to {len(self._assigned_topics)} topics"
)
def _on_memberlist_update(self, memberlist: Memberlist) -> None:
"""Called when the memberlist is updated"""
self._curr_memberlist = memberlist
if len(self._curr_memberlist) > 0:
self._compute_assigned_topics()
else:
# In this case we'd want to warn that there are no members but
# this is not an error, as it could be that the cluster is just starting up
print("Memberlist is empty")
def _on_message(self, embedding_records: Sequence[EmbeddingRecord]) -> None:
"""Called when a message is received from the consumer"""
print(f"Received {len(embedding_records)} records")
print(
f"First record: {embedding_records[0]} is for collection {embedding_records[0]['collection_id']}"
)
return None
def _create_pulsar_topics(self) -> None:
"""This creates the pulsar topics used by the system.
HACK: THIS IS COMPLETELY A HACK AND WILL BE REPLACED
BY A PROPER TOPIC MANAGEMENT SYSTEM IN THE COORDINATOR"""
topics = self._assignment_policy.get_topics()
admin = PulsarAdmin(self._system)
for topic in topics:
admin.create_topic(topic)
def QueryVectors(
self, request: proto.QueryVectorsRequest, context: Any
) -> proto.QueryVectorsResponse:
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Query segment not implemented yet")
return proto.QueryVectorsResponse()
# @trace_method(
# "SegmentServer.GetVectors", OpenTelemetryGranularity.OPERATION_AND_SEGMENT
# )
# def GetVectors(
# self, request: proto.GetVectorsRequest, context: Any
# ) -> proto.GetVectorsResponse:
# segment_id = UUID(hex=request.segment_id)
# if segment_id not in self._segment_cache:
# context.set_code(grpc.StatusCode.NOT_FOUND)
# context.set_details("Segment not found")
# return proto.GetVectorsResponse()
# else:
# segment = self._segment_cache[segment_id]
# segment = cast(VectorReader, segment)
# segment_results = segment.get_vectors(request.ids)
# return_records = []
# for record in segment_results:
# # TODO: encoding should be based on stored encoding for segment
# # For now we just assume float32
# return_record = to_proto_vector_embedding_record(
# record, ScalarEncoding.FLOAT32
# )
# return_records.append(return_record)
# return proto.GetVectorsResponse(records=return_records)
# def _cls(self, segment: Segment) -> Type[SegmentImplementation]:
# classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])]
# cls = get_class(classname, SegmentImplementation)
# return cls
# def _create_instance(self, segment: Segment) -> None:
# if segment["id"] not in self._segment_cache:
# cls = self._cls(segment)
# instance = cls(self._system, segment)
# instance.start()
# self._segment_cache[segment["id"]] = instance
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
system = System(Settings())
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
segment_server = SegmentServer(system)
# add_SegmentServerServicer_to_server(segment_server, server) # type: ignore
add_VectorReaderServicer_to_server(segment_server, server) # type: ignore
server.add_insecure_port(
f"[::]:{system.settings.require('chroma_server_grpc_port')}"
)
system.start()
server.start()
server.wait_for_termination()