from typing import Any, Callable, Dict, Optional, cast from overrides import EnforceOverrides, override from chromadb.config import System from chromadb.segment.distributed import ( Memberlist, MemberlistProvider, SegmentDirectory, ) from chromadb.types import Segment from kubernetes import client, config, watch from kubernetes.client.rest import ApiException import threading from chromadb.utils.rendezvous_hash import assign, murmur3hasher # These could go in config but given that they will rarely change, they are here for now to avoid # polluting the config file further. WATCH_TIMEOUT_SECONDS = 60 KUBERNETES_NAMESPACE = "chroma" KUBERNETES_GROUP = "chroma.cluster" class MockMemberlistProvider(MemberlistProvider, EnforceOverrides): """A mock memberlist provider for testing""" _memberlist: Memberlist def __init__(self, system: System): super().__init__(system) self._memberlist = ["a", "b", "c"] @override def get_memberlist(self) -> Memberlist: return self._memberlist @override def set_memberlist_name(self, memberlist: str) -> None: pass # The mock provider does not need to set the memberlist name def update_memberlist(self, memberlist: Memberlist) -> None: """Updates the memberlist and calls all registered callbacks. This mocks an update from a k8s CR""" self._memberlist = memberlist for callback in self.callbacks: callback(memberlist) class CustomResourceMemberlistProvider(MemberlistProvider, EnforceOverrides): """A memberlist provider that uses a k8s custom resource to store the memberlist""" _kubernetes_api: client.CustomObjectsApi _memberlist_name: Optional[str] _curr_memberlist: Optional[Memberlist] _curr_memberlist_mutex: threading.Lock _watch_thread: Optional[threading.Thread] _kill_watch_thread: threading.Event def __init__(self, system: System): super().__init__(system) config.load_config() self._kubernetes_api = client.CustomObjectsApi() self._watch_thread = None self._memberlist_name = None self._curr_memberlist = None self._curr_memberlist_mutex = threading.Lock() self._kill_watch_thread = threading.Event() @override def start(self) -> None: if self._memberlist_name is None: raise ValueError("Memberlist name must be set before starting") self.get_memberlist() self._watch_worker_memberlist() return super().start() @override def stop(self) -> None: self._curr_memberlist = None self._memberlist_name = None # Stop the watch thread self._kill_watch_thread.set() if self._watch_thread is not None: self._watch_thread.join() self._watch_thread = None self._kill_watch_thread.clear() return super().stop() @override def reset_state(self) -> None: if not self._system.settings.require("allow_reset"): raise ValueError( "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." ) if self._memberlist_name: self._kubernetes_api.patch_namespaced_custom_object( group=KUBERNETES_GROUP, version="v1", namespace=KUBERNETES_NAMESPACE, plural="memberlists", name=self._memberlist_name, body={ "kind": "MemberList", "spec": {"members": []}, }, ) @override def get_memberlist(self) -> Memberlist: if self._curr_memberlist is None: self._curr_memberlist = self._fetch_memberlist() return self._curr_memberlist @override def set_memberlist_name(self, memberlist: str) -> None: self._memberlist_name = memberlist def _fetch_memberlist(self) -> Memberlist: api_response = self._kubernetes_api.get_namespaced_custom_object( group=KUBERNETES_GROUP, version="v1", namespace=KUBERNETES_NAMESPACE, plural="memberlists", name=f"{self._memberlist_name}", ) api_response = cast(Dict[str, Any], api_response) if "spec" not in api_response: return [] response_spec = cast(Dict[str, Any], api_response["spec"]) return self._parse_response_memberlist(response_spec) def _watch_worker_memberlist(self) -> None: # TODO: We may want to make this watch function a library function that can be used by other # components that need to watch k8s custom resources. def run_watch() -> None: w = watch.Watch() def do_watch() -> None: for event in w.stream( self._kubernetes_api.list_namespaced_custom_object, group=KUBERNETES_GROUP, version="v1", namespace=KUBERNETES_NAMESPACE, plural="memberlists", field_selector=f"metadata.name={self._memberlist_name}", timeout_seconds=WATCH_TIMEOUT_SECONDS, ): event = cast(Dict[str, Any], event) response_spec = event["object"]["spec"] response_spec = cast(Dict[str, Any], response_spec) with self._curr_memberlist_mutex: self._curr_memberlist = self._parse_response_memberlist( response_spec ) self._notify(self._curr_memberlist) # Watch the custom resource for changes # Watch with a timeout and retry so we can gracefully stop this if needed while not self._kill_watch_thread.is_set(): try: do_watch() except ApiException as e: # If status code is 410, the watch has expired and we need to start a new one. if e.status == 410: pass return if self._watch_thread is None: thread = threading.Thread(target=run_watch, daemon=True) thread.start() self._watch_thread = thread else: raise Exception("A watch thread is already running.") def _parse_response_memberlist( self, api_response_spec: Dict[str, Any] ) -> Memberlist: if "members" not in api_response_spec: return [] return [m["url"] for m in api_response_spec["members"]] def _notify(self, memberlist: Memberlist) -> None: for callback in self.callbacks: callback(memberlist) class RendezvousHashSegmentDirectory(SegmentDirectory, EnforceOverrides): _memberlist_provider: MemberlistProvider _curr_memberlist_mutex: threading.Lock _curr_memberlist: Optional[Memberlist] def __init__(self, system: System): super().__init__(system) self._memberlist_provider = self.require(MemberlistProvider) memberlist_name = system.settings.require("worker_memberlist_name") self._memberlist_provider.set_memberlist_name(memberlist_name) self._curr_memberlist = None self._curr_memberlist_mutex = threading.Lock() @override def start(self) -> None: self._curr_memberlist = self._memberlist_provider.get_memberlist() self._memberlist_provider.register_updated_memberlist_callback( self._update_memberlist ) return super().start() @override def stop(self) -> None: self._memberlist_provider.unregister_updated_memberlist_callback( self._update_memberlist ) return super().stop() @override def get_segment_endpoint(self, segment: Segment) -> str: if self._curr_memberlist is None or len(self._curr_memberlist) == 0: raise ValueError("Memberlist is not initialized") assignment = assign(segment["id"].hex, self._curr_memberlist, murmur3hasher) assignment = f"{assignment}:50051" # TODO: make port configurable return assignment @override def register_updated_segment_callback( self, callback: Callable[[Segment], None] ) -> None: raise NotImplementedError() def _update_memberlist(self, memberlist: Memberlist) -> None: with self._curr_memberlist_mutex: self._curr_memberlist = memberlist