Spaces:
Sleeping
Sleeping
from typing import Any, Callable, Dict, Optional, cast | |
from overrides import EnforceOverrides, override | |
from chromadb.config import System | |
from chromadb.segment.distributed import ( | |
Memberlist, | |
MemberlistProvider, | |
SegmentDirectory, | |
) | |
from chromadb.types import Segment | |
from kubernetes import client, config, watch | |
from kubernetes.client.rest import ApiException | |
import threading | |
from chromadb.utils.rendezvous_hash import assign, murmur3hasher | |
# These could go in config but given that they will rarely change, they are here for now to avoid | |
# polluting the config file further. | |
WATCH_TIMEOUT_SECONDS = 60 | |
KUBERNETES_NAMESPACE = "chroma" | |
KUBERNETES_GROUP = "chroma.cluster" | |
class MockMemberlistProvider(MemberlistProvider, EnforceOverrides): | |
"""A mock memberlist provider for testing""" | |
_memberlist: Memberlist | |
def __init__(self, system: System): | |
super().__init__(system) | |
self._memberlist = ["a", "b", "c"] | |
def get_memberlist(self) -> Memberlist: | |
return self._memberlist | |
def set_memberlist_name(self, memberlist: str) -> None: | |
pass # The mock provider does not need to set the memberlist name | |
def update_memberlist(self, memberlist: Memberlist) -> None: | |
"""Updates the memberlist and calls all registered callbacks. This mocks an update from a k8s CR""" | |
self._memberlist = memberlist | |
for callback in self.callbacks: | |
callback(memberlist) | |
class CustomResourceMemberlistProvider(MemberlistProvider, EnforceOverrides): | |
"""A memberlist provider that uses a k8s custom resource to store the memberlist""" | |
_kubernetes_api: client.CustomObjectsApi | |
_memberlist_name: Optional[str] | |
_curr_memberlist: Optional[Memberlist] | |
_curr_memberlist_mutex: threading.Lock | |
_watch_thread: Optional[threading.Thread] | |
_kill_watch_thread: threading.Event | |
def __init__(self, system: System): | |
super().__init__(system) | |
config.load_config() | |
self._kubernetes_api = client.CustomObjectsApi() | |
self._watch_thread = None | |
self._memberlist_name = None | |
self._curr_memberlist = None | |
self._curr_memberlist_mutex = threading.Lock() | |
self._kill_watch_thread = threading.Event() | |
def start(self) -> None: | |
if self._memberlist_name is None: | |
raise ValueError("Memberlist name must be set before starting") | |
self.get_memberlist() | |
self._watch_worker_memberlist() | |
return super().start() | |
def stop(self) -> None: | |
self._curr_memberlist = None | |
self._memberlist_name = None | |
# Stop the watch thread | |
self._kill_watch_thread.set() | |
if self._watch_thread is not None: | |
self._watch_thread.join() | |
self._watch_thread = None | |
self._kill_watch_thread.clear() | |
return super().stop() | |
def reset_state(self) -> None: | |
if not self._system.settings.require("allow_reset"): | |
raise ValueError( | |
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." | |
) | |
if self._memberlist_name: | |
self._kubernetes_api.patch_namespaced_custom_object( | |
group=KUBERNETES_GROUP, | |
version="v1", | |
namespace=KUBERNETES_NAMESPACE, | |
plural="memberlists", | |
name=self._memberlist_name, | |
body={ | |
"kind": "MemberList", | |
"spec": {"members": []}, | |
}, | |
) | |
def get_memberlist(self) -> Memberlist: | |
if self._curr_memberlist is None: | |
self._curr_memberlist = self._fetch_memberlist() | |
return self._curr_memberlist | |
def set_memberlist_name(self, memberlist: str) -> None: | |
self._memberlist_name = memberlist | |
def _fetch_memberlist(self) -> Memberlist: | |
api_response = self._kubernetes_api.get_namespaced_custom_object( | |
group=KUBERNETES_GROUP, | |
version="v1", | |
namespace=KUBERNETES_NAMESPACE, | |
plural="memberlists", | |
name=f"{self._memberlist_name}", | |
) | |
api_response = cast(Dict[str, Any], api_response) | |
if "spec" not in api_response: | |
return [] | |
response_spec = cast(Dict[str, Any], api_response["spec"]) | |
return self._parse_response_memberlist(response_spec) | |
def _watch_worker_memberlist(self) -> None: | |
# TODO: We may want to make this watch function a library function that can be used by other | |
# components that need to watch k8s custom resources. | |
def run_watch() -> None: | |
w = watch.Watch() | |
def do_watch() -> None: | |
for event in w.stream( | |
self._kubernetes_api.list_namespaced_custom_object, | |
group=KUBERNETES_GROUP, | |
version="v1", | |
namespace=KUBERNETES_NAMESPACE, | |
plural="memberlists", | |
field_selector=f"metadata.name={self._memberlist_name}", | |
timeout_seconds=WATCH_TIMEOUT_SECONDS, | |
): | |
event = cast(Dict[str, Any], event) | |
response_spec = event["object"]["spec"] | |
response_spec = cast(Dict[str, Any], response_spec) | |
with self._curr_memberlist_mutex: | |
self._curr_memberlist = self._parse_response_memberlist( | |
response_spec | |
) | |
self._notify(self._curr_memberlist) | |
# Watch the custom resource for changes | |
# Watch with a timeout and retry so we can gracefully stop this if needed | |
while not self._kill_watch_thread.is_set(): | |
try: | |
do_watch() | |
except ApiException as e: | |
# If status code is 410, the watch has expired and we need to start a new one. | |
if e.status == 410: | |
pass | |
return | |
if self._watch_thread is None: | |
thread = threading.Thread(target=run_watch, daemon=True) | |
thread.start() | |
self._watch_thread = thread | |
else: | |
raise Exception("A watch thread is already running.") | |
def _parse_response_memberlist( | |
self, api_response_spec: Dict[str, Any] | |
) -> Memberlist: | |
if "members" not in api_response_spec: | |
return [] | |
return [m["url"] for m in api_response_spec["members"]] | |
def _notify(self, memberlist: Memberlist) -> None: | |
for callback in self.callbacks: | |
callback(memberlist) | |
class RendezvousHashSegmentDirectory(SegmentDirectory, EnforceOverrides): | |
_memberlist_provider: MemberlistProvider | |
_curr_memberlist_mutex: threading.Lock | |
_curr_memberlist: Optional[Memberlist] | |
def __init__(self, system: System): | |
super().__init__(system) | |
self._memberlist_provider = self.require(MemberlistProvider) | |
memberlist_name = system.settings.require("worker_memberlist_name") | |
self._memberlist_provider.set_memberlist_name(memberlist_name) | |
self._curr_memberlist = None | |
self._curr_memberlist_mutex = threading.Lock() | |
def start(self) -> None: | |
self._curr_memberlist = self._memberlist_provider.get_memberlist() | |
self._memberlist_provider.register_updated_memberlist_callback( | |
self._update_memberlist | |
) | |
return super().start() | |
def stop(self) -> None: | |
self._memberlist_provider.unregister_updated_memberlist_callback( | |
self._update_memberlist | |
) | |
return super().stop() | |
def get_segment_endpoint(self, segment: Segment) -> str: | |
if self._curr_memberlist is None or len(self._curr_memberlist) == 0: | |
raise ValueError("Memberlist is not initialized") | |
assignment = assign(segment["id"].hex, self._curr_memberlist, murmur3hasher) | |
assignment = f"{assignment}:50051" # TODO: make port configurable | |
return assignment | |
def register_updated_segment_callback( | |
self, callback: Callable[[Segment], None] | |
) -> None: | |
raise NotImplementedError() | |
def _update_memberlist(self, memberlist: Memberlist) -> None: | |
with self._curr_memberlist_mutex: | |
self._curr_memberlist = memberlist | |