import string from typing import Dict, Any import hypothesis.strategies as st import pytest from hypothesis import given, settings from chromadb.api import ServerAPI from chromadb.config import System from chromadb.test.conftest import _fastapi_fixture @st.composite def token_config(draw: st.DrawFn) -> Dict[str, Any]: token_header = draw(st.sampled_from(["AUTHORIZATION", "X_CHROMA_TOKEN", None])) server_provider = draw( st.sampled_from(["token", "chromadb.auth.token.TokenAuthServerProvider"]) ) client_provider = draw( st.sampled_from(["token", "chromadb.auth.token.TokenAuthClientProvider"]) ) server_credentials_provider = draw( st.sampled_from( ["chromadb.auth.token.TokenConfigServerAuthCredentialsProvider"] ) ) token = draw( st.text( alphabet=string.digits + string.ascii_letters + string.punctuation, min_size=1, max_size=50, ) ) persistence = draw(st.booleans()) return { "token_transport_header": token_header, "chroma_server_auth_credentials": token, "chroma_client_auth_credentials": token, "chroma_server_auth_provider": server_provider, "chroma_client_auth_provider": client_provider, "chroma_server_auth_credentials_provider": server_credentials_provider, "is_persistent": persistence, } @settings(max_examples=10) @given(token_config()) def test_fastapi_server_token_auth(token_config: Dict[str, Any]) -> None: api = _fastapi_fixture( is_persistent=token_config["is_persistent"], chroma_server_auth_provider=token_config["chroma_server_auth_provider"], chroma_server_auth_credentials_provider=token_config[ "chroma_server_auth_credentials_provider" ], chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"], chroma_client_auth_provider=token_config["chroma_client_auth_provider"], chroma_client_auth_token_transport_header=token_config[ "token_transport_header" ], chroma_server_auth_token_transport_header=token_config[ "token_transport_header" ], chroma_client_auth_credentials=token_config["chroma_client_auth_credentials"], ) _sys: System = next(api) _sys.reset_state() _api = _sys.instance(ServerAPI) _api.heartbeat() assert _api.list_collections() == [] @st.composite def random_token(draw: st.DrawFn) -> str: return draw( st.text(alphabet=string.ascii_letters + string.digits, min_size=1, max_size=5) ) @st.composite def invalid_token(draw: st.DrawFn) -> str: opposite_alphabet = set(string.printable) - set( string.digits + string.ascii_letters + string.punctuation ) token = draw(st.text(alphabet=list(opposite_alphabet), min_size=1, max_size=50)) return token @settings(max_examples=10) @given(tconf=token_config(), inval_tok=invalid_token()) def test_invalid_token(tconf: Dict[str, Any], inval_tok: str) -> None: api = _fastapi_fixture( is_persistent=tconf["is_persistent"], chroma_server_auth_provider=tconf["chroma_server_auth_provider"], chroma_server_auth_credentials_provider=tconf[ "chroma_server_auth_credentials_provider" ], chroma_server_auth_credentials=tconf["chroma_server_auth_credentials"], chroma_server_auth_token_transport_header=tconf["token_transport_header"], chroma_client_auth_provider=tconf["chroma_client_auth_provider"], chroma_client_auth_token_transport_header=tconf["token_transport_header"], chroma_client_auth_credentials=inval_tok, ) with pytest.raises(Exception) as e: _sys: System = next(api) _sys.reset_state() _sys.instance(ServerAPI) assert "Invalid token" in str(e) @settings(max_examples=10) @given(token_config(), random_token()) def test_fastapi_server_token_auth_wrong_token( token_config: Dict[str, Any], random_token: str ) -> None: api = _fastapi_fixture( is_persistent=token_config["is_persistent"], chroma_server_auth_provider=token_config["chroma_server_auth_provider"], chroma_server_auth_credentials_provider=token_config[ "chroma_server_auth_credentials_provider" ], chroma_server_auth_credentials=token_config["chroma_server_auth_credentials"], chroma_server_auth_token_transport_header=token_config[ "token_transport_header" ], chroma_client_auth_provider=token_config["chroma_client_auth_provider"], chroma_client_auth_token_transport_header=token_config[ "token_transport_header" ], chroma_client_auth_credentials=token_config["chroma_client_auth_credentials"] + random_token, ) _sys: System = next(api) _sys.reset_state() _api = _sys.instance(ServerAPI) _api.heartbeat() with pytest.raises(Exception) as e: _api.list_collections() assert "Unauthorized" in str(e)