import base64 import logging from typing import Tuple, Any, cast from overrides import override from pydantic import SecretStr from chromadb.auth import ( ServerAuthProvider, ClientAuthProvider, ServerAuthenticationRequest, ServerAuthCredentialsProvider, AuthInfoType, BasicAuthCredentials, ClientAuthCredentialsProvider, ClientAuthResponse, SimpleServerAuthenticationResponse, ) from chromadb.auth.registry import register_provider, resolve_provider from chromadb.config import System from chromadb.telemetry.opentelemetry import ( OpenTelemetryGranularity, trace_method, ) from chromadb.utils import get_class logger = logging.getLogger(__name__) __all__ = ["BasicAuthServerProvider", "BasicAuthClientProvider"] class BasicAuthClientAuthResponse(ClientAuthResponse): def __init__(self, credentials: SecretStr) -> None: self._credentials = credentials @override def get_auth_info_type(self) -> AuthInfoType: return AuthInfoType.HEADER @override def get_auth_info(self) -> Tuple[str, SecretStr]: return "Authorization", SecretStr( f"Basic {self._credentials.get_secret_value()}" ) @register_provider("basic") class BasicAuthClientProvider(ClientAuthProvider): _credentials_provider: ClientAuthCredentialsProvider[Any] def __init__(self, system: System) -> None: super().__init__(system) self._settings = system.settings system.settings.require("chroma_client_auth_credentials_provider") self._credentials_provider = system.require( get_class( str(system.settings.chroma_client_auth_credentials_provider), ClientAuthCredentialsProvider, ) ) @override def authenticate(self) -> ClientAuthResponse: _creds = self._credentials_provider.get_credentials() return BasicAuthClientAuthResponse( SecretStr( base64.b64encode(f"{_creds.get_secret_value()}".encode("utf-8")).decode( "utf-8" ) ) ) @register_provider("basic") class BasicAuthServerProvider(ServerAuthProvider): _credentials_provider: ServerAuthCredentialsProvider def __init__(self, system: System) -> None: super().__init__(system) self._settings = system.settings system.settings.require("chroma_server_auth_credentials_provider") self._credentials_provider = cast( ServerAuthCredentialsProvider, system.require( resolve_provider( str(system.settings.chroma_server_auth_credentials_provider), ServerAuthCredentialsProvider, ) ), ) @trace_method("BasicAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL) @override def authenticate( self, request: ServerAuthenticationRequest[Any] ) -> SimpleServerAuthenticationResponse: try: _auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization") _validation = self._credentials_provider.validate_credentials( BasicAuthCredentials.from_header(_auth_header) ) return SimpleServerAuthenticationResponse( _validation, self._credentials_provider.get_user_identity( BasicAuthCredentials.from_header(_auth_header) ), ) except Exception as e: logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}") return SimpleServerAuthenticationResponse(False, None)