Spaces:
Running
Running
# +-----------------------------------------------+ | |
# | | | |
# | Give Feedback / Get Help | | |
# | https://github.com/BerriAI/litellm/issues/new | | |
# | | | |
# +-----------------------------------------------+ | |
# | |
# Thank you users! We ❤️ you! - Krrish & Ishaan | |
import inspect | |
import json | |
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation | |
import os | |
from typing import List, Optional, Union | |
import redis # type: ignore | |
import redis.asyncio as async_redis # type: ignore | |
from litellm import get_secret, get_secret_str | |
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT | |
from ._logging import verbose_logger | |
def _get_redis_kwargs(): | |
arg_spec = inspect.getfullargspec(redis.Redis) | |
# Only allow primitive arguments | |
exclude_args = { | |
"self", | |
"connection_pool", | |
"retry", | |
} | |
include_args = ["url"] | |
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args | |
return available_args | |
def _get_redis_url_kwargs(client=None): | |
if client is None: | |
client = redis.Redis.from_url | |
arg_spec = inspect.getfullargspec(redis.Redis.from_url) | |
# Only allow primitive arguments | |
exclude_args = { | |
"self", | |
"connection_pool", | |
"retry", | |
} | |
include_args = ["url"] | |
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args | |
return available_args | |
def _get_redis_cluster_kwargs(client=None): | |
if client is None: | |
client = redis.Redis.from_url | |
arg_spec = inspect.getfullargspec(redis.RedisCluster) | |
# Only allow primitive arguments | |
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"} | |
available_args = [x for x in arg_spec.args if x not in exclude_args] | |
available_args.append("password") | |
available_args.append("username") | |
available_args.append("ssl") | |
return available_args | |
def _get_redis_env_kwarg_mapping(): | |
PREFIX = "REDIS_" | |
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()} | |
def _redis_kwargs_from_environment(): | |
mapping = _get_redis_env_kwarg_mapping() | |
return_dict = {} | |
for k, v in mapping.items(): | |
value = get_secret(k, default_value=None) # type: ignore | |
if value is not None: | |
return_dict[v] = value | |
return return_dict | |
def get_redis_url_from_environment(): | |
if "REDIS_URL" in os.environ: | |
return os.environ["REDIS_URL"] | |
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ: | |
raise ValueError( | |
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis." | |
) | |
if "REDIS_PASSWORD" in os.environ: | |
redis_password = f":{os.environ['REDIS_PASSWORD']}@" | |
else: | |
redis_password = "" | |
return ( | |
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" | |
) | |
def _get_redis_client_logic(**env_overrides): | |
""" | |
Common functionality across sync + async redis client implementations | |
""" | |
### check if "os.environ/<key-name>" passed in | |
for k, v in env_overrides.items(): | |
if isinstance(v, str) and v.startswith("os.environ/"): | |
v = v.replace("os.environ/", "") | |
value = get_secret(v) # type: ignore | |
env_overrides[k] = value | |
redis_kwargs = { | |
**_redis_kwargs_from_environment(), | |
**env_overrides, | |
} | |
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore | |
"REDIS_CLUSTER_NODES" | |
) | |
if _startup_nodes is not None and isinstance(_startup_nodes, str): | |
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes) | |
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore | |
"REDIS_SENTINEL_NODES" | |
) | |
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str): | |
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes) | |
_sentinel_password: Optional[str] = redis_kwargs.get( | |
"sentinel_password", None | |
) or get_secret_str("REDIS_SENTINEL_PASSWORD") | |
if _sentinel_password is not None: | |
redis_kwargs["sentinel_password"] = _sentinel_password | |
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore | |
"REDIS_SERVICE_NAME" | |
) | |
if _service_name is not None: | |
redis_kwargs["service_name"] = _service_name | |
if "url" in redis_kwargs and redis_kwargs["url"] is not None: | |
redis_kwargs.pop("host", None) | |
redis_kwargs.pop("port", None) | |
redis_kwargs.pop("db", None) | |
redis_kwargs.pop("password", None) | |
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None: | |
pass | |
elif ( | |
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None | |
): | |
pass | |
elif "host" not in redis_kwargs or redis_kwargs["host"] is None: | |
raise ValueError("Either 'host' or 'url' must be specified for redis.") | |
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") | |
return redis_kwargs | |
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster: | |
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore | |
if _redis_cluster_nodes_in_env is not None: | |
try: | |
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env) | |
except json.JSONDecodeError: | |
raise ValueError( | |
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted." | |
) | |
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.") | |
from redis.cluster import ClusterNode | |
args = _get_redis_cluster_kwargs() | |
cluster_kwargs = {} | |
for arg in redis_kwargs: | |
if arg in args: | |
cluster_kwargs[arg] = redis_kwargs[arg] | |
new_startup_nodes: List[ClusterNode] = [] | |
for item in redis_kwargs["startup_nodes"]: | |
new_startup_nodes.append(ClusterNode(**item)) | |
redis_kwargs.pop("startup_nodes") | |
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore | |
def _init_redis_sentinel(redis_kwargs) -> redis.Redis: | |
sentinel_nodes = redis_kwargs.get("sentinel_nodes") | |
sentinel_password = redis_kwargs.get("sentinel_password") | |
service_name = redis_kwargs.get("service_name") | |
if not sentinel_nodes or not service_name: | |
raise ValueError( | |
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel." | |
) | |
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.") | |
# Set up the Sentinel client | |
sentinel = redis.Sentinel( | |
sentinel_nodes, | |
socket_timeout=REDIS_SOCKET_TIMEOUT, | |
password=sentinel_password, | |
) | |
# Return the master instance for the given service | |
return sentinel.master_for(service_name) | |
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis: | |
sentinel_nodes = redis_kwargs.get("sentinel_nodes") | |
sentinel_password = redis_kwargs.get("sentinel_password") | |
service_name = redis_kwargs.get("service_name") | |
if not sentinel_nodes or not service_name: | |
raise ValueError( | |
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel." | |
) | |
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.") | |
# Set up the Sentinel client | |
sentinel = async_redis.Sentinel( | |
sentinel_nodes, | |
socket_timeout=REDIS_SOCKET_TIMEOUT, | |
password=sentinel_password, | |
) | |
# Return the master instance for the given service | |
return sentinel.master_for(service_name) | |
def get_redis_client(**env_overrides): | |
redis_kwargs = _get_redis_client_logic(**env_overrides) | |
if "url" in redis_kwargs and redis_kwargs["url"] is not None: | |
args = _get_redis_url_kwargs() | |
url_kwargs = {} | |
for arg in redis_kwargs: | |
if arg in args: | |
url_kwargs[arg] = redis_kwargs[arg] | |
return redis.Redis.from_url(**url_kwargs) | |
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore | |
return init_redis_cluster(redis_kwargs) | |
# Check for Redis Sentinel | |
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs: | |
return _init_redis_sentinel(redis_kwargs) | |
return redis.Redis(**redis_kwargs) | |
def get_redis_async_client( | |
**env_overrides, | |
) -> async_redis.Redis: | |
redis_kwargs = _get_redis_client_logic(**env_overrides) | |
if "url" in redis_kwargs and redis_kwargs["url"] is not None: | |
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url) | |
url_kwargs = {} | |
for arg in redis_kwargs: | |
if arg in args: | |
url_kwargs[arg] = redis_kwargs[arg] | |
else: | |
verbose_logger.debug( | |
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format( | |
arg | |
) | |
) | |
return async_redis.Redis.from_url(**url_kwargs) | |
if "startup_nodes" in redis_kwargs: | |
from redis.cluster import ClusterNode | |
args = _get_redis_cluster_kwargs() | |
cluster_kwargs = {} | |
for arg in redis_kwargs: | |
if arg in args: | |
cluster_kwargs[arg] = redis_kwargs[arg] | |
new_startup_nodes: List[ClusterNode] = [] | |
for item in redis_kwargs["startup_nodes"]: | |
new_startup_nodes.append(ClusterNode(**item)) | |
redis_kwargs.pop("startup_nodes") | |
return async_redis.RedisCluster( | |
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore | |
) | |
# Check for Redis Sentinel | |
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs: | |
return _init_async_redis_sentinel(redis_kwargs) | |
return async_redis.Redis( | |
**redis_kwargs, | |
) | |
def get_redis_connection_pool(**env_overrides): | |
redis_kwargs = _get_redis_client_logic(**env_overrides) | |
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs) | |
if "url" in redis_kwargs and redis_kwargs["url"] is not None: | |
return async_redis.BlockingConnectionPool.from_url( | |
timeout=REDIS_CONNECTION_POOL_TIMEOUT, url=redis_kwargs["url"] | |
) | |
connection_class = async_redis.Connection | |
if "ssl" in redis_kwargs: | |
connection_class = async_redis.SSLConnection | |
redis_kwargs.pop("ssl", None) | |
redis_kwargs["connection_class"] = connection_class | |
redis_kwargs.pop("startup_nodes", None) | |
return async_redis.BlockingConnectionPool( | |
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs | |
) | |