Spaces:
Configuration error
Configuration error
File size: 6,362 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import logging
import os
from unittest.mock import Mock, patch
import pytest
from litellm.secret_managers.main import get_secret
# Set up logging for debugging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Mock HTTPHandler and oidc_cache
class MockHTTPHandler:
def __init__(self, timeout):
self.timeout = timeout
self.status_code = 200
self.text = "mocked_token"
self.json_data = {"value": "mocked_token"}
def get(self, url, params=None, headers=None):
# Store params for audience verification
self.last_params = params
logger.debug(
f"MockHTTPHandler.get called with url={url}, params={params}, headers={headers}"
)
mock_response = Mock()
mock_response.status_code = self.status_code
mock_response.text = self.text
mock_response.json.return_value = self.json_data
return mock_response
@pytest.fixture
def mock_oidc_cache():
cache = Mock()
cache.get_cache.return_value = None
cache.set_cache = Mock()
return cache
@pytest.fixture
def mock_env():
with patch.dict(os.environ, {}, clear=True):
yield os.environ
@patch("litellm.secret_managers.main.oidc_cache")
@patch("litellm.secret_managers.main.HTTPHandler")
def test_oidc_google_success(mock_http_handler, mock_oidc_cache):
mock_oidc_cache.get_cache.return_value = None
mock_handler = MockHTTPHandler(timeout=600.0)
mock_http_handler.return_value = mock_handler
secret_name = "oidc/google/[invalid url, do not cite]"
result = get_secret(secret_name)
assert result == "mocked_token"
assert mock_handler.last_params == {"audience": "[invalid url, do not cite]"}
mock_oidc_cache.set_cache.assert_called_once_with(
key=secret_name, value="mocked_token", ttl=3540
)
@patch("litellm.secret_managers.main.oidc_cache")
def test_oidc_google_cached(mock_oidc_cache):
mock_oidc_cache.get_cache.return_value = "cached_token"
secret_name = "oidc/google/[invalid url, do not cite]"
with patch("litellm.HTTPHandler") as mock_http:
result = get_secret(secret_name)
assert result == "cached_token", f"Expected cached token, got {result}"
mock_oidc_cache.get_cache.assert_called_with(key=secret_name)
mock_http.assert_not_called()
def test_oidc_google_failure(mock_oidc_cache):
mock_handler = MockHTTPHandler(timeout=600.0)
mock_handler.status_code = 400
with patch("litellm.secret_managers.main.HTTPHandler", return_value=mock_handler):
mock_oidc_cache.get_cache.return_value = None
secret_name = "oidc/google/https://example.com/api"
with pytest.raises(ValueError, match="Google OIDC provider failed"):
get_secret(secret_name)
def test_oidc_circleci_success(monkeypatch):
monkeypatch.setenv("CIRCLE_OIDC_TOKEN", "circleci_token")
secret_name = "oidc/circleci/test-audience"
result = get_secret(secret_name)
assert result == "circleci_token"
def test_oidc_circleci_failure(monkeypatch):
monkeypatch.delenv("CIRCLE_OIDC_TOKEN", raising=False)
secret_name = "oidc/circleci/test-audience"
with pytest.raises(ValueError, match="CIRCLE_OIDC_TOKEN not found in environment"):
get_secret(secret_name)
@patch("litellm.secret_managers.main.oidc_cache")
@patch("litellm.secret_managers.main.HTTPHandler")
def test_oidc_github_success(mock_http_handler, mock_oidc_cache, mock_env):
mock_env["ACTIONS_ID_TOKEN_REQUEST_URL"] = "https://github.com/token"
mock_env["ACTIONS_ID_TOKEN_REQUEST_TOKEN"] = "github_token"
mock_oidc_cache.get_cache.return_value = None
mock_handler = MockHTTPHandler(timeout=600.0)
mock_http_handler.return_value = mock_handler
secret_name = "oidc/github/github-audience"
result = get_secret(secret_name)
assert result == "mocked_token", f"Expected token 'mocked_token', got {result}"
assert mock_handler.last_params == {"audience": "github-audience"}
logger.debug(f"set_cache call args: {mock_oidc_cache.set_cache.call_args}")
mock_oidc_cache.set_cache.assert_called_once()
mock_oidc_cache.set_cache.assert_called_with(
key=secret_name, value="mocked_token", ttl=295
)
def test_oidc_github_missing_env():
secret_name = "oidc/github/github-audience"
with pytest.raises(
ValueError,
match="ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment",
):
get_secret(secret_name)
def test_oidc_azure_file_success(mock_env, tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("azure_token")
mock_env["AZURE_FEDERATED_TOKEN_FILE"] = str(token_file)
secret_name = "oidc/azure/azure-audience"
result = get_secret(secret_name)
assert result == "azure_token"
@patch("litellm.secret_managers.main.get_azure_ad_token_provider")
def test_oidc_azure_ad_token_success(mock_get_azure_ad_token_provider):
mock_token_provider = Mock(return_value="azure_ad_token")
mock_get_azure_ad_token_provider.return_value = mock_token_provider
secret_name = "oidc/azure/api://azure-audience"
result = get_secret(secret_name)
assert result == "azure_ad_token"
mock_get_azure_ad_token_provider.assert_called_once_with(
azure_scope="api://azure-audience"
)
mock_token_provider.assert_called_once_with()
def test_oidc_file_success(tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("file_token")
secret_name = f"oidc/file/{token_file}"
result = get_secret(secret_name)
assert result == "file_token"
def test_oidc_env_success(mock_env):
mock_env["CUSTOM_TOKEN"] = "env_token"
secret_name = "oidc/env/CUSTOM_TOKEN"
result = get_secret(secret_name)
assert result == "env_token"
def test_oidc_env_path_success(mock_env, tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("env_path_token")
mock_env["TOKEN_PATH"] = str(token_file)
secret_name = "oidc/env_path/TOKEN_PATH"
result = get_secret(secret_name)
assert result == "env_path_token"
def test_unsupported_oidc_provider():
secret_name = "oidc/unsupported/unsupported-audience"
with pytest.raises(ValueError, match="Unsupported OIDC provider"):
get_secret(secret_name)
|