|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy, httpx |
|
from datetime import datetime |
|
from typing import Dict, List, Optional, Union, Literal, Any |
|
import random, threading, time, traceback, uuid |
|
import litellm, openai |
|
from litellm.caching import RedisCache, InMemoryCache, DualCache |
|
|
|
import logging, asyncio |
|
import inspect, concurrent |
|
from openai import AsyncOpenAI |
|
from collections import defaultdict |
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler |
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler |
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler |
|
from litellm.llms.custom_httpx.azure_dall_e_2 import ( |
|
CustomHTTPTransport, |
|
AsyncCustomHTTPTransport, |
|
) |
|
from litellm.utils import ModelResponse, CustomStreamWrapper |
|
import copy |
|
from litellm._logging import verbose_router_logger |
|
import logging |
|
|
|
|
|
class Router: |
|
""" |
|
Example usage: |
|
```python |
|
from litellm import Router |
|
model_list = [ |
|
{ |
|
"model_name": "azure-gpt-3.5-turbo", # model alias |
|
"litellm_params": { # params for litellm completion/embedding call |
|
"model": "azure/<your-deployment-name-1>", |
|
"api_key": <your-api-key>, |
|
"api_version": <your-api-version>, |
|
"api_base": <your-api-base> |
|
}, |
|
}, |
|
{ |
|
"model_name": "azure-gpt-3.5-turbo", # model alias |
|
"litellm_params": { # params for litellm completion/embedding call |
|
"model": "azure/<your-deployment-name-2>", |
|
"api_key": <your-api-key>, |
|
"api_version": <your-api-version>, |
|
"api_base": <your-api-base> |
|
}, |
|
}, |
|
{ |
|
"model_name": "openai-gpt-3.5-turbo", # model alias |
|
"litellm_params": { # params for litellm completion/embedding call |
|
"model": "gpt-3.5-turbo", |
|
"api_key": <your-api-key>, |
|
}, |
|
] |
|
|
|
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) |
|
``` |
|
""" |
|
|
|
model_names: List = [] |
|
cache_responses: Optional[bool] = False |
|
default_cache_time_seconds: int = 1 * 60 * 60 |
|
num_retries: int = 0 |
|
tenacity = None |
|
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None |
|
lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None |
|
|
|
def __init__( |
|
self, |
|
model_list: Optional[list] = None, |
|
|
|
redis_url: Optional[str] = None, |
|
redis_host: Optional[str] = None, |
|
redis_port: Optional[int] = None, |
|
redis_password: Optional[str] = None, |
|
cache_responses: Optional[bool] = False, |
|
cache_kwargs: dict = {}, |
|
caching_groups: Optional[ |
|
List[tuple] |
|
] = None, |
|
client_ttl: int = 3600, |
|
|
|
num_retries: int = 0, |
|
timeout: Optional[float] = None, |
|
default_litellm_params={}, |
|
set_verbose: bool = False, |
|
fallbacks: List = [], |
|
allowed_fails: Optional[int] = None, |
|
context_window_fallbacks: List = [], |
|
model_group_alias: Optional[dict] = {}, |
|
retry_after: int = 0, |
|
routing_strategy: Literal[ |
|
"simple-shuffle", |
|
"least-busy", |
|
"usage-based-routing", |
|
"latency-based-routing", |
|
] = "simple-shuffle", |
|
routing_strategy_args: dict = {}, |
|
) -> None: |
|
self.set_verbose = set_verbose |
|
self.deployment_names: List = ( |
|
[] |
|
) |
|
self.deployment_latency_map = {} |
|
|
|
cache_type: Literal["local", "redis"] = "local" |
|
redis_cache = None |
|
cache_config = {} |
|
self.client_ttl = client_ttl |
|
if redis_url is not None or ( |
|
redis_host is not None |
|
and redis_port is not None |
|
and redis_password is not None |
|
): |
|
cache_type = "redis" |
|
|
|
if redis_url is not None: |
|
cache_config["url"] = redis_url |
|
|
|
if redis_host is not None: |
|
cache_config["host"] = redis_host |
|
|
|
if redis_port is not None: |
|
cache_config["port"] = str(redis_port) |
|
|
|
if redis_password is not None: |
|
cache_config["password"] = redis_password |
|
|
|
|
|
cache_config.update(cache_kwargs) |
|
redis_cache = RedisCache(**cache_config) |
|
|
|
if cache_responses: |
|
if litellm.cache is None: |
|
|
|
litellm.cache = litellm.Cache(type=cache_type, **cache_config) |
|
self.cache_responses = cache_responses |
|
self.cache = DualCache( |
|
redis_cache=redis_cache, in_memory_cache=InMemoryCache() |
|
) |
|
|
|
if model_list: |
|
model_list = copy.deepcopy(model_list) |
|
self.set_model_list(model_list) |
|
self.healthy_deployments: List = self.model_list |
|
for m in model_list: |
|
self.deployment_latency_map[m["litellm_params"]["model"]] = 0 |
|
|
|
self.allowed_fails = allowed_fails or litellm.allowed_fails |
|
self.failed_calls = ( |
|
InMemoryCache() |
|
) |
|
self.num_retries = num_retries or litellm.num_retries or 0 |
|
self.timeout = timeout or litellm.request_timeout |
|
self.retry_after = retry_after |
|
self.routing_strategy = routing_strategy |
|
self.fallbacks = fallbacks or litellm.fallbacks |
|
self.context_window_fallbacks = ( |
|
context_window_fallbacks or litellm.context_window_fallbacks |
|
) |
|
self.model_exception_map: dict = ( |
|
{} |
|
) |
|
self.total_calls: defaultdict = defaultdict( |
|
int |
|
) |
|
self.fail_calls: defaultdict = defaultdict( |
|
int |
|
) |
|
self.success_calls: defaultdict = defaultdict( |
|
int |
|
) |
|
self.previous_models: List = ( |
|
[] |
|
) |
|
self.model_group_alias: dict = ( |
|
model_group_alias or {} |
|
) |
|
|
|
|
|
self.chat = litellm.Chat(params=default_litellm_params) |
|
|
|
|
|
self.default_litellm_params = default_litellm_params |
|
self.default_litellm_params.setdefault("timeout", timeout) |
|
self.default_litellm_params.setdefault("max_retries", 0) |
|
self.default_litellm_params.setdefault("metadata", {}).update( |
|
{"caching_groups": caching_groups} |
|
) |
|
|
|
|
|
if routing_strategy == "least-busy": |
|
self.leastbusy_logger = LeastBusyLoggingHandler( |
|
router_cache=self.cache, model_list=self.model_list |
|
) |
|
|
|
if isinstance(litellm.input_callback, list): |
|
litellm.input_callback.append(self.leastbusy_logger) |
|
else: |
|
litellm.input_callback = [self.leastbusy_logger] |
|
if isinstance(litellm.callbacks, list): |
|
litellm.callbacks.append(self.leastbusy_logger) |
|
elif routing_strategy == "usage-based-routing": |
|
self.lowesttpm_logger = LowestTPMLoggingHandler( |
|
router_cache=self.cache, model_list=self.model_list |
|
) |
|
if isinstance(litellm.callbacks, list): |
|
litellm.callbacks.append(self.lowesttpm_logger) |
|
elif routing_strategy == "latency-based-routing": |
|
self.lowestlatency_logger = LowestLatencyLoggingHandler( |
|
router_cache=self.cache, |
|
model_list=self.model_list, |
|
routing_args=routing_strategy_args, |
|
) |
|
if isinstance(litellm.callbacks, list): |
|
litellm.callbacks.append(self.lowestlatency_logger) |
|
|
|
if isinstance(litellm.failure_callback, list): |
|
litellm.failure_callback.append(self.deployment_callback_on_failure) |
|
else: |
|
litellm.failure_callback = [self.deployment_callback_on_failure] |
|
verbose_router_logger.debug( |
|
f"Intialized router with Routing strategy: {self.routing_strategy}\n" |
|
) |
|
|
|
|
|
|
|
def completion( |
|
self, model: str, messages: List[Dict[str, str]], **kwargs |
|
) -> Union[ModelResponse, CustomStreamWrapper]: |
|
""" |
|
Example usage: |
|
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}] |
|
""" |
|
try: |
|
kwargs["model"] = model |
|
kwargs["messages"] = messages |
|
kwargs["original_function"] = self._completion |
|
timeout = kwargs.get("request_timeout", self.timeout) |
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) |
|
kwargs.setdefault("metadata", {}).update({"model_group": model}) |
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: |
|
|
|
future = executor.submit(self.function_with_fallbacks, **kwargs) |
|
response = future.result(timeout=timeout) |
|
|
|
return response |
|
except Exception as e: |
|
raise e |
|
|
|
def _completion(self, model: str, messages: List[Dict[str, str]], **kwargs): |
|
try: |
|
|
|
deployment = self.get_available_deployment( |
|
model=model, |
|
messages=messages, |
|
specific_deployment=kwargs.pop("specific_deployment", None), |
|
) |
|
kwargs.setdefault("metadata", {}).update( |
|
{"deployment": deployment["litellm_params"]["model"]} |
|
) |
|
data = deployment["litellm_params"].copy() |
|
kwargs["model_info"] = deployment.get("model_info", {}) |
|
for k, v in self.default_litellm_params.items(): |
|
if ( |
|
k not in kwargs |
|
): |
|
kwargs[k] = v |
|
elif k == "metadata": |
|
kwargs[k].update(v) |
|
potential_model_client = self._get_client( |
|
deployment=deployment, kwargs=kwargs |
|
) |
|
|
|
dynamic_api_key = kwargs.get("api_key", None) |
|
if ( |
|
dynamic_api_key is not None |
|
and potential_model_client is not None |
|
and dynamic_api_key != potential_model_client.api_key |
|
): |
|
model_client = None |
|
else: |
|
model_client = potential_model_client |
|
|
|
return litellm.completion( |
|
**{ |
|
**data, |
|
"messages": messages, |
|
"caching": self.cache_responses, |
|
"client": model_client, |
|
**kwargs, |
|
} |
|
) |
|
except Exception as e: |
|
raise e |
|
|
|
async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): |
|
try: |
|
kwargs["model"] = model |
|
kwargs["messages"] = messages |
|
kwargs["original_function"] = self._acompletion |
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) |
|
timeout = kwargs.get("request_timeout", self.timeout) |
|
kwargs.setdefault("metadata", {}).update({"model_group": model}) |
|
|
|
response = await self.async_function_with_fallbacks(**kwargs) |
|
|
|
return response |
|
except Exception as e: |
|
raise e |
|
|
|
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): |
|
model_name = None |
|
try: |
|
verbose_router_logger.debug( |
|
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" |
|
) |
|
deployment = self.get_available_deployment( |
|
model=model, |
|
messages=messages, |
|
specific_deployment=kwargs.pop("specific_deployment", None), |
|
) |
|
kwargs.setdefault("metadata", {}).update( |
|
{"deployment": deployment["litellm_params"]["model"]} |
|
) |
|
kwargs["model_info"] = deployment.get("model_info", {}) |
|
data = deployment["litellm_params"].copy() |
|
model_name = data["model"] |
|
for k, v in self.default_litellm_params.items(): |
|
if ( |
|
k not in kwargs |
|
): |
|
kwargs[k] = v |
|
elif k == "metadata": |
|
kwargs[k].update(v) |
|
|
|
potential_model_client = self._get_client( |
|
deployment=deployment, kwargs=kwargs, client_type="async" |
|
) |
|
|
|
dynamic_api_key = kwargs.get("api_key", None) |
|
if ( |
|
dynamic_api_key is not None |
|
and potential_model_client is not None |
|
and dynamic_api_key != potential_model_client.api_key |
|
): |
|
model_client = None |
|
else: |
|
model_client = potential_model_client |
|
self.total_calls[model_name] += 1 |
|
response = await litellm.acompletion( |
|
**{ |
|
**data, |
|
"messages": messages, |
|
"caching": self.cache_responses, |
|
"client": model_client, |
|
"timeout": self.timeout, |
|
**kwargs, |
|
} |
|
) |
|
self.success_calls[model_name] += 1 |
|
verbose_router_logger.info( |
|
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" |
|
) |
|
return response |
|
except Exception as e: |
|
verbose_router_logger.info( |
|
f"litellm.acompletion(model={model_name})\033[31m Exception {str(e)}\033[0m" |
|
) |
|
if model_name is not None: |
|
self.fail_calls[model_name] += 1 |
|
raise e |
|
|
|
def image_generation(self, prompt: str, model: str, **kwargs): |
|
try: |
|
kwargs["model"] = model |
|
kwargs["prompt"] = prompt |
|
kwargs["original_function"] = self._image_generation |
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) |
|
timeout = kwargs.get("request_timeout", self.timeout) |
|
kwargs.setdefault("metadata", {}).update({"model_group": model}) |
|
response = self.function_with_fallbacks(**kwargs) |
|
|
|
return response |
|
except Exception as e: |
|
raise e |
|
|
|
def _image_generation(self, prompt: str, model: str, **kwargs): |
|
try: |
|
verbose_router_logger.debug( |
|
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" |
|
) |
|
deployment = self.get_available_deployment( |
|
model=model, |
|
messages=[{"role": "user", "content": "prompt"}], |
|
specific_deployment=kwargs.pop("specific_deployment", None), |
|
) |
|
kwargs.setdefault("metadata", {}).update( |
|
{"deployment": deployment["litellm_params"]["model"]} |
|
) |
|
kwargs["model_info"] = deployment.get("model_info", {}) |
|
data = deployment["litellm_params"].copy() |
|
model_name = data["model"] |
|
for k, v in self.default_litellm_params.items(): |
|
if ( |
|
k not in kwargs |
|
): |
|
kwargs[k] = v |
|
elif k == "metadata": |
|
kwargs[k].update(v) |
|
|
|
potential_model_client = self._get_client( |
|
deployment=deployment, kwargs=kwargs, client_type="async" |
|
) |
|
|
|
dynamic_api_key = kwargs.get("api_key", None) |
|
if ( |
|
dynamic_api_key is not None |
|
and potential_model_client is not None |
|
and dynamic_api_key != potential_model_client.api_key |
|
): |
|
model_client = None |
|
else: |
|
model_client = potential_model_client |
|
|
|
self.total_calls[model_name] += 1 |
|
response = litellm.image_generation( |
|
**{ |
|
**data, |
|
"prompt": prompt, |
|
"caching": self.cache_responses, |
|
"client": model_client, |
|
**kwargs, |
|
} |
|
) |
|
self.success_calls[model_name] += 1 |
|
verbose_router_logger.info( |
|
f"litellm.image_generation(model={model_name})\033[32m 200 OK\033[0m" |
|
) |
|
return response |
|
except Exception as e: |
|
verbose_router_logger.info( |
|
f"litellm.image_generation(model={model_name})\033[31m Exception {str(e)}\033[0m" |
|
) |
|
if model_name is not None: |
|
self.fail_calls[model_name] += 1 |
|
raise e |
|
|
|
async def aimage_generation(self, prompt: str, model: str, **kwargs): |
|
try: |
|
kwargs["model"] = model |
|
kwargs["prompt"] = prompt |
|
kwargs["original_function"] = self._aimage_generation |
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) |
|
timeout = kwargs.get("request_timeout", self.timeout) |
|
kwargs.setdefault("metadata", {}).update({"model_group": model}) |
|
response = await self.async_function_with_fallbacks(**kwargs) |
|
|
|
return response |
|
except Exception as e: |
|
raise e |
|
|
|
async def _aimage_generation(self, prompt: str, model: str, **kwargs): |
|
try: |
|
verbose_router_logger.debug( |
|
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" |
|
) |
|
deployment = self.get_available_deployment( |
|
model=model, |
|
messages=[{"role": "user", "content": "prompt"}], |
|
specific_deployment=kwargs.pop("specific_deployment", None), |
|
) |
|
kwargs.setdefault("metadata", {}).update( |
|
{"deployment": deployment["litellm_params"]["model"]} |
|
) |
|
kwargs["model_info"] = deployment.get("model_info", {}) |
|
data = deployment["litellm_params"].copy() |
|
model_name = data["model"] |
|
for k, v in self.default_litellm_params.items(): |
|
if ( |
|
k not in kwargs |
|
): |
|
kwargs[k] = v |
|
elif k == "metadata": |
|
kwargs[k].update(v) |
|
|
|
potential_model_client = self._get_client( |
|
deployment=deployment, kwargs=kwargs, client_type="async" |
|
) |
|
|
|
dynamic_api_key = kwargs.get("api_key", None) |
|
if ( |
|
dynamic_api_key is not None |
|
and potential_model_client is not None |
|
and dynamic_api_key != potential_model_client.api_key |
|
): |
|
model_client = None |
|
else: |
|
model_client = potential_model_client |
|
|
|
self.total_calls[model_name] += 1 |
|
response = await litellm.aimage_generation( |
|
**{ |
|
**data, |
|
"prompt": prompt, |
|
"caching": self.cache_responses, |
|
"client": model_client, |
|
**kwargs, |
|
} |
|
) |
|
self.success_calls[model_name] += 1 |
|
verbose_router_logger.info( |
|
f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m" |
|
) |
|
return response |
|
except Exception as e: |
|
verbose_router_logger.info( |
|
f"litellm.aimage_generation(model={model_name})\033[31m Exception {str(e)}\033[0m" |
|
) |
|
if model_name is not None: |
|
self.fail_calls[model_name] += 1 |
|
raise e |
|
|
|
def text_completion( |
|
self, |
|
model: str, |
|
prompt: str, |
|
is_retry: Optional[bool] = False, |
|
is_fallback: Optional[bool] = False, |
|
is_async: Optional[bool] = False, |
|
**kwargs, |
|
): |
|
try: |
|
kwargs["model"] = model |
|
kwargs["prompt"] = prompt |
|
kwargs["original_function"] = self._acompletion |
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) |
|
timeout = kwargs.get("request_timeout", self.timeout) |
|
kwargs.setdefault("metadata", {}).update({"model_group": model}) |
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
deployment = self.get_available_deployment( |
|
model=model, |
|
messages=messages, |
|
specific_deployment=kwargs.pop("specific_deployment", None), |
|
) |
|
|
|
data = deployment["litellm_params"].copy() |
|
for k, v in self.default_litellm_params.items(): |
|
if ( |
|
k not in kwargs |
|
): |
|
kwargs[k] = v |
|
elif k == "metadata": |
|
kwargs[k].update(v) |
|
|
|
|
|
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) |
|
except Exception as e: |
|
if self.num_retries > 0: |
|
kwargs["model"] = model |
|
kwargs["messages"] = messages |
|
kwargs["original_function"] = self.completion |
|
return self.function_with_retries(**kwargs) |
|
else: |
|
raise e |
|
|
|
async def atext_completion( |
|
self, |
|
model: str, |
|
prompt: str, |
|
is_retry: Optional[bool] = False, |
|
is_fallback: Optional[bool] = False, |
|
is_async: Optional[bool] = False, |
|
**kwargs, |
|
): |
|
try: |
|
kwargs["model"] = model |
|
kwargs["prompt"] = prompt |
|
kwargs["original_function"] = self._atext_completion |
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) |
|
timeout = kwargs.get("request_timeout", self.timeout) |
|
kwargs.setdefault("metadata", {}).update({"model_group": model}) |
|
response = await self.async_function_with_fallbacks(**kwargs) |
|
|
|
return response |
|
except Exception as e: |
|
raise e |
|
|
|
async def _atext_completion(self, model: str, prompt: str, **kwargs): |
|
try: |
|
verbose_router_logger.debug( |
|
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" |
|
) |
|
deployment = self.get_available_deployment( |
|
model=model, |
|
messages=[{"role": "user", "content": prompt}], |
|
specific_deployment=kwargs.pop("specific_deployment", None), |
|
) |
|
kwargs.setdefault("metadata", {}).update( |
|
{"deployment": deployment["litellm_params"]["model"]} |
|
) |
|
kwargs["model_info"] = deployment.get("model_info", {}) |
|
data = deployment["litellm_params"].copy() |
|
model_name = data["model"] |
|
for k, v in self.default_litellm_params.items(): |
|
if ( |
|
k not in kwargs |
|
): |
|
kwargs[k] = v |
|
elif k == "metadata": |
|
kwargs[k].update(v) |
|
|
|
potential_model_client = self._get_client( |
|
deployment=deployment, kwargs=kwargs, client_type="async" |
|
) |
|
|
|
dynamic_api_key = kwargs.get("api_key", None) |
|
if ( |
|
dynamic_api_key is not None |
|
and potential_model_client is not None |
|
and dynamic_api_key != potential_model_client.api_key |
|
): |
|
model_client = None |
|
else: |
|
model_client = potential_model_client |
|
self.total_calls[model_name] += 1 |
|
response = await litellm.atext_completion( |
|
**{ |
|
**data, |
|
"prompt": prompt, |
|
"caching": self.cache_responses, |
|
"client": model_client, |
|
"timeout": self.timeout, |
|
**kwargs, |
|
} |
|
) |
|
self.success_calls[model_name] += 1 |
|
verbose_router_logger.info( |
|
f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m" |
|
) |
|
return response |
|
except Exception as e: |
|
verbose_router_logger.info( |
|
f"litellm.atext_completion(model={model_name})\033[31m Exception {str(e)}\033[0m" |
|
) |
|
if model_name is not None: |
|
self.fail_calls[model_name] += 1 |
|
raise e |
|
|
|
def embedding( |
|
self, |
|
model: str, |
|
input: Union[str, List], |
|
is_async: Optional[bool] = False, |
|
**kwargs, |
|
) -> Union[List[float], None]: |
|
|
|
deployment = self.get_available_deployment( |
|
model=model, |
|
input=input, |
|
specific_deployment=kwargs.pop("specific_deployment", None), |
|
) |
|
kwargs.setdefault("model_info", {}) |
|
kwargs.setdefault("metadata", {}).update( |
|
{"model_group": model, "deployment": deployment["litellm_params"]["model"]} |
|
) |
|
data = deployment["litellm_params"].copy() |
|
for k, v in self.default_litellm_params.items(): |
|
if ( |
|
k not in kwargs |
|
): |
|
kwargs[k] = v |
|
elif k == "metadata": |
|
kwargs[k].update(v) |
|
potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs) |
|
|
|
dynamic_api_key = kwargs.get("api_key", None) |
|
if ( |
|
dynamic_api_key is not None |
|
and potential_model_client is not None |
|
and dynamic_api_key != potential_model_client.api_key |
|
): |
|
model_client = None |
|
else: |
|
model_client = potential_model_client |
|
return litellm.embedding( |
|
**{ |
|
**data, |
|
"input": input, |
|
"caching": self.cache_responses, |
|
"client": model_client, |
|
**kwargs, |
|
} |
|
) |
|
|
|
async def aembedding( |
|
self, |
|
model: str, |
|
input: Union[str, List], |
|
is_async: Optional[bool] = True, |
|
**kwargs, |
|
) -> Union[List[float], None]: |
|
try: |
|
kwargs["model"] = model |
|
kwargs["input"] = input |
|
kwargs["original_function"] = self._aembedding |
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) |
|
timeout = kwargs.get("request_timeout", self.timeout) |
|
kwargs.setdefault("metadata", {}).update({"model_group": model}) |
|
response = await self.async_function_with_fallbacks(**kwargs) |
|
return response |
|
except Exception as e: |
|
raise e |
|
|
|
async def _aembedding(self, input: Union[str, List], model: str, **kwargs): |
|
try: |
|
verbose_router_logger.debug( |
|
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}" |
|
) |
|
deployment = self.get_available_deployment( |
|
model=model, |
|
input=input, |
|
specific_deployment=kwargs.pop("specific_deployment", None), |
|
) |
|
kwargs.setdefault("metadata", {}).update( |
|
{"deployment": deployment["litellm_params"]["model"]} |
|
) |
|
kwargs["model_info"] = deployment.get("model_info", {}) |
|
data = deployment["litellm_params"].copy() |
|
model_name = data["model"] |
|
for k, v in self.default_litellm_params.items(): |
|
if ( |
|
k not in kwargs |
|
): |
|
kwargs[k] = v |
|
elif k == "metadata": |
|
kwargs[k].update(v) |
|
|
|
potential_model_client = self._get_client( |
|
deployment=deployment, kwargs=kwargs, client_type="async" |
|
) |
|
|
|
dynamic_api_key = kwargs.get("api_key", None) |
|
if ( |
|
dynamic_api_key is not None |
|
and potential_model_client is not None |
|
and dynamic_api_key != potential_model_client.api_key |
|
): |
|
model_client = None |
|
else: |
|
model_client = potential_model_client |
|
|
|
self.total_calls[model_name] += 1 |
|
response = await litellm.aembedding( |
|
**{ |
|
**data, |
|
"input": input, |
|
"caching": self.cache_responses, |
|
"client": model_client, |
|
**kwargs, |
|
} |
|
) |
|
self.success_calls[model_name] += 1 |
|
verbose_router_logger.info( |
|
f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m" |
|
) |
|
return response |
|
except Exception as e: |
|
verbose_router_logger.info( |
|
f"litellm.aembedding(model={model_name})\033[31m Exception {str(e)}\033[0m" |
|
) |
|
if model_name is not None: |
|
self.fail_calls[model_name] += 1 |
|
raise e |
|
|
|
async def async_function_with_fallbacks(self, *args, **kwargs): |
|
""" |
|
Try calling the function_with_retries |
|
If it fails after num_retries, fall back to another model group |
|
""" |
|
model_group = kwargs.get("model") |
|
fallbacks = kwargs.get("fallbacks", self.fallbacks) |
|
context_window_fallbacks = kwargs.get( |
|
"context_window_fallbacks", self.context_window_fallbacks |
|
) |
|
try: |
|
response = await self.async_function_with_retries(*args, **kwargs) |
|
verbose_router_logger.debug(f"Async Response: {response}") |
|
return response |
|
except Exception as e: |
|
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}") |
|
original_exception = e |
|
fallback_model_group = None |
|
try: |
|
if ( |
|
hasattr(e, "status_code") and e.status_code == 400 |
|
): |
|
raise e |
|
verbose_router_logger.debug(f"Trying to fallback b/w models") |
|
if ( |
|
isinstance(e, litellm.ContextWindowExceededError) |
|
and context_window_fallbacks is not None |
|
): |
|
fallback_model_group = None |
|
for ( |
|
item |
|
) in context_window_fallbacks: |
|
if list(item.keys())[0] == model_group: |
|
fallback_model_group = item[model_group] |
|
break |
|
|
|
if fallback_model_group is None: |
|
raise original_exception |
|
|
|
for mg in fallback_model_group: |
|
""" |
|
Iterate through the model groups and try calling that deployment |
|
""" |
|
try: |
|
kwargs["model"] = mg |
|
response = await self.async_function_with_retries( |
|
*args, **kwargs |
|
) |
|
return response |
|
except Exception as e: |
|
pass |
|
elif fallbacks is not None: |
|
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") |
|
for item in fallbacks: |
|
if list(item.keys())[0] == model_group: |
|
fallback_model_group = item[model_group] |
|
break |
|
if fallback_model_group is None: |
|
verbose_router_logger.info( |
|
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" |
|
) |
|
raise original_exception |
|
for mg in fallback_model_group: |
|
""" |
|
Iterate through the model groups and try calling that deployment |
|
""" |
|
try: |
|
|
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) |
|
verbose_router_logger.info( |
|
f"Falling back to model_group = {mg}" |
|
) |
|
kwargs["model"] = mg |
|
kwargs["metadata"]["model_group"] = mg |
|
response = await self.async_function_with_retries( |
|
*args, **kwargs |
|
) |
|
return response |
|
except Exception as e: |
|
raise e |
|
except Exception as e: |
|
verbose_router_logger.debug(f"An exception occurred - {str(e)}") |
|
traceback.print_exc() |
|
raise original_exception |
|
|
|
async def async_function_with_retries(self, *args, **kwargs): |
|
verbose_router_logger.debug( |
|
f"Inside async function with retries: args - {args}; kwargs - {kwargs}" |
|
) |
|
original_function = kwargs.pop("original_function") |
|
fallbacks = kwargs.pop("fallbacks", self.fallbacks) |
|
context_window_fallbacks = kwargs.pop( |
|
"context_window_fallbacks", self.context_window_fallbacks |
|
) |
|
verbose_router_logger.debug( |
|
f"async function w/ retries: original_function - {original_function}" |
|
) |
|
num_retries = kwargs.pop("num_retries") |
|
try: |
|
|
|
response = await original_function(*args, **kwargs) |
|
return response |
|
except Exception as e: |
|
original_exception = e |
|
|
|
if ( |
|
isinstance(original_exception, litellm.ContextWindowExceededError) |
|
and context_window_fallbacks is None |
|
) or ( |
|
isinstance(original_exception, openai.RateLimitError) |
|
and fallbacks is not None |
|
): |
|
raise original_exception |
|
|
|
|
|
if "No models available" in str(e): |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=num_retries, |
|
max_retries=num_retries, |
|
min_timeout=self.retry_after, |
|
) |
|
await asyncio.sleep(timeout) |
|
elif hasattr(original_exception, "status_code") and litellm._should_retry( |
|
status_code=original_exception.status_code |
|
): |
|
if hasattr(original_exception, "response") and hasattr( |
|
original_exception.response, "headers" |
|
): |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=num_retries, |
|
max_retries=num_retries, |
|
response_headers=original_exception.response.headers, |
|
min_timeout=self.retry_after, |
|
) |
|
else: |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=num_retries, |
|
max_retries=num_retries, |
|
min_timeout=self.retry_after, |
|
) |
|
await asyncio.sleep(timeout) |
|
else: |
|
raise original_exception |
|
|
|
|
|
if num_retries > 0: |
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) |
|
|
|
for current_attempt in range(num_retries): |
|
verbose_router_logger.debug( |
|
f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}" |
|
) |
|
try: |
|
|
|
response = await original_function(*args, **kwargs) |
|
if inspect.iscoroutinefunction( |
|
response |
|
): |
|
response = await response |
|
return response |
|
|
|
except Exception as e: |
|
|
|
kwargs = self.log_retry(kwargs=kwargs, e=e) |
|
remaining_retries = num_retries - current_attempt |
|
if "No models available" in str(e): |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=remaining_retries, |
|
max_retries=num_retries, |
|
min_timeout=self.retry_after, |
|
) |
|
await asyncio.sleep(timeout) |
|
elif ( |
|
hasattr(e, "status_code") |
|
and hasattr(e, "response") |
|
and litellm._should_retry(status_code=e.status_code) |
|
): |
|
if hasattr(e.response, "headers"): |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=remaining_retries, |
|
max_retries=num_retries, |
|
response_headers=e.response.headers, |
|
min_timeout=self.retry_after, |
|
) |
|
else: |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=remaining_retries, |
|
max_retries=num_retries, |
|
min_timeout=self.retry_after, |
|
) |
|
await asyncio.sleep(timeout) |
|
else: |
|
raise e |
|
raise original_exception |
|
|
|
def function_with_fallbacks(self, *args, **kwargs): |
|
""" |
|
Try calling the function_with_retries |
|
If it fails after num_retries, fall back to another model group |
|
""" |
|
model_group = kwargs.get("model") |
|
fallbacks = kwargs.get("fallbacks", self.fallbacks) |
|
context_window_fallbacks = kwargs.get( |
|
"context_window_fallbacks", self.context_window_fallbacks |
|
) |
|
try: |
|
response = self.function_with_retries(*args, **kwargs) |
|
return response |
|
except Exception as e: |
|
original_exception = e |
|
verbose_router_logger.debug(f"An exception occurs {original_exception}") |
|
try: |
|
verbose_router_logger.debug( |
|
f"Trying to fallback b/w models. Initial model group: {model_group}" |
|
) |
|
if ( |
|
isinstance(e, litellm.ContextWindowExceededError) |
|
and context_window_fallbacks is not None |
|
): |
|
fallback_model_group = None |
|
|
|
for ( |
|
item |
|
) in context_window_fallbacks: |
|
if list(item.keys())[0] == model_group: |
|
fallback_model_group = item[model_group] |
|
break |
|
|
|
if fallback_model_group is None: |
|
raise original_exception |
|
|
|
for mg in fallback_model_group: |
|
""" |
|
Iterate through the model groups and try calling that deployment |
|
""" |
|
try: |
|
|
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) |
|
kwargs["model"] = mg |
|
response = self.function_with_fallbacks(*args, **kwargs) |
|
return response |
|
except Exception as e: |
|
pass |
|
elif fallbacks is not None: |
|
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") |
|
fallback_model_group = None |
|
for item in fallbacks: |
|
if list(item.keys())[0] == model_group: |
|
fallback_model_group = item[model_group] |
|
break |
|
|
|
if fallback_model_group is None: |
|
raise original_exception |
|
|
|
for mg in fallback_model_group: |
|
""" |
|
Iterate through the model groups and try calling that deployment |
|
""" |
|
try: |
|
|
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) |
|
kwargs["model"] = mg |
|
response = self.function_with_fallbacks(*args, **kwargs) |
|
return response |
|
except Exception as e: |
|
raise e |
|
except Exception as e: |
|
raise e |
|
raise original_exception |
|
|
|
def function_with_retries(self, *args, **kwargs): |
|
""" |
|
Try calling the model 3 times. Shuffle between available deployments. |
|
""" |
|
verbose_router_logger.debug( |
|
f"Inside function with retries: args - {args}; kwargs - {kwargs}" |
|
) |
|
original_function = kwargs.pop("original_function") |
|
num_retries = kwargs.pop("num_retries") |
|
fallbacks = kwargs.pop("fallbacks", self.fallbacks) |
|
context_window_fallbacks = kwargs.pop( |
|
"context_window_fallbacks", self.context_window_fallbacks |
|
) |
|
try: |
|
|
|
response = original_function(*args, **kwargs) |
|
return response |
|
except Exception as e: |
|
original_exception = e |
|
verbose_router_logger.debug( |
|
f"num retries in function with retries: {num_retries}" |
|
) |
|
|
|
if ( |
|
isinstance(original_exception, litellm.ContextWindowExceededError) |
|
and context_window_fallbacks is None |
|
) or ( |
|
isinstance(original_exception, openai.RateLimitError) |
|
and fallbacks is not None |
|
): |
|
raise original_exception |
|
|
|
if num_retries > 0: |
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) |
|
|
|
for current_attempt in range(num_retries): |
|
verbose_router_logger.debug( |
|
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}" |
|
) |
|
try: |
|
|
|
response = original_function(*args, **kwargs) |
|
return response |
|
|
|
except Exception as e: |
|
|
|
kwargs = self.log_retry(kwargs=kwargs, e=e) |
|
remaining_retries = num_retries - current_attempt |
|
if "No models available" in str(e): |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=remaining_retries, |
|
max_retries=num_retries, |
|
min_timeout=self.retry_after, |
|
) |
|
time.sleep(timeout) |
|
elif ( |
|
hasattr(e, "status_code") |
|
and hasattr(e, "response") |
|
and litellm._should_retry(status_code=e.status_code) |
|
): |
|
if hasattr(e.response, "headers"): |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=remaining_retries, |
|
max_retries=num_retries, |
|
response_headers=e.response.headers, |
|
min_timeout=self.retry_after, |
|
) |
|
else: |
|
timeout = litellm._calculate_retry_after( |
|
remaining_retries=remaining_retries, |
|
max_retries=num_retries, |
|
min_timeout=self.retry_after, |
|
) |
|
time.sleep(timeout) |
|
else: |
|
raise e |
|
raise original_exception |
|
|
|
|
|
|
|
def deployment_callback_on_failure( |
|
self, |
|
kwargs, |
|
completion_response, |
|
start_time, |
|
end_time, |
|
): |
|
try: |
|
exception = kwargs.get("exception", None) |
|
exception_type = type(exception) |
|
exception_status = getattr(exception, "status_code", "") |
|
exception_cause = getattr(exception, "__cause__", "") |
|
exception_message = getattr(exception, "message", "") |
|
exception_str = ( |
|
str(exception_type) |
|
+ "Status: " |
|
+ str(exception_status) |
|
+ "Message: " |
|
+ str(exception_cause) |
|
+ str(exception_message) |
|
+ "Full exception" |
|
+ str(exception) |
|
) |
|
model_name = kwargs.get("model", None) |
|
custom_llm_provider = kwargs.get("litellm_params", {}).get( |
|
"custom_llm_provider", None |
|
) |
|
metadata = kwargs.get("litellm_params", {}).get("metadata", None) |
|
_model_info = kwargs.get("litellm_params", {}).get("model_info", {}) |
|
if isinstance(_model_info, dict): |
|
deployment_id = _model_info.get("id", None) |
|
self._set_cooldown_deployments( |
|
deployment_id |
|
) |
|
if metadata: |
|
deployment = metadata.get("deployment", None) |
|
deployment_exceptions = self.model_exception_map.get(deployment, []) |
|
deployment_exceptions.append(exception_str) |
|
self.model_exception_map[deployment] = deployment_exceptions |
|
verbose_router_logger.debug("\nEXCEPTION FOR DEPLOYMENTS\n") |
|
verbose_router_logger.debug(self.model_exception_map) |
|
for model in self.model_exception_map: |
|
verbose_router_logger.debug( |
|
f"Model {model} had {len(self.model_exception_map[model])} exception" |
|
) |
|
if custom_llm_provider: |
|
model_name = f"{custom_llm_provider}/{model_name}" |
|
|
|
except Exception as e: |
|
raise e |
|
|
|
def log_retry(self, kwargs: dict, e: Exception) -> dict: |
|
""" |
|
When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing |
|
""" |
|
try: |
|
|
|
previous_model = { |
|
"exception_type": type(e).__name__, |
|
"exception_string": str(e), |
|
} |
|
for ( |
|
k, |
|
v, |
|
) in ( |
|
kwargs.items() |
|
): |
|
if k != "metadata": |
|
previous_model[k] = v |
|
elif k == "metadata" and isinstance(v, dict): |
|
previous_model["metadata"] = {} |
|
for metadata_k, metadata_v in kwargs["metadata"].items(): |
|
if metadata_k != "previous_models": |
|
previous_model[k][metadata_k] = metadata_v |
|
self.previous_models.append(previous_model) |
|
kwargs["metadata"]["previous_models"] = self.previous_models |
|
return kwargs |
|
except Exception as e: |
|
raise e |
|
|
|
def _set_cooldown_deployments(self, deployment: Optional[str] = None): |
|
""" |
|
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute |
|
""" |
|
if deployment is None: |
|
return |
|
|
|
current_minute = datetime.now().strftime("%H-%M") |
|
|
|
|
|
|
|
|
|
current_fails = self.failed_calls.get_cache(key=deployment) or 0 |
|
updated_fails = current_fails + 1 |
|
verbose_router_logger.debug( |
|
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}" |
|
) |
|
if updated_fails > self.allowed_fails: |
|
|
|
cooldown_key = f"{current_minute}:cooldown_models" |
|
cached_value = self.cache.get_cache(key=cooldown_key) |
|
|
|
verbose_router_logger.debug(f"adding {deployment} to cooldown models") |
|
|
|
try: |
|
if deployment in cached_value: |
|
pass |
|
else: |
|
cached_value = cached_value + [deployment] |
|
|
|
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1) |
|
except: |
|
cached_value = [deployment] |
|
|
|
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1) |
|
else: |
|
self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=1) |
|
|
|
def _get_cooldown_deployments(self): |
|
""" |
|
Get the list of models being cooled down for this minute |
|
""" |
|
current_minute = datetime.now().strftime("%H-%M") |
|
|
|
cooldown_key = f"{current_minute}:cooldown_models" |
|
|
|
|
|
|
|
|
|
cooldown_models = self.cache.get_cache(key=cooldown_key) or [] |
|
|
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") |
|
return cooldown_models |
|
|
|
def set_client(self, model: dict): |
|
""" |
|
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 |
|
""" |
|
client_ttl = self.client_ttl |
|
litellm_params = model.get("litellm_params", {}) |
|
model_name = litellm_params.get("model") |
|
model_id = model["model_info"]["id"] |
|
|
|
custom_llm_provider = litellm_params.get("custom_llm_provider") |
|
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" |
|
default_api_base = None |
|
default_api_key = None |
|
if custom_llm_provider in litellm.openai_compatible_providers: |
|
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( |
|
model=model_name |
|
) |
|
default_api_base = api_base |
|
default_api_key = api_key |
|
if ( |
|
model_name in litellm.open_ai_chat_completion_models |
|
or custom_llm_provider in litellm.openai_compatible_providers |
|
or custom_llm_provider == "azure" |
|
or custom_llm_provider == "custom_openai" |
|
or custom_llm_provider == "openai" |
|
or "ft:gpt-3.5-turbo" in model_name |
|
or model_name in litellm.open_ai_embedding_models |
|
): |
|
|
|
|
|
|
|
api_key = litellm_params.get("api_key") or default_api_key |
|
if api_key and api_key.startswith("os.environ/"): |
|
api_key_env_name = api_key.replace("os.environ/", "") |
|
api_key = litellm.get_secret(api_key_env_name) |
|
litellm_params["api_key"] = api_key |
|
|
|
api_base = litellm_params.get("api_base") |
|
base_url = litellm_params.get("base_url") |
|
api_base = ( |
|
api_base or base_url or default_api_base |
|
) |
|
if api_base and api_base.startswith("os.environ/"): |
|
api_base_env_name = api_base.replace("os.environ/", "") |
|
api_base = litellm.get_secret(api_base_env_name) |
|
litellm_params["api_base"] = api_base |
|
|
|
api_version = litellm_params.get("api_version") |
|
if api_version and api_version.startswith("os.environ/"): |
|
api_version_env_name = api_version.replace("os.environ/", "") |
|
api_version = litellm.get_secret(api_version_env_name) |
|
litellm_params["api_version"] = api_version |
|
|
|
timeout = litellm_params.pop("timeout", None) |
|
if isinstance(timeout, str) and timeout.startswith("os.environ/"): |
|
timeout_env_name = timeout.replace("os.environ/", "") |
|
timeout = litellm.get_secret(timeout_env_name) |
|
litellm_params["timeout"] = timeout |
|
|
|
stream_timeout = litellm_params.pop( |
|
"stream_timeout", timeout |
|
) |
|
if isinstance(stream_timeout, str) and stream_timeout.startswith( |
|
"os.environ/" |
|
): |
|
stream_timeout_env_name = stream_timeout.replace("os.environ/", "") |
|
stream_timeout = litellm.get_secret(stream_timeout_env_name) |
|
litellm_params["stream_timeout"] = stream_timeout |
|
|
|
max_retries = litellm_params.pop("max_retries", 2) |
|
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): |
|
max_retries_env_name = max_retries.replace("os.environ/", "") |
|
max_retries = litellm.get_secret(max_retries_env_name) |
|
litellm_params["max_retries"] = max_retries |
|
|
|
if "azure" in model_name: |
|
if api_base is None: |
|
raise ValueError( |
|
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" |
|
) |
|
if api_version is None: |
|
api_version = "2023-07-01-preview" |
|
if "gateway.ai.cloudflare.com" in api_base: |
|
if not api_base.endswith("/"): |
|
api_base += "/" |
|
azure_model = model_name.replace("azure/", "") |
|
api_base += f"{azure_model}" |
|
cache_key = f"{model_id}_async_client" |
|
_client = openai.AsyncAzureOpenAI( |
|
api_key=api_key, |
|
base_url=api_base, |
|
api_version=api_version, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.AsyncClient( |
|
transport=AsyncCustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
cache_key = f"{model_id}_client" |
|
_client = openai.AzureOpenAI( |
|
api_key=api_key, |
|
base_url=api_base, |
|
api_version=api_version, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.Client( |
|
transport=CustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
cache_key = f"{model_id}_stream_async_client" |
|
_client = openai.AsyncAzureOpenAI( |
|
api_key=api_key, |
|
base_url=api_base, |
|
api_version=api_version, |
|
timeout=stream_timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.AsyncClient( |
|
transport=AsyncCustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
cache_key = f"{model_id}_stream_client" |
|
_client = openai.AzureOpenAI( |
|
api_key=api_key, |
|
base_url=api_base, |
|
api_version=api_version, |
|
timeout=stream_timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.Client( |
|
transport=CustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
else: |
|
verbose_router_logger.debug( |
|
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}" |
|
) |
|
|
|
cache_key = f"{model_id}_async_client" |
|
_client = openai.AsyncAzureOpenAI( |
|
api_key=api_key, |
|
azure_endpoint=api_base, |
|
api_version=api_version, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.AsyncClient( |
|
transport=AsyncCustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
cache_key = f"{model_id}_client" |
|
_client = openai.AzureOpenAI( |
|
api_key=api_key, |
|
azure_endpoint=api_base, |
|
api_version=api_version, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.Client( |
|
transport=CustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
|
|
cache_key = f"{model_id}_stream_async_client" |
|
_client = openai.AsyncAzureOpenAI( |
|
api_key=api_key, |
|
azure_endpoint=api_base, |
|
api_version=api_version, |
|
timeout=stream_timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.AsyncClient( |
|
transport=AsyncCustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
cache_key = f"{model_id}_stream_client" |
|
_client = openai.AzureOpenAI( |
|
api_key=api_key, |
|
azure_endpoint=api_base, |
|
api_version=api_version, |
|
timeout=stream_timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.Client( |
|
transport=CustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
else: |
|
verbose_router_logger.debug( |
|
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{api_key}" |
|
) |
|
cache_key = f"{model_id}_async_client" |
|
_client = openai.AsyncOpenAI( |
|
api_key=api_key, |
|
base_url=api_base, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.AsyncClient( |
|
transport=AsyncCustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
cache_key = f"{model_id}_client" |
|
_client = openai.OpenAI( |
|
api_key=api_key, |
|
base_url=api_base, |
|
timeout=timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.Client( |
|
transport=CustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
|
|
cache_key = f"{model_id}_stream_async_client" |
|
_client = openai.AsyncOpenAI( |
|
api_key=api_key, |
|
base_url=api_base, |
|
timeout=stream_timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.AsyncClient( |
|
transport=AsyncCustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
|
|
cache_key = f"{model_id}_stream_client" |
|
_client = openai.OpenAI( |
|
api_key=api_key, |
|
base_url=api_base, |
|
timeout=stream_timeout, |
|
max_retries=max_retries, |
|
http_client=httpx.Client( |
|
transport=CustomHTTPTransport(), |
|
limits=httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
), |
|
), |
|
) |
|
self.cache.set_cache( |
|
key=cache_key, |
|
value=_client, |
|
ttl=client_ttl, |
|
local_only=True, |
|
) |
|
|
|
def set_model_list(self, model_list: list): |
|
self.model_list = copy.deepcopy(model_list) |
|
|
|
import os |
|
|
|
for model in self.model_list: |
|
|
|
model_info = model.get("model_info", {}) |
|
model_info["id"] = model_info.get("id", str(uuid.uuid4())) |
|
model["model_info"] = model_info |
|
|
|
self.deployment_names.append(model["litellm_params"]["model"]) |
|
|
|
|
|
|
|
if ( |
|
model["litellm_params"].get("rpm") is None |
|
and model.get("rpm") is not None |
|
): |
|
model["litellm_params"]["rpm"] = model.get("rpm") |
|
if ( |
|
model["litellm_params"].get("tpm") is None |
|
and model.get("tpm") is not None |
|
): |
|
model["litellm_params"]["tpm"] = model.get("tpm") |
|
|
|
|
|
|
|
( |
|
_model, |
|
custom_llm_provider, |
|
dynamic_api_key, |
|
api_base, |
|
) = litellm.get_llm_provider( |
|
model=model["litellm_params"]["model"], |
|
custom_llm_provider=model["litellm_params"].get( |
|
"custom_llm_provider", None |
|
), |
|
) |
|
|
|
if custom_llm_provider not in litellm.provider_list: |
|
raise Exception(f"Unsupported provider - {custom_llm_provider}") |
|
|
|
self.set_client(model=model) |
|
|
|
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}") |
|
self.model_names = [m["model_name"] for m in model_list] |
|
|
|
def get_model_names(self): |
|
return self.model_names |
|
|
|
def _get_client(self, deployment, kwargs, client_type=None): |
|
""" |
|
Returns the appropriate client based on the given deployment, kwargs, and client_type. |
|
|
|
Parameters: |
|
deployment (dict): The deployment dictionary containing the clients. |
|
kwargs (dict): The keyword arguments passed to the function. |
|
client_type (str): The type of client to return. |
|
|
|
Returns: |
|
The appropriate client based on the given client_type and kwargs. |
|
""" |
|
model_id = deployment["model_info"]["id"] |
|
if client_type == "async": |
|
if kwargs.get("stream") == True: |
|
cache_key = f"{model_id}_stream_async_client" |
|
client = self.cache.get_cache(key=cache_key, local_only=True) |
|
if client is None: |
|
""" |
|
Re-initialize the client |
|
""" |
|
self.set_client(model=deployment) |
|
client = self.cache.get_cache(key=cache_key, local_only=True) |
|
return client |
|
else: |
|
cache_key = f"{model_id}_async_client" |
|
client = self.cache.get_cache(key=cache_key, local_only=True) |
|
if client is None: |
|
""" |
|
Re-initialize the client |
|
""" |
|
self.set_client(model=deployment) |
|
client = self.cache.get_cache(key=cache_key, local_only=True) |
|
return client |
|
else: |
|
if kwargs.get("stream") == True: |
|
cache_key = f"{model_id}_stream_client" |
|
client = self.cache.get_cache(key=cache_key) |
|
if client is None: |
|
""" |
|
Re-initialize the client |
|
""" |
|
self.set_client(model=deployment) |
|
client = self.cache.get_cache(key=cache_key) |
|
return client |
|
else: |
|
cache_key = f"{model_id}_client" |
|
client = self.cache.get_cache(key=cache_key) |
|
if client is None: |
|
""" |
|
Re-initialize the client |
|
""" |
|
self.set_client(model=deployment) |
|
client = self.cache.get_cache(key=cache_key) |
|
return client |
|
|
|
def get_available_deployment( |
|
self, |
|
model: str, |
|
messages: Optional[List[Dict[str, str]]] = None, |
|
input: Optional[Union[str, List]] = None, |
|
specific_deployment: Optional[bool] = False, |
|
): |
|
""" |
|
Returns the deployment based on routing strategy |
|
""" |
|
|
|
|
|
|
|
if specific_deployment == True: |
|
|
|
for deployment in self.model_list: |
|
deployment_model = deployment.get("litellm_params").get("model") |
|
if deployment_model == model: |
|
|
|
|
|
return deployment |
|
raise ValueError( |
|
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}" |
|
) |
|
|
|
|
|
if model in self.model_group_alias: |
|
verbose_router_logger.debug( |
|
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}" |
|
) |
|
model = self.model_group_alias[model] |
|
|
|
|
|
|
|
healthy_deployments = [m for m in self.model_list if m["model_name"] == model] |
|
if len(healthy_deployments) == 0: |
|
|
|
healthy_deployments = [ |
|
m for m in self.model_list if m["litellm_params"]["model"] == model |
|
] |
|
|
|
verbose_router_logger.debug( |
|
f"initial list of deployments: {healthy_deployments}" |
|
) |
|
|
|
|
|
deployments_to_remove = [] |
|
|
|
cooldown_deployments = self._get_cooldown_deployments() |
|
verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}") |
|
|
|
for deployment in healthy_deployments: |
|
deployment_id = deployment["model_info"]["id"] |
|
if deployment_id in cooldown_deployments: |
|
deployments_to_remove.append(deployment) |
|
|
|
for deployment in deployments_to_remove: |
|
healthy_deployments.remove(deployment) |
|
|
|
verbose_router_logger.debug( |
|
f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}" |
|
) |
|
if len(healthy_deployments) == 0: |
|
raise ValueError("No models available") |
|
if litellm.model_alias_map and model in litellm.model_alias_map: |
|
model = litellm.model_alias_map[ |
|
model |
|
] |
|
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: |
|
deployment = self.leastbusy_logger.get_available_deployments( |
|
model_group=model, healthy_deployments=healthy_deployments |
|
) |
|
elif self.routing_strategy == "simple-shuffle": |
|
|
|
|
|
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) |
|
if rpm is not None: |
|
|
|
rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments] |
|
verbose_router_logger.debug(f"\nrpms {rpms}") |
|
total_rpm = sum(rpms) |
|
weights = [rpm / total_rpm for rpm in rpms] |
|
verbose_router_logger.debug(f"\n weights {weights}") |
|
|
|
selected_index = random.choices(range(len(rpms)), weights=weights)[0] |
|
verbose_router_logger.debug(f"\n selected index, {selected_index}") |
|
deployment = healthy_deployments[selected_index] |
|
return deployment or deployment[0] |
|
|
|
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None) |
|
if tpm is not None: |
|
|
|
tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments] |
|
verbose_router_logger.debug(f"\ntpms {tpms}") |
|
total_tpm = sum(tpms) |
|
weights = [tpm / total_tpm for tpm in tpms] |
|
verbose_router_logger.debug(f"\n weights {weights}") |
|
|
|
selected_index = random.choices(range(len(tpms)), weights=weights)[0] |
|
verbose_router_logger.debug(f"\n selected index, {selected_index}") |
|
deployment = healthy_deployments[selected_index] |
|
return deployment or deployment[0] |
|
|
|
|
|
item = random.choice(healthy_deployments) |
|
return item or item[0] |
|
elif ( |
|
self.routing_strategy == "latency-based-routing" |
|
and self.lowestlatency_logger is not None |
|
): |
|
deployment = self.lowestlatency_logger.get_available_deployments( |
|
model_group=model, healthy_deployments=healthy_deployments |
|
) |
|
elif ( |
|
self.routing_strategy == "usage-based-routing" |
|
and self.lowesttpm_logger is not None |
|
): |
|
deployment = self.lowesttpm_logger.get_available_deployments( |
|
model_group=model, |
|
healthy_deployments=healthy_deployments, |
|
messages=messages, |
|
input=input, |
|
) |
|
|
|
if deployment is None: |
|
raise ValueError("No models available.") |
|
|
|
return deployment |
|
|
|
def flush_cache(self): |
|
litellm.cache = None |
|
self.cache.flush_cache() |
|
|
|
def reset(self): |
|
|
|
litellm.success_callback = [] |
|
litellm.__async_success_callback = [] |
|
litellm.failure_callback = [] |
|
litellm._async_failure_callback = [] |
|
self.flush_cache() |
|
|