Spaces:
Sleeping
Sleeping
import ast | |
import asyncio | |
import json | |
import uuid | |
from base64 import b64encode | |
from datetime import datetime | |
from typing import Dict, List, Optional, Tuple, Union | |
from urllib.parse import parse_qs, urlencode, urlparse | |
import httpx | |
from fastapi import ( | |
APIRouter, | |
Depends, | |
HTTPException, | |
Request, | |
Response, | |
UploadFile, | |
status, | |
) | |
from fastapi.responses import StreamingResponse | |
from starlette.datastructures import UploadFile as StarletteUploadFile | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps | |
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client | |
from litellm.proxy._types import ( | |
ConfigFieldInfo, | |
ConfigFieldUpdate, | |
PassThroughEndpointResponse, | |
PassThroughGenericEndpoint, | |
ProxyException, | |
UserAPIKeyAuth, | |
) | |
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth | |
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing | |
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body | |
from litellm.secret_managers.main import get_secret_str | |
from litellm.types.llms.custom_http import httpxSpecialProvider | |
from litellm.types.passthrough_endpoints.pass_through_endpoints import ( | |
EndpointType, | |
PassthroughStandardLoggingPayload, | |
) | |
from litellm.types.utils import StandardLoggingUserAPIKeyMetadata | |
from .streaming_handler import PassThroughStreamingHandler | |
from .success_handler import PassThroughEndpointLogging | |
router = APIRouter() | |
pass_through_endpoint_logging = PassThroughEndpointLogging() | |
def get_response_body(response: httpx.Response) -> Optional[dict]: | |
try: | |
return response.json() | |
except Exception: | |
return None | |
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: | |
""" | |
checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc | |
only runs for headers defined on config.yaml | |
example header can be | |
{"Authorization": "bearer os.environ/COHERE_API_KEY"} | |
""" | |
if custom_headers is None: | |
return None | |
headers = {} | |
for key, value in custom_headers.items(): | |
# langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys | |
# we can then get the b64 encoded keys here | |
if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": | |
# langfuse requires b64 encoded headers - we construct that here | |
_langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] | |
_langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] | |
if isinstance( | |
_langfuse_public_key, str | |
) and _langfuse_public_key.startswith("os.environ/"): | |
_langfuse_public_key = get_secret_str(_langfuse_public_key) | |
if isinstance( | |
_langfuse_secret_key, str | |
) and _langfuse_secret_key.startswith("os.environ/"): | |
_langfuse_secret_key = get_secret_str(_langfuse_secret_key) | |
headers["Authorization"] = "Basic " + b64encode( | |
f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") | |
).decode("ascii") | |
else: | |
# for all other headers | |
headers[key] = value | |
if isinstance(value, str) and "os.environ/" in value: | |
verbose_proxy_logger.debug( | |
"pass through endpoint - looking up 'os.environ/' variable" | |
) | |
# get string section that is os.environ/ | |
start_index = value.find("os.environ/") | |
_variable_name = value[start_index:] | |
verbose_proxy_logger.debug( | |
"pass through endpoint - getting secret for variable name: %s", | |
_variable_name, | |
) | |
_secret_value = get_secret_str(_variable_name) | |
if _secret_value is not None: | |
new_value = value.replace(_variable_name, _secret_value) | |
headers[key] = new_value | |
return headers | |
async def chat_completion_pass_through_endpoint( # noqa: PLR0915 | |
fastapi_response: Response, | |
request: Request, | |
adapter_id: str, | |
user_api_key_dict: UserAPIKeyAuth, | |
): | |
from litellm.proxy.proxy_server import ( | |
add_litellm_data_to_request, | |
general_settings, | |
llm_router, | |
proxy_config, | |
proxy_logging_obj, | |
user_api_base, | |
user_max_tokens, | |
user_model, | |
user_request_timeout, | |
user_temperature, | |
version, | |
) | |
data = {} | |
try: | |
body = await request.body() | |
body_str = body.decode() | |
try: | |
data = ast.literal_eval(body_str) | |
except Exception: | |
data = json.loads(body_str) | |
data["adapter_id"] = adapter_id | |
verbose_proxy_logger.debug( | |
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), | |
) | |
data["model"] = ( | |
general_settings.get("completion_model", None) # server default | |
or user_model # model name passed via cli args | |
or data.get("model", None) # default passed in http request | |
) | |
if user_model: | |
data["model"] = user_model | |
data = await add_litellm_data_to_request( | |
data=data, # type: ignore | |
request=request, | |
general_settings=general_settings, | |
user_api_key_dict=user_api_key_dict, | |
version=version, | |
proxy_config=proxy_config, | |
) | |
# override with user settings, these are params passed via cli | |
if user_temperature: | |
data["temperature"] = user_temperature | |
if user_request_timeout: | |
data["request_timeout"] = user_request_timeout | |
if user_max_tokens: | |
data["max_tokens"] = user_max_tokens | |
if user_api_base: | |
data["api_base"] = user_api_base | |
### MODEL ALIAS MAPPING ### | |
# check if model name in model alias map | |
# get the actual model name | |
if data["model"] in litellm.model_alias_map: | |
data["model"] = litellm.model_alias_map[data["model"]] | |
### CALL HOOKS ### - modify incoming data before calling the model | |
data = await proxy_logging_obj.pre_call_hook( # type: ignore | |
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" | |
) | |
### ROUTE THE REQUESTs ### | |
router_model_names = llm_router.model_names if llm_router is not None else [] | |
# skip router if user passed their key | |
if "api_key" in data: | |
llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) | |
elif ( | |
llm_router is not None and data["model"] in router_model_names | |
): # model in router model list | |
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) | |
elif ( | |
llm_router is not None | |
and llm_router.model_group_alias is not None | |
and data["model"] in llm_router.model_group_alias | |
): # model set in model_group_alias | |
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) | |
elif ( | |
llm_router is not None and data["model"] in llm_router.deployment_names | |
): # model in router deployments, calling a specific deployment on the router | |
llm_response = asyncio.create_task( | |
llm_router.aadapter_completion(**data, specific_deployment=True) | |
) | |
elif ( | |
llm_router is not None and data["model"] in llm_router.get_model_ids() | |
): # model in router model list | |
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) | |
elif ( | |
llm_router is not None | |
and data["model"] not in router_model_names | |
and llm_router.default_deployment is not None | |
): # model in router deployments, calling a specific deployment on the router | |
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) | |
elif user_model is not None: # `litellm --model <your-model-name>` | |
llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) | |
else: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail={ | |
"error": "completion: Invalid model name passed in model=" | |
+ data.get("model", "") | |
}, | |
) | |
# Await the llm_response task | |
response = await llm_response | |
hidden_params = getattr(response, "_hidden_params", {}) or {} | |
model_id = hidden_params.get("model_id", None) or "" | |
cache_key = hidden_params.get("cache_key", None) or "" | |
api_base = hidden_params.get("api_base", None) or "" | |
response_cost = hidden_params.get("response_cost", None) or "" | |
### ALERTING ### | |
asyncio.create_task( | |
proxy_logging_obj.update_request_status( | |
litellm_call_id=data.get("litellm_call_id", ""), status="success" | |
) | |
) | |
verbose_proxy_logger.debug("final response: %s", response) | |
fastapi_response.headers.update( | |
ProxyBaseLLMRequestProcessing.get_custom_headers( | |
user_api_key_dict=user_api_key_dict, | |
model_id=model_id, | |
cache_key=cache_key, | |
api_base=api_base, | |
version=version, | |
response_cost=response_cost, | |
) | |
) | |
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) | |
return response | |
except Exception as e: | |
await proxy_logging_obj.post_call_failure_hook( | |
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data | |
) | |
verbose_proxy_logger.exception( | |
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format( | |
str(e) | |
) | |
) | |
error_msg = f"{str(e)}" | |
raise ProxyException( | |
message=getattr(e, "message", error_msg), | |
type=getattr(e, "type", "None"), | |
param=getattr(e, "param", "None"), | |
code=getattr(e, "status_code", 500), | |
) | |
class HttpPassThroughEndpointHelpers: | |
def forward_headers_from_request( | |
request: Request, | |
headers: dict, | |
forward_headers: Optional[bool] = False, | |
): | |
""" | |
Helper to forward headers from original request | |
""" | |
if forward_headers is True: | |
request_headers = dict(request.headers) | |
# Header We Should NOT forward | |
request_headers.pop("content-length", None) | |
request_headers.pop("host", None) | |
# Combine request headers with custom headers | |
headers = {**request_headers, **headers} | |
return headers | |
def get_response_headers( | |
headers: httpx.Headers, | |
litellm_call_id: Optional[str] = None, | |
custom_headers: Optional[dict] = None, | |
) -> dict: | |
excluded_headers = {"transfer-encoding", "content-encoding"} | |
return_headers = { | |
key: value | |
for key, value in headers.items() | |
if key.lower() not in excluded_headers | |
} | |
if litellm_call_id: | |
return_headers["x-litellm-call-id"] = litellm_call_id | |
if custom_headers: | |
return_headers.update(custom_headers) | |
return return_headers | |
def get_endpoint_type(url: str) -> EndpointType: | |
parsed_url = urlparse(url) | |
if ("generateContent") in url or ("streamGenerateContent") in url: | |
return EndpointType.VERTEX_AI | |
elif parsed_url.hostname == "api.anthropic.com": | |
return EndpointType.ANTHROPIC | |
return EndpointType.GENERIC | |
def get_merged_query_parameters( | |
existing_url: httpx.URL, request_query_params: Dict[str, Union[str, list]] | |
) -> Dict[str, Union[str, List[str]]]: | |
# Get the existing query params from the target URL | |
existing_query_string = existing_url.query.decode("utf-8") | |
existing_query_params = parse_qs(existing_query_string) | |
# parse_qs returns a dict where each value is a list, so let's flatten it | |
updated_existing_query_params = { | |
k: v[0] if len(v) == 1 else v for k, v in existing_query_params.items() | |
} | |
# Merge the query params, giving priority to the existing ones | |
return {**request_query_params, **updated_existing_query_params} | |
async def _make_non_streaming_http_request( | |
request: Request, | |
async_client: httpx.AsyncClient, | |
url: str, | |
headers: dict, | |
requested_query_params: Optional[dict] = None, | |
custom_body: Optional[dict] = None, | |
) -> httpx.Response: | |
""" | |
Make a non-streaming HTTP request | |
If request is GET, don't include a JSON body | |
""" | |
if request.method == "GET": | |
response = await async_client.request( | |
method=request.method, | |
url=url, | |
headers=headers, | |
params=requested_query_params, | |
) | |
else: | |
response = await async_client.request( | |
method=request.method, | |
url=url, | |
headers=headers, | |
params=requested_query_params, | |
json=custom_body, | |
) | |
return response | |
async def non_streaming_http_request_handler( | |
request: Request, | |
async_client: httpx.AsyncClient, | |
url: httpx.URL, | |
headers: dict, | |
requested_query_params: Optional[dict] = None, | |
_parsed_body: Optional[dict] = None, | |
) -> httpx.Response: | |
""" | |
Handle non-streaming HTTP requests | |
Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests | |
""" | |
if request.method == "GET": | |
response = await async_client.request( | |
method=request.method, | |
url=url, | |
headers=headers, | |
params=requested_query_params, | |
) | |
elif HttpPassThroughEndpointHelpers.is_multipart(request) is True: | |
return await HttpPassThroughEndpointHelpers.make_multipart_http_request( | |
request=request, | |
async_client=async_client, | |
url=url, | |
headers=headers, | |
requested_query_params=requested_query_params, | |
) | |
else: | |
# Generic httpx method | |
response = await async_client.request( | |
method=request.method, | |
url=url, | |
headers=headers, | |
params=requested_query_params, | |
json=_parsed_body, | |
) | |
return response | |
def is_multipart(request: Request) -> bool: | |
"""Check if the request is a multipart/form-data request""" | |
return "multipart/form-data" in request.headers.get("content-type", "") | |
async def _build_request_files_from_upload_file( | |
upload_file: Union[UploadFile, StarletteUploadFile], | |
) -> Tuple[Optional[str], bytes, Optional[str]]: | |
"""Build a request files dict from an UploadFile object""" | |
file_content = await upload_file.read() | |
return (upload_file.filename, file_content, upload_file.content_type) | |
async def make_multipart_http_request( | |
request: Request, | |
async_client: httpx.AsyncClient, | |
url: httpx.URL, | |
headers: dict, | |
requested_query_params: Optional[dict] = None, | |
) -> httpx.Response: | |
"""Process multipart/form-data requests, handling both files and form fields""" | |
form_data = await request.form() | |
files = {} | |
form_data_dict = {} | |
for field_name, field_value in form_data.items(): | |
if isinstance(field_value, (StarletteUploadFile, UploadFile)): | |
files[field_name] = ( | |
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( | |
upload_file=field_value | |
) | |
) | |
else: | |
form_data_dict[field_name] = field_value | |
response = await async_client.request( | |
method=request.method, | |
url=url, | |
headers=headers, | |
params=requested_query_params, | |
files=files, | |
data=form_data_dict, | |
) | |
return response | |
async def pass_through_request( # noqa: PLR0915 | |
request: Request, | |
target: str, | |
custom_headers: dict, | |
user_api_key_dict: UserAPIKeyAuth, | |
custom_body: Optional[dict] = None, | |
forward_headers: Optional[bool] = False, | |
merge_query_params: Optional[bool] = False, | |
query_params: Optional[dict] = None, | |
stream: Optional[bool] = None, | |
): | |
litellm_call_id = str(uuid.uuid4()) | |
url: Optional[httpx.URL] = None | |
try: | |
from litellm.litellm_core_utils.litellm_logging import Logging | |
from litellm.proxy.proxy_server import proxy_logging_obj | |
url = httpx.URL(target) | |
headers = custom_headers | |
headers = HttpPassThroughEndpointHelpers.forward_headers_from_request( | |
request=request, headers=headers, forward_headers=forward_headers | |
) | |
if merge_query_params: | |
# Create a new URL with the merged query params | |
url = url.copy_with( | |
query=urlencode( | |
HttpPassThroughEndpointHelpers.get_merged_query_parameters( | |
existing_url=url, | |
request_query_params=dict(request.query_params), | |
) | |
).encode("ascii") | |
) | |
endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type( | |
str(url) | |
) | |
_parsed_body = None | |
if custom_body: | |
_parsed_body = custom_body | |
else: | |
_parsed_body = await _read_request_body(request) | |
verbose_proxy_logger.debug( | |
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( | |
url, headers, _parsed_body | |
) | |
) | |
### CALL HOOKS ### - modify incoming data / reject request before calling the model | |
_parsed_body = await proxy_logging_obj.pre_call_hook( | |
user_api_key_dict=user_api_key_dict, | |
data=_parsed_body, | |
call_type="pass_through_endpoint", | |
) | |
async_client_obj = get_async_httpx_client( | |
llm_provider=httpxSpecialProvider.PassThroughEndpoint, | |
params={"timeout": 600}, | |
) | |
async_client = async_client_obj.client | |
# create logging object | |
start_time = datetime.now() | |
logging_obj = Logging( | |
model="unknown", | |
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}], | |
stream=False, | |
call_type="pass_through_endpoint", | |
start_time=start_time, | |
litellm_call_id=litellm_call_id, | |
function_id="1245", | |
) | |
passthrough_logging_payload = PassthroughStandardLoggingPayload( | |
url=str(url), | |
request_body=_parsed_body, | |
request_method=getattr(request, "method", None), | |
) | |
kwargs = _init_kwargs_for_pass_through_endpoint( | |
user_api_key_dict=user_api_key_dict, | |
_parsed_body=_parsed_body, | |
passthrough_logging_payload=passthrough_logging_payload, | |
litellm_call_id=litellm_call_id, | |
request=request, | |
logging_obj=logging_obj, | |
) | |
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints | |
logging_obj.update_environment_variables( | |
model="unknown", | |
user="unknown", | |
optional_params={}, | |
litellm_params=kwargs["litellm_params"], | |
call_type="pass_through_endpoint", | |
) | |
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id | |
# combine url with query params for logging | |
requested_query_params: Optional[dict] = ( | |
query_params or request.query_params.__dict__ | |
) | |
if requested_query_params == request.query_params.__dict__: | |
requested_query_params = None | |
requested_query_params_str = None | |
if requested_query_params: | |
requested_query_params_str = "&".join( | |
f"{k}={v}" for k, v in requested_query_params.items() | |
) | |
logging_url = str(url) | |
if requested_query_params_str: | |
if "?" in str(url): | |
logging_url = str(url) + "&" + requested_query_params_str | |
else: | |
logging_url = str(url) + "?" + requested_query_params_str | |
logging_obj.pre_call( | |
input=[{"role": "user", "content": safe_dumps(_parsed_body)}], | |
api_key="", | |
additional_args={ | |
"complete_input_dict": _parsed_body, | |
"api_base": str(logging_url), | |
"headers": headers, | |
}, | |
) | |
if stream: | |
req = async_client.build_request( | |
"POST", | |
url, | |
json=_parsed_body, | |
params=requested_query_params, | |
headers=headers, | |
) | |
response = await async_client.send(req, stream=stream) | |
try: | |
response.raise_for_status() | |
except httpx.HTTPStatusError as e: | |
raise HTTPException( | |
status_code=e.response.status_code, detail=await e.response.aread() | |
) | |
return StreamingResponse( | |
PassThroughStreamingHandler.chunk_processor( | |
response=response, | |
request_body=_parsed_body, | |
litellm_logging_obj=logging_obj, | |
endpoint_type=endpoint_type, | |
start_time=start_time, | |
passthrough_success_handler_obj=pass_through_endpoint_logging, | |
url_route=str(url), | |
), | |
headers=HttpPassThroughEndpointHelpers.get_response_headers( | |
headers=response.headers, | |
litellm_call_id=litellm_call_id, | |
), | |
status_code=response.status_code, | |
) | |
verbose_proxy_logger.debug("request method: {}".format(request.method)) | |
verbose_proxy_logger.debug("request url: {}".format(url)) | |
verbose_proxy_logger.debug("request headers: {}".format(headers)) | |
verbose_proxy_logger.debug( | |
"requested_query_params={}".format(requested_query_params) | |
) | |
verbose_proxy_logger.debug("request body: {}".format(_parsed_body)) | |
response = ( | |
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler( | |
request=request, | |
async_client=async_client, | |
url=url, | |
headers=headers, | |
requested_query_params=requested_query_params, | |
_parsed_body=_parsed_body, | |
) | |
) | |
verbose_proxy_logger.debug("response.headers= %s", response.headers) | |
if _is_streaming_response(response) is True: | |
try: | |
response.raise_for_status() | |
except httpx.HTTPStatusError as e: | |
raise HTTPException( | |
status_code=e.response.status_code, detail=await e.response.aread() | |
) | |
return StreamingResponse( | |
PassThroughStreamingHandler.chunk_processor( | |
response=response, | |
request_body=_parsed_body, | |
litellm_logging_obj=logging_obj, | |
endpoint_type=endpoint_type, | |
start_time=start_time, | |
passthrough_success_handler_obj=pass_through_endpoint_logging, | |
url_route=str(url), | |
), | |
headers=HttpPassThroughEndpointHelpers.get_response_headers( | |
headers=response.headers, | |
litellm_call_id=litellm_call_id, | |
), | |
status_code=response.status_code, | |
) | |
try: | |
response.raise_for_status() | |
except httpx.HTTPStatusError as e: | |
raise HTTPException( | |
status_code=e.response.status_code, detail=e.response.text | |
) | |
if response.status_code >= 300: | |
raise HTTPException(status_code=response.status_code, detail=response.text) | |
content = await response.aread() | |
## LOG SUCCESS | |
response_body: Optional[dict] = get_response_body(response) | |
passthrough_logging_payload["response_body"] = response_body | |
end_time = datetime.now() | |
asyncio.create_task( | |
pass_through_endpoint_logging.pass_through_async_success_handler( | |
httpx_response=response, | |
response_body=response_body, | |
url_route=str(url), | |
result="", | |
start_time=start_time, | |
end_time=end_time, | |
logging_obj=logging_obj, | |
cache_hit=False, | |
request_body=_parsed_body, | |
**kwargs, | |
) | |
) | |
## CUSTOM HEADERS - `x-litellm-*` | |
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( | |
user_api_key_dict=user_api_key_dict, | |
call_id=litellm_call_id, | |
model_id=None, | |
cache_key=None, | |
api_base=str(url._uri_reference), | |
) | |
return Response( | |
content=content, | |
status_code=response.status_code, | |
headers=HttpPassThroughEndpointHelpers.get_response_headers( | |
headers=response.headers, | |
custom_headers=custom_headers, | |
), | |
) | |
except Exception as e: | |
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( | |
user_api_key_dict=user_api_key_dict, | |
call_id=litellm_call_id, | |
model_id=None, | |
cache_key=None, | |
api_base=str(url._uri_reference) if url else None, | |
) | |
verbose_proxy_logger.exception( | |
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( | |
str(e) | |
) | |
) | |
if isinstance(e, HTTPException): | |
raise ProxyException( | |
message=getattr(e, "message", str(e.detail)), | |
type=getattr(e, "type", "None"), | |
param=getattr(e, "param", "None"), | |
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), | |
headers=custom_headers, | |
) | |
else: | |
error_msg = f"{str(e)}" | |
raise ProxyException( | |
message=getattr(e, "message", error_msg), | |
type=getattr(e, "type", "None"), | |
param=getattr(e, "param", "None"), | |
code=getattr(e, "status_code", 500), | |
headers=custom_headers, | |
) | |
def _init_kwargs_for_pass_through_endpoint( | |
request: Request, | |
user_api_key_dict: UserAPIKeyAuth, | |
passthrough_logging_payload: PassthroughStandardLoggingPayload, | |
logging_obj: LiteLLMLoggingObj, | |
_parsed_body: Optional[dict] = None, | |
litellm_call_id: Optional[str] = None, | |
) -> dict: | |
_parsed_body = _parsed_body or {} | |
_litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None) | |
_metadata = dict( | |
StandardLoggingUserAPIKeyMetadata( | |
user_api_key_hash=user_api_key_dict.api_key, | |
user_api_key_alias=user_api_key_dict.key_alias, | |
user_api_key_user_email=user_api_key_dict.user_email, | |
user_api_key_user_id=user_api_key_dict.user_id, | |
user_api_key_team_id=user_api_key_dict.team_id, | |
user_api_key_org_id=user_api_key_dict.org_id, | |
user_api_key_team_alias=user_api_key_dict.team_alias, | |
user_api_key_end_user_id=user_api_key_dict.end_user_id, | |
) | |
) | |
_metadata["user_api_key"] = user_api_key_dict.api_key | |
if _litellm_metadata: | |
_metadata.update(_litellm_metadata) | |
_metadata = _update_metadata_with_tags_in_header( | |
request=request, | |
metadata=_metadata, | |
) | |
kwargs = { | |
"litellm_params": { | |
"metadata": _metadata, | |
}, | |
"call_type": "pass_through_endpoint", | |
"litellm_call_id": litellm_call_id, | |
"passthrough_logging_payload": passthrough_logging_payload, | |
} | |
logging_obj.model_call_details["passthrough_logging_payload"] = ( | |
passthrough_logging_payload | |
) | |
return kwargs | |
def _update_metadata_with_tags_in_header(request: Request, metadata: dict) -> dict: | |
""" | |
If tags are in the request headers, add them to the metadata | |
Used for google and vertex JS SDKs | |
""" | |
_tags = request.headers.get("tags") | |
if _tags: | |
metadata["tags"] = _tags.split(",") | |
return metadata | |
def create_pass_through_route( | |
endpoint, | |
target: str, | |
custom_headers: Optional[dict] = None, | |
_forward_headers: Optional[bool] = False, | |
_merge_query_params: Optional[bool] = False, | |
dependencies: Optional[List] = None, | |
): | |
# check if target is an adapter.py or a url | |
import uuid | |
from litellm.proxy.types_utils.utils import get_instance_fn | |
try: | |
if isinstance(target, CustomLogger): | |
adapter = target | |
else: | |
adapter = get_instance_fn(value=target) | |
adapter_id = str(uuid.uuid4()) | |
litellm.adapters = [{"id": adapter_id, "adapter": adapter}] | |
async def endpoint_func( # type: ignore | |
request: Request, | |
fastapi_response: Response, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
): | |
return await chat_completion_pass_through_endpoint( | |
fastapi_response=fastapi_response, | |
request=request, | |
adapter_id=adapter_id, | |
user_api_key_dict=user_api_key_dict, | |
) | |
except Exception: | |
verbose_proxy_logger.debug("Defaulting to target being a url.") | |
async def endpoint_func( # type: ignore | |
request: Request, | |
fastapi_response: Response, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
query_params: Optional[dict] = None, | |
custom_body: Optional[dict] = None, | |
stream: Optional[ | |
bool | |
] = None, # if pass-through endpoint is a streaming request | |
): | |
return await pass_through_request( # type: ignore | |
request=request, | |
target=target, | |
custom_headers=custom_headers or {}, | |
user_api_key_dict=user_api_key_dict, | |
forward_headers=_forward_headers, | |
merge_query_params=_merge_query_params, | |
query_params=query_params, | |
stream=stream, | |
custom_body=custom_body, | |
) | |
return endpoint_func | |
def _is_streaming_response(response: httpx.Response) -> bool: | |
_content_type = response.headers.get("content-type") | |
if _content_type is not None and "text/event-stream" in _content_type: | |
return True | |
return False | |
async def initialize_pass_through_endpoints(pass_through_endpoints: list): | |
verbose_proxy_logger.debug("initializing pass through endpoints") | |
from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes | |
from litellm.proxy.proxy_server import app, premium_user | |
for endpoint in pass_through_endpoints: | |
_target = endpoint.get("target", None) | |
_path = endpoint.get("path", None) | |
_custom_headers = endpoint.get("headers", None) | |
_custom_headers = await set_env_variables_in_header( | |
custom_headers=_custom_headers | |
) | |
_forward_headers = endpoint.get("forward_headers", None) | |
_merge_query_params = endpoint.get("merge_query_params", None) | |
_auth = endpoint.get("auth", None) | |
_dependencies = None | |
if _auth is not None and str(_auth).lower() == "true": | |
if premium_user is not True: | |
raise ValueError( | |
"Error Setting Authentication on Pass Through Endpoint: {}".format( | |
CommonProxyErrors.not_premium_user.value | |
) | |
) | |
_dependencies = [Depends(user_api_key_auth)] | |
LiteLLMRoutes.openai_routes.value.append(_path) | |
if _target is None: | |
continue | |
verbose_proxy_logger.debug( | |
"adding pass through endpoint: %s, dependencies: %s", _path, _dependencies | |
) | |
app.add_api_route( # type: ignore | |
path=_path, | |
endpoint=create_pass_through_route( # type: ignore | |
_path, | |
_target, | |
_custom_headers, | |
_forward_headers, | |
_merge_query_params, | |
_dependencies, | |
), | |
methods=["GET", "POST", "PUT", "DELETE", "PATCH"], | |
dependencies=_dependencies, | |
) | |
verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path) | |
async def get_pass_through_endpoints( | |
endpoint_id: Optional[str] = None, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
): | |
""" | |
GET configured pass through endpoint. | |
If no endpoint_id given, return all configured endpoints. | |
""" | |
from litellm.proxy.proxy_server import get_config_general_settings | |
## Get existing pass-through endpoint field value | |
try: | |
response: ConfigFieldInfo = await get_config_general_settings( | |
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict | |
) | |
except Exception: | |
return PassThroughEndpointResponse(endpoints=[]) | |
pass_through_endpoint_data: Optional[List] = response.field_value | |
if pass_through_endpoint_data is None: | |
return PassThroughEndpointResponse(endpoints=[]) | |
returned_endpoints = [] | |
if endpoint_id is None: | |
for endpoint in pass_through_endpoint_data: | |
if isinstance(endpoint, dict): | |
returned_endpoints.append(PassThroughGenericEndpoint(**endpoint)) | |
elif isinstance(endpoint, PassThroughGenericEndpoint): | |
returned_endpoints.append(endpoint) | |
elif endpoint_id is not None: | |
for endpoint in pass_through_endpoint_data: | |
_endpoint: Optional[PassThroughGenericEndpoint] = None | |
if isinstance(endpoint, dict): | |
_endpoint = PassThroughGenericEndpoint(**endpoint) | |
elif isinstance(endpoint, PassThroughGenericEndpoint): | |
_endpoint = endpoint | |
if _endpoint is not None and _endpoint.path == endpoint_id: | |
returned_endpoints.append(_endpoint) | |
return PassThroughEndpointResponse(endpoints=returned_endpoints) | |
async def update_pass_through_endpoints(request: Request, endpoint_id: str): | |
""" | |
Update a pass-through endpoint | |
""" | |
pass | |
async def create_pass_through_endpoints( | |
data: PassThroughGenericEndpoint, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
): | |
""" | |
Create new pass-through endpoint | |
""" | |
from litellm.proxy.proxy_server import ( | |
get_config_general_settings, | |
update_config_general_settings, | |
) | |
## Get existing pass-through endpoint field value | |
try: | |
response: ConfigFieldInfo = await get_config_general_settings( | |
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict | |
) | |
except Exception: | |
response = ConfigFieldInfo( | |
field_name="pass_through_endpoints", field_value=None | |
) | |
## Update field with new endpoint | |
data_dict = data.model_dump() | |
if response.field_value is None: | |
response.field_value = [data_dict] | |
elif isinstance(response.field_value, List): | |
response.field_value.append(data_dict) | |
## Update db | |
updated_data = ConfigFieldUpdate( | |
field_name="pass_through_endpoints", | |
field_value=response.field_value, | |
config_type="general_settings", | |
) | |
await update_config_general_settings( | |
data=updated_data, user_api_key_dict=user_api_key_dict | |
) | |
async def delete_pass_through_endpoints( | |
endpoint_id: str, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
): | |
""" | |
Delete a pass-through endpoint | |
Returns - the deleted endpoint | |
""" | |
from litellm.proxy.proxy_server import ( | |
get_config_general_settings, | |
update_config_general_settings, | |
) | |
## Get existing pass-through endpoint field value | |
try: | |
response: ConfigFieldInfo = await get_config_general_settings( | |
field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict | |
) | |
except Exception: | |
response = ConfigFieldInfo( | |
field_name="pass_through_endpoints", field_value=None | |
) | |
## Update field by removing endpoint | |
pass_through_endpoint_data: Optional[List] = response.field_value | |
response_obj: Optional[PassThroughGenericEndpoint] = None | |
if response.field_value is None or pass_through_endpoint_data is None: | |
raise HTTPException( | |
status_code=400, | |
detail={"error": "There are no pass-through endpoints setup."}, | |
) | |
elif isinstance(response.field_value, List): | |
invalid_idx: Optional[int] = None | |
for idx, endpoint in enumerate(pass_through_endpoint_data): | |
_endpoint: Optional[PassThroughGenericEndpoint] = None | |
if isinstance(endpoint, dict): | |
_endpoint = PassThroughGenericEndpoint(**endpoint) | |
elif isinstance(endpoint, PassThroughGenericEndpoint): | |
_endpoint = endpoint | |
if _endpoint is not None and _endpoint.path == endpoint_id: | |
invalid_idx = idx | |
response_obj = _endpoint | |
if invalid_idx is not None: | |
pass_through_endpoint_data.pop(invalid_idx) | |
## Update db | |
updated_data = ConfigFieldUpdate( | |
field_name="pass_through_endpoints", | |
field_value=pass_through_endpoint_data, | |
config_type="general_settings", | |
) | |
await update_config_general_settings( | |
data=updated_data, user_api_key_dict=user_api_key_dict | |
) | |
if response_obj is None: | |
raise HTTPException( | |
status_code=400, | |
detail={ | |
"error": "Endpoint={} was not found in pass-through endpoint list.".format( | |
endpoint_id | |
) | |
}, | |
) | |
return PassThroughEndpointResponse(endpoints=[response_obj]) | |