Spaces:
Sleeping
Sleeping
File size: 3,682 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import base64
import logging
from typing import Tuple, Any, cast
from overrides import override
from pydantic import SecretStr
from chromadb.auth import (
ServerAuthProvider,
ClientAuthProvider,
ServerAuthenticationRequest,
ServerAuthCredentialsProvider,
AuthInfoType,
BasicAuthCredentials,
ClientAuthCredentialsProvider,
ClientAuthResponse,
SimpleServerAuthenticationResponse,
)
from chromadb.auth.registry import register_provider, resolve_provider
from chromadb.config import System
from chromadb.telemetry.opentelemetry import (
OpenTelemetryGranularity,
trace_method,
)
from chromadb.utils import get_class
logger = logging.getLogger(__name__)
__all__ = ["BasicAuthServerProvider", "BasicAuthClientProvider"]
class BasicAuthClientAuthResponse(ClientAuthResponse):
def __init__(self, credentials: SecretStr) -> None:
self._credentials = credentials
@override
def get_auth_info_type(self) -> AuthInfoType:
return AuthInfoType.HEADER
@override
def get_auth_info(self) -> Tuple[str, SecretStr]:
return "Authorization", SecretStr(
f"Basic {self._credentials.get_secret_value()}"
)
@register_provider("basic")
class BasicAuthClientProvider(ClientAuthProvider):
_credentials_provider: ClientAuthCredentialsProvider[Any]
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_client_auth_credentials_provider")
self._credentials_provider = system.require(
get_class(
str(system.settings.chroma_client_auth_credentials_provider),
ClientAuthCredentialsProvider,
)
)
@override
def authenticate(self) -> ClientAuthResponse:
_creds = self._credentials_provider.get_credentials()
return BasicAuthClientAuthResponse(
SecretStr(
base64.b64encode(f"{_creds.get_secret_value()}".encode("utf-8")).decode(
"utf-8"
)
)
)
@register_provider("basic")
class BasicAuthServerProvider(ServerAuthProvider):
_credentials_provider: ServerAuthCredentialsProvider
def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_server_auth_credentials_provider")
self._credentials_provider = cast(
ServerAuthCredentialsProvider,
system.require(
resolve_provider(
str(system.settings.chroma_server_auth_credentials_provider),
ServerAuthCredentialsProvider,
)
),
)
@trace_method("BasicAuthServerProvider.authenticate", OpenTelemetryGranularity.ALL)
@override
def authenticate(
self, request: ServerAuthenticationRequest[Any]
) -> SimpleServerAuthenticationResponse:
try:
_auth_header = request.get_auth_info(AuthInfoType.HEADER, "Authorization")
_validation = self._credentials_provider.validate_credentials(
BasicAuthCredentials.from_header(_auth_header)
)
return SimpleServerAuthenticationResponse(
_validation,
self._credentials_provider.get_user_identity(
BasicAuthCredentials.from_header(_auth_header)
),
)
except Exception as e:
logger.error(f"BasicAuthServerProvider.authenticate failed: {repr(e)}")
return SimpleServerAuthenticationResponse(False, None)
|