from typing import TYPE_CHECKING, Any, Literal, Optional from fastapi import HTTPException, status import litellm if TYPE_CHECKING: from litellm.router import Router as _Router LitellmRouter = _Router else: LitellmRouter = Any ROUTE_ENDPOINT_MAPPING = { "acompletion": "/chat/completions", "atext_completion": "/completions", "aembedding": "/embeddings", "aimage_generation": "/image/generations", "aspeech": "/audio/speech", "atranscription": "/audio/transcriptions", "amoderation": "/moderations", "arerank": "/rerank", "aresponses": "/responses", "alist_input_items": "/responses/{response_id}/input_items", "aimage_edit": "/images/edits", } class ProxyModelNotFoundError(HTTPException): def __init__(self, route: str, model_name: str): detail = { "error": f"{route}: Invalid model name passed in model={model_name}. Call `/v1/models` to view available models for your key." } super().__init__(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) def get_team_id_from_data(data: dict) -> Optional[str]: """ Get the team id from the data's metadata or litellm_metadata params. """ if ( "metadata" in data and data["metadata"] is not None and "user_api_key_team_id" in data["metadata"] ): return data["metadata"].get("user_api_key_team_id") elif ( "litellm_metadata" in data and data["litellm_metadata"] is not None and "user_api_key_team_id" in data["litellm_metadata"] ): return data["litellm_metadata"].get("user_api_key_team_id") return None async def route_request( data: dict, llm_router: Optional[LitellmRouter], user_model: Optional[str], route_type: Literal[ "acompletion", "atext_completion", "aembedding", "aimage_generation", "aspeech", "atranscription", "amoderation", "arerank", "aresponses", "aget_responses", "adelete_responses", "alist_input_items", "_arealtime", # private function for realtime API "aimage_edit", ], ): """ Common helper to route the request """ team_id = get_team_id_from_data(data) router_model_names = llm_router.model_names if llm_router is not None else [] if "api_key" in data or "api_base" in data: return getattr(llm_router, f"{route_type}")(**data) elif "user_config" in data: router_config = data.pop("user_config") user_router = litellm.Router(**router_config) ret_val = getattr(user_router, f"{route_type}")(**data) user_router.discard() return ret_val elif ( route_type == "acompletion" and data.get("model", "") is not None and "," in data.get("model", "") and llm_router is not None ): if data.get("fastest_response", False): return llm_router.abatch_completion_fastest_response(**data) else: models = [model.strip() for model in data.pop("model").split(",")] return llm_router.abatch_completion(models=models, **data) elif llm_router is not None: team_model_name = ( llm_router.map_team_model(data["model"], team_id) if team_id is not None else None ) if team_model_name is not None: data["model"] = team_model_name return getattr(llm_router, f"{route_type}")(**data) elif ( data["model"] in router_model_names or data["model"] in llm_router.get_model_ids() ): return getattr(llm_router, f"{route_type}")(**data) elif ( llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias ): return getattr(llm_router, f"{route_type}")(**data) elif data["model"] in llm_router.deployment_names: return getattr(llm_router, f"{route_type}")( **data, specific_deployment=True ) elif data["model"] not in router_model_names: if llm_router.router_general_settings.pass_through_all_models: return getattr(litellm, f"{route_type}")(**data) elif ( llm_router.default_deployment is not None or len(llm_router.pattern_router.patterns) > 0 ): return getattr(llm_router, f"{route_type}")(**data) elif route_type in [ "amoderation", "aget_responses", "adelete_responses", "alist_input_items", ]: # moderation endpoint does not require `model` parameter return getattr(llm_router, f"{route_type}")(**data) elif user_model is not None: return getattr(litellm, f"{route_type}")(**data) # if no route found then it's a bad request route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type) raise ProxyModelNotFoundError( route=route_name, model_name=data.get("model", ""), )