Spaces:
Sleeping
Sleeping
import asyncio | |
import sys | |
from datetime import datetime, timedelta | |
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union | |
from fastapi import HTTPException | |
from pydantic import BaseModel | |
import litellm | |
from litellm import DualCache, ModelResponse | |
from litellm._logging import verbose_proxy_logger | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs | |
from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth | |
from litellm.proxy.auth.auth_utils import ( | |
get_key_model_rpm_limit, | |
get_key_model_tpm_limit, | |
) | |
if TYPE_CHECKING: | |
from opentelemetry.trace import Span as _Span | |
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache | |
Span = Union[_Span, Any] | |
InternalUsageCache = _InternalUsageCache | |
else: | |
Span = Any | |
InternalUsageCache = Any | |
class CacheObject(TypedDict): | |
current_global_requests: Optional[dict] | |
request_count_api_key: Optional[dict] | |
request_count_api_key_model: Optional[dict] | |
request_count_user_id: Optional[dict] | |
request_count_team_id: Optional[dict] | |
request_count_end_user_id: Optional[dict] | |
class _PROXY_MaxParallelRequestsHandler(CustomLogger): | |
# Class variables or attributes | |
def __init__(self, internal_usage_cache: InternalUsageCache): | |
self.internal_usage_cache = internal_usage_cache | |
def print_verbose(self, print_statement): | |
try: | |
verbose_proxy_logger.debug(print_statement) | |
if litellm.set_verbose: | |
print(print_statement) # noqa | |
except Exception: | |
pass | |
async def check_key_in_limits( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: str, | |
max_parallel_requests: int, | |
tpm_limit: int, | |
rpm_limit: int, | |
current: Optional[dict], | |
request_count_api_key: str, | |
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"], | |
values_to_update_in_cache: List[Tuple[Any, Any]], | |
) -> dict: | |
verbose_proxy_logger.info( | |
f"Current Usage of {rate_limit_type} in this minute: {current}" | |
) | |
if current is None: | |
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: | |
# base case | |
raise self.raise_rate_limit_error( | |
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}" | |
) | |
new_val = { | |
"current_requests": 1, | |
"current_tpm": 0, | |
"current_rpm": 1, | |
} | |
values_to_update_in_cache.append((request_count_api_key, new_val)) | |
elif ( | |
int(current["current_requests"]) < max_parallel_requests | |
and current["current_tpm"] < tpm_limit | |
and current["current_rpm"] < rpm_limit | |
): | |
# Increase count for this token | |
new_val = { | |
"current_requests": current["current_requests"] + 1, | |
"current_tpm": current["current_tpm"], | |
"current_rpm": current["current_rpm"] + 1, | |
} | |
values_to_update_in_cache.append((request_count_api_key, new_val)) | |
else: | |
raise HTTPException( | |
status_code=429, | |
detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. {CommonProxyErrors.max_parallel_request_limit_reached.value}. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}", | |
headers={"retry-after": str(self.time_to_next_minute())}, | |
) | |
await self.internal_usage_cache.async_batch_set_cache( | |
cache_list=values_to_update_in_cache, | |
ttl=60, | |
litellm_parent_otel_span=user_api_key_dict.parent_otel_span, | |
local_only=True, | |
) | |
return new_val | |
def time_to_next_minute(self) -> float: | |
# Get the current time | |
now = datetime.now() | |
# Calculate the next minute | |
next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0) | |
# Calculate the difference in seconds | |
seconds_to_next_minute = (next_minute - now).total_seconds() | |
return seconds_to_next_minute | |
def raise_rate_limit_error( | |
self, additional_details: Optional[str] = None | |
) -> HTTPException: | |
""" | |
Raise an HTTPException with a 429 status code and a retry-after header | |
""" | |
error_message = "Max parallel request limit reached" | |
if additional_details is not None: | |
error_message = error_message + " " + additional_details | |
raise HTTPException( | |
status_code=429, | |
detail=f"Max parallel request limit reached {additional_details}", | |
headers={"retry-after": str(self.time_to_next_minute())}, | |
) | |
async def get_all_cache_objects( | |
self, | |
current_global_requests: Optional[str], | |
request_count_api_key: Optional[str], | |
request_count_api_key_model: Optional[str], | |
request_count_user_id: Optional[str], | |
request_count_team_id: Optional[str], | |
request_count_end_user_id: Optional[str], | |
parent_otel_span: Optional[Span] = None, | |
) -> CacheObject: | |
keys = [ | |
current_global_requests, | |
request_count_api_key, | |
request_count_api_key_model, | |
request_count_user_id, | |
request_count_team_id, | |
request_count_end_user_id, | |
] | |
results = await self.internal_usage_cache.async_batch_get_cache( | |
keys=keys, | |
parent_otel_span=parent_otel_span, | |
) | |
if results is None: | |
return CacheObject( | |
current_global_requests=None, | |
request_count_api_key=None, | |
request_count_api_key_model=None, | |
request_count_user_id=None, | |
request_count_team_id=None, | |
request_count_end_user_id=None, | |
) | |
return CacheObject( | |
current_global_requests=results[0], | |
request_count_api_key=results[1], | |
request_count_api_key_model=results[2], | |
request_count_user_id=results[3], | |
request_count_team_id=results[4], | |
request_count_end_user_id=results[5], | |
) | |
async def async_pre_call_hook( # noqa: PLR0915 | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: str, | |
): | |
self.print_verbose("Inside Max Parallel Request Pre-Call Hook") | |
api_key = user_api_key_dict.api_key | |
max_parallel_requests = user_api_key_dict.max_parallel_requests | |
if max_parallel_requests is None: | |
max_parallel_requests = sys.maxsize | |
if data is None: | |
data = {} | |
global_max_parallel_requests = data.get("metadata", {}).get( | |
"global_max_parallel_requests", None | |
) | |
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) | |
if tpm_limit is None: | |
tpm_limit = sys.maxsize | |
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) | |
if rpm_limit is None: | |
rpm_limit = sys.maxsize | |
values_to_update_in_cache: List[ | |
Tuple[Any, Any] | |
] = ( | |
[] | |
) # values that need to get updated in cache, will run a batch_set_cache after this function | |
# ------------ | |
# Setup values | |
# ------------ | |
new_val: Optional[dict] = None | |
if global_max_parallel_requests is not None: | |
# get value from cache | |
_key = "global_max_parallel_requests" | |
current_global_requests = await self.internal_usage_cache.async_get_cache( | |
key=_key, | |
local_only=True, | |
litellm_parent_otel_span=user_api_key_dict.parent_otel_span, | |
) | |
# check if below limit | |
if current_global_requests is None: | |
current_global_requests = 1 | |
# if above -> raise error | |
if current_global_requests >= global_max_parallel_requests: | |
return self.raise_rate_limit_error( | |
additional_details=f"Hit Global Limit: Limit={global_max_parallel_requests}, current: {current_global_requests}" | |
) | |
# if below -> increment | |
else: | |
await self.internal_usage_cache.async_increment_cache( | |
key=_key, | |
value=1, | |
local_only=True, | |
litellm_parent_otel_span=user_api_key_dict.parent_otel_span, | |
) | |
_model = data.get("model", None) | |
current_date = datetime.now().strftime("%Y-%m-%d") | |
current_hour = datetime.now().strftime("%H") | |
current_minute = datetime.now().strftime("%M") | |
precise_minute = f"{current_date}-{current_hour}-{current_minute}" | |
cache_objects: CacheObject = await self.get_all_cache_objects( | |
current_global_requests=( | |
"global_max_parallel_requests" | |
if global_max_parallel_requests is not None | |
else None | |
), | |
request_count_api_key=( | |
f"{api_key}::{precise_minute}::request_count" | |
if api_key is not None | |
else None | |
), | |
request_count_api_key_model=( | |
f"{api_key}::{_model}::{precise_minute}::request_count" | |
if api_key is not None and _model is not None | |
else None | |
), | |
request_count_user_id=( | |
f"{user_api_key_dict.user_id}::{precise_minute}::request_count" | |
if user_api_key_dict.user_id is not None | |
else None | |
), | |
request_count_team_id=( | |
f"{user_api_key_dict.team_id}::{precise_minute}::request_count" | |
if user_api_key_dict.team_id is not None | |
else None | |
), | |
request_count_end_user_id=( | |
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count" | |
if user_api_key_dict.end_user_id is not None | |
else None | |
), | |
parent_otel_span=user_api_key_dict.parent_otel_span, | |
) | |
if api_key is not None: | |
request_count_api_key = f"{api_key}::{precise_minute}::request_count" | |
# CHECK IF REQUEST ALLOWED for key | |
await self.check_key_in_limits( | |
user_api_key_dict=user_api_key_dict, | |
cache=cache, | |
data=data, | |
call_type=call_type, | |
max_parallel_requests=max_parallel_requests, | |
current=cache_objects["request_count_api_key"], | |
request_count_api_key=request_count_api_key, | |
tpm_limit=tpm_limit, | |
rpm_limit=rpm_limit, | |
rate_limit_type="key", | |
values_to_update_in_cache=values_to_update_in_cache, | |
) | |
# Check if request under RPM/TPM per model for a given API Key | |
if ( | |
get_key_model_tpm_limit(user_api_key_dict) is not None | |
or get_key_model_rpm_limit(user_api_key_dict) is not None | |
): | |
_model = data.get("model", None) | |
request_count_api_key = ( | |
f"{api_key}::{_model}::{precise_minute}::request_count" | |
) | |
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) | |
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) | |
tpm_limit_for_model = None | |
rpm_limit_for_model = None | |
if _model is not None: | |
if _tpm_limit_for_key_model: | |
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model) | |
if _rpm_limit_for_key_model: | |
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model) | |
new_val = await self.check_key_in_limits( | |
user_api_key_dict=user_api_key_dict, | |
cache=cache, | |
data=data, | |
call_type=call_type, | |
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a model | |
current=cache_objects["request_count_api_key_model"], | |
request_count_api_key=request_count_api_key, | |
tpm_limit=tpm_limit_for_model or sys.maxsize, | |
rpm_limit=rpm_limit_for_model or sys.maxsize, | |
rate_limit_type="model_per_key", | |
values_to_update_in_cache=values_to_update_in_cache, | |
) | |
_remaining_tokens = None | |
_remaining_requests = None | |
# Add remaining tokens, requests to metadata | |
if new_val: | |
if tpm_limit_for_model is not None: | |
_remaining_tokens = tpm_limit_for_model - new_val["current_tpm"] | |
if rpm_limit_for_model is not None: | |
_remaining_requests = rpm_limit_for_model - new_val["current_rpm"] | |
_remaining_limits_data = { | |
f"litellm-key-remaining-tokens-{_model}": _remaining_tokens, | |
f"litellm-key-remaining-requests-{_model}": _remaining_requests, | |
} | |
if "metadata" not in data: | |
data["metadata"] = {} | |
data["metadata"].update(_remaining_limits_data) | |
# check if REQUEST ALLOWED for user_id | |
user_id = user_api_key_dict.user_id | |
if user_id is not None: | |
user_tpm_limit = user_api_key_dict.user_tpm_limit | |
user_rpm_limit = user_api_key_dict.user_rpm_limit | |
if user_tpm_limit is None: | |
user_tpm_limit = sys.maxsize | |
if user_rpm_limit is None: | |
user_rpm_limit = sys.maxsize | |
request_count_api_key = f"{user_id}::{precise_minute}::request_count" | |
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") | |
await self.check_key_in_limits( | |
user_api_key_dict=user_api_key_dict, | |
cache=cache, | |
data=data, | |
call_type=call_type, | |
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user | |
current=cache_objects["request_count_user_id"], | |
request_count_api_key=request_count_api_key, | |
tpm_limit=user_tpm_limit, | |
rpm_limit=user_rpm_limit, | |
rate_limit_type="user", | |
values_to_update_in_cache=values_to_update_in_cache, | |
) | |
# TEAM RATE LIMITS | |
## get team tpm/rpm limits | |
team_id = user_api_key_dict.team_id | |
if team_id is not None: | |
team_tpm_limit = user_api_key_dict.team_tpm_limit | |
team_rpm_limit = user_api_key_dict.team_rpm_limit | |
if team_tpm_limit is None: | |
team_tpm_limit = sys.maxsize | |
if team_rpm_limit is None: | |
team_rpm_limit = sys.maxsize | |
request_count_api_key = f"{team_id}::{precise_minute}::request_count" | |
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") | |
await self.check_key_in_limits( | |
user_api_key_dict=user_api_key_dict, | |
cache=cache, | |
data=data, | |
call_type=call_type, | |
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team | |
current=cache_objects["request_count_team_id"], | |
request_count_api_key=request_count_api_key, | |
tpm_limit=team_tpm_limit, | |
rpm_limit=team_rpm_limit, | |
rate_limit_type="team", | |
values_to_update_in_cache=values_to_update_in_cache, | |
) | |
# End-User Rate Limits | |
# Only enforce if user passed `user` to /chat, /completions, /embeddings | |
if user_api_key_dict.end_user_id: | |
end_user_tpm_limit = getattr( | |
user_api_key_dict, "end_user_tpm_limit", sys.maxsize | |
) | |
end_user_rpm_limit = getattr( | |
user_api_key_dict, "end_user_rpm_limit", sys.maxsize | |
) | |
if end_user_tpm_limit is None: | |
end_user_tpm_limit = sys.maxsize | |
if end_user_rpm_limit is None: | |
end_user_rpm_limit = sys.maxsize | |
# now do the same tpm/rpm checks | |
request_count_api_key = ( | |
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count" | |
) | |
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") | |
await self.check_key_in_limits( | |
user_api_key_dict=user_api_key_dict, | |
cache=cache, | |
data=data, | |
call_type=call_type, | |
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User | |
request_count_api_key=request_count_api_key, | |
current=cache_objects["request_count_end_user_id"], | |
tpm_limit=end_user_tpm_limit, | |
rpm_limit=end_user_rpm_limit, | |
rate_limit_type="customer", | |
values_to_update_in_cache=values_to_update_in_cache, | |
) | |
asyncio.create_task( | |
self.internal_usage_cache.async_batch_set_cache( | |
cache_list=values_to_update_in_cache, | |
ttl=60, | |
litellm_parent_otel_span=user_api_key_dict.parent_otel_span, | |
) # don't block execution for cache updates | |
) | |
return | |
async def async_log_success_event( # noqa: PLR0915 | |
self, kwargs, response_obj, start_time, end_time | |
): | |
from litellm.proxy.common_utils.callback_utils import ( | |
get_model_group_from_litellm_kwargs, | |
) | |
litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs( | |
kwargs=kwargs | |
) | |
try: | |
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") | |
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( | |
"global_max_parallel_requests", None | |
) | |
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] | |
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( | |
"user_api_key_user_id", None | |
) | |
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( | |
"user_api_key_team_id", None | |
) | |
user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get( | |
"user_api_key_model_max_budget", None | |
) | |
user_api_key_end_user_id = kwargs.get("user") | |
user_api_key_metadata = ( | |
kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {}) | |
or {} | |
) | |
# ------------ | |
# Setup values | |
# ------------ | |
if global_max_parallel_requests is not None: | |
# get value from cache | |
_key = "global_max_parallel_requests" | |
# decrement | |
await self.internal_usage_cache.async_increment_cache( | |
key=_key, | |
value=-1, | |
local_only=True, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) | |
current_date = datetime.now().strftime("%Y-%m-%d") | |
current_hour = datetime.now().strftime("%H") | |
current_minute = datetime.now().strftime("%M") | |
precise_minute = f"{current_date}-{current_hour}-{current_minute}" | |
total_tokens = 0 | |
if isinstance(response_obj, ModelResponse): | |
total_tokens = response_obj.usage.total_tokens # type: ignore | |
# ------------ | |
# Update usage - API Key | |
# ------------ | |
values_to_update_in_cache = [] | |
if user_api_key is not None: | |
request_count_api_key = ( | |
f"{user_api_key}::{precise_minute}::request_count" | |
) | |
current = await self.internal_usage_cache.async_get_cache( | |
key=request_count_api_key, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) or { | |
"current_requests": 1, | |
"current_tpm": 0, | |
"current_rpm": 0, | |
} | |
new_val = { | |
"current_requests": max(current["current_requests"] - 1, 0), | |
"current_tpm": current["current_tpm"] + total_tokens, | |
"current_rpm": current["current_rpm"], | |
} | |
self.print_verbose( | |
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" | |
) | |
values_to_update_in_cache.append((request_count_api_key, new_val)) | |
# ------------ | |
# Update usage - model group + API Key | |
# ------------ | |
model_group = get_model_group_from_litellm_kwargs(kwargs) | |
if ( | |
user_api_key is not None | |
and model_group is not None | |
and ( | |
"model_rpm_limit" in user_api_key_metadata | |
or "model_tpm_limit" in user_api_key_metadata | |
or user_api_key_model_max_budget is not None | |
) | |
): | |
request_count_api_key = ( | |
f"{user_api_key}::{model_group}::{precise_minute}::request_count" | |
) | |
current = await self.internal_usage_cache.async_get_cache( | |
key=request_count_api_key, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) or { | |
"current_requests": 1, | |
"current_tpm": 0, | |
"current_rpm": 0, | |
} | |
new_val = { | |
"current_requests": max(current["current_requests"] - 1, 0), | |
"current_tpm": current["current_tpm"] + total_tokens, | |
"current_rpm": current["current_rpm"], | |
} | |
self.print_verbose( | |
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" | |
) | |
values_to_update_in_cache.append((request_count_api_key, new_val)) | |
# ------------ | |
# Update usage - User | |
# ------------ | |
if user_api_key_user_id is not None: | |
total_tokens = 0 | |
if isinstance(response_obj, ModelResponse): | |
total_tokens = response_obj.usage.total_tokens # type: ignore | |
request_count_api_key = ( | |
f"{user_api_key_user_id}::{precise_minute}::request_count" | |
) | |
current = await self.internal_usage_cache.async_get_cache( | |
key=request_count_api_key, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) or { | |
"current_requests": 1, | |
"current_tpm": total_tokens, | |
"current_rpm": 1, | |
} | |
new_val = { | |
"current_requests": max(current["current_requests"] - 1, 0), | |
"current_tpm": current["current_tpm"] + total_tokens, | |
"current_rpm": current["current_rpm"], | |
} | |
self.print_verbose( | |
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" | |
) | |
values_to_update_in_cache.append((request_count_api_key, new_val)) | |
# ------------ | |
# Update usage - Team | |
# ------------ | |
if user_api_key_team_id is not None: | |
total_tokens = 0 | |
if isinstance(response_obj, ModelResponse): | |
total_tokens = response_obj.usage.total_tokens # type: ignore | |
request_count_api_key = ( | |
f"{user_api_key_team_id}::{precise_minute}::request_count" | |
) | |
current = await self.internal_usage_cache.async_get_cache( | |
key=request_count_api_key, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) or { | |
"current_requests": 1, | |
"current_tpm": total_tokens, | |
"current_rpm": 1, | |
} | |
new_val = { | |
"current_requests": max(current["current_requests"] - 1, 0), | |
"current_tpm": current["current_tpm"] + total_tokens, | |
"current_rpm": current["current_rpm"], | |
} | |
self.print_verbose( | |
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" | |
) | |
values_to_update_in_cache.append((request_count_api_key, new_val)) | |
# ------------ | |
# Update usage - End User | |
# ------------ | |
if user_api_key_end_user_id is not None: | |
total_tokens = 0 | |
if isinstance(response_obj, ModelResponse): | |
total_tokens = response_obj.usage.total_tokens # type: ignore | |
request_count_api_key = ( | |
f"{user_api_key_end_user_id}::{precise_minute}::request_count" | |
) | |
current = await self.internal_usage_cache.async_get_cache( | |
key=request_count_api_key, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) or { | |
"current_requests": 1, | |
"current_tpm": total_tokens, | |
"current_rpm": 1, | |
} | |
new_val = { | |
"current_requests": max(current["current_requests"] - 1, 0), | |
"current_tpm": current["current_tpm"] + total_tokens, | |
"current_rpm": current["current_rpm"], | |
} | |
self.print_verbose( | |
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" | |
) | |
values_to_update_in_cache.append((request_count_api_key, new_val)) | |
await self.internal_usage_cache.async_batch_set_cache( | |
cache_list=values_to_update_in_cache, | |
ttl=60, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) | |
except Exception as e: | |
self.print_verbose(e) # noqa | |
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
try: | |
self.print_verbose("Inside Max Parallel Request Failure Hook") | |
litellm_parent_otel_span: Union[ | |
Span, None | |
] = _get_parent_otel_span_from_kwargs(kwargs=kwargs) | |
_metadata = kwargs["litellm_params"].get("metadata", {}) or {} | |
global_max_parallel_requests = _metadata.get( | |
"global_max_parallel_requests", None | |
) | |
user_api_key = _metadata.get("user_api_key", None) | |
self.print_verbose(f"user_api_key: {user_api_key}") | |
if user_api_key is None: | |
return | |
## decrement call count if call failed | |
if CommonProxyErrors.max_parallel_request_limit_reached.value in str( | |
kwargs["exception"] | |
): | |
pass # ignore failed calls due to max limit being reached | |
else: | |
# ------------ | |
# Setup values | |
# ------------ | |
if global_max_parallel_requests is not None: | |
# get value from cache | |
_key = "global_max_parallel_requests" | |
( | |
await self.internal_usage_cache.async_get_cache( | |
key=_key, | |
local_only=True, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) | |
) | |
# decrement | |
await self.internal_usage_cache.async_increment_cache( | |
key=_key, | |
value=-1, | |
local_only=True, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) | |
current_date = datetime.now().strftime("%Y-%m-%d") | |
current_hour = datetime.now().strftime("%H") | |
current_minute = datetime.now().strftime("%M") | |
precise_minute = f"{current_date}-{current_hour}-{current_minute}" | |
request_count_api_key = ( | |
f"{user_api_key}::{precise_minute}::request_count" | |
) | |
# ------------ | |
# Update usage | |
# ------------ | |
current = await self.internal_usage_cache.async_get_cache( | |
key=request_count_api_key, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) or { | |
"current_requests": 1, | |
"current_tpm": 0, | |
"current_rpm": 0, | |
} | |
new_val = { | |
"current_requests": max(current["current_requests"] - 1, 0), | |
"current_tpm": current["current_tpm"], | |
"current_rpm": current["current_rpm"], | |
} | |
self.print_verbose(f"updated_value in failure call: {new_val}") | |
await self.internal_usage_cache.async_set_cache( | |
request_count_api_key, | |
new_val, | |
ttl=60, | |
litellm_parent_otel_span=litellm_parent_otel_span, | |
) # save in cache for up to 1 min. | |
except Exception as e: | |
verbose_proxy_logger.exception( | |
"Inside Parallel Request Limiter: An exception occurred - {}".format( | |
str(e) | |
) | |
) | |
async def get_internal_user_object( | |
self, | |
user_id: str, | |
user_api_key_dict: UserAPIKeyAuth, | |
) -> Optional[dict]: | |
""" | |
Helper to get the 'Internal User Object' | |
It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks` | |
We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable. | |
""" | |
from litellm._logging import verbose_proxy_logger | |
from litellm.proxy.auth.auth_checks import get_user_object | |
from litellm.proxy.proxy_server import prisma_client | |
try: | |
_user_id_rate_limits = await get_user_object( | |
user_id=user_id, | |
prisma_client=prisma_client, | |
user_api_key_cache=self.internal_usage_cache.dual_cache, | |
user_id_upsert=False, | |
parent_otel_span=user_api_key_dict.parent_otel_span, | |
proxy_logging_obj=None, | |
) | |
if _user_id_rate_limits is None: | |
return None | |
return _user_id_rate_limits.model_dump() | |
except Exception as e: | |
verbose_proxy_logger.debug( | |
"Parallel Request Limiter: Error getting user object", str(e) | |
) | |
return None | |
async def async_post_call_success_hook( | |
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response | |
): | |
""" | |
Retrieve the key's remaining rate limits. | |
""" | |
api_key = user_api_key_dict.api_key | |
current_date = datetime.now().strftime("%Y-%m-%d") | |
current_hour = datetime.now().strftime("%H") | |
current_minute = datetime.now().strftime("%M") | |
precise_minute = f"{current_date}-{current_hour}-{current_minute}" | |
request_count_api_key = f"{api_key}::{precise_minute}::request_count" | |
current: Optional[ | |
CurrentItemRateLimit | |
] = await self.internal_usage_cache.async_get_cache( | |
key=request_count_api_key, | |
litellm_parent_otel_span=user_api_key_dict.parent_otel_span, | |
) | |
key_remaining_rpm_limit: Optional[int] = None | |
key_rpm_limit: Optional[int] = None | |
key_remaining_tpm_limit: Optional[int] = None | |
key_tpm_limit: Optional[int] = None | |
if current is not None: | |
if user_api_key_dict.rpm_limit is not None: | |
key_remaining_rpm_limit = ( | |
user_api_key_dict.rpm_limit - current["current_rpm"] | |
) | |
key_rpm_limit = user_api_key_dict.rpm_limit | |
if user_api_key_dict.tpm_limit is not None: | |
key_remaining_tpm_limit = ( | |
user_api_key_dict.tpm_limit - current["current_tpm"] | |
) | |
key_tpm_limit = user_api_key_dict.tpm_limit | |
if hasattr(response, "_hidden_params"): | |
_hidden_params = getattr(response, "_hidden_params") | |
else: | |
_hidden_params = None | |
if _hidden_params is not None and ( | |
isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict) | |
): | |
if isinstance(_hidden_params, BaseModel): | |
_hidden_params = _hidden_params.model_dump() | |
_additional_headers = _hidden_params.get("additional_headers", {}) or {} | |
if key_remaining_rpm_limit is not None: | |
_additional_headers[ | |
"x-ratelimit-remaining-requests" | |
] = key_remaining_rpm_limit | |
if key_rpm_limit is not None: | |
_additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit | |
if key_remaining_tpm_limit is not None: | |
_additional_headers[ | |
"x-ratelimit-remaining-tokens" | |
] = key_remaining_tpm_limit | |
if key_tpm_limit is not None: | |
_additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit | |
setattr( | |
response, | |
"_hidden_params", | |
{**_hidden_params, "additional_headers": _additional_headers}, | |
) | |
return await super().async_post_call_success_hook( | |
data, user_api_key_dict, response | |
) | |