litellm / litellm /proxy /db /dynamo_db.py
ka1kuk's picture
Upload 235 files
7db0ae4 verified
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import (
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_Config,
LiteLLM_UserTable,
)
from litellm import get_secret
from typing import Any, List, Literal, Optional, Union
import json
from datetime import datetime
class DynamoDBWrapper(CustomDB):
from aiodynamo.credentials import Credentials, StaticCredentials
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if (
database_arguments.read_capacity_units is not None
and isinstance(database_arguments.read_capacity_units, int)
and database_arguments.write_capacity_units is not None
and isinstance(database_arguments.write_capacity_units, int)
):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
)
self.database_arguments = database_arguments
self.region_name = database_arguments.region_name
async def connect(self):
"""
Connect to DB, and creating / updating any tables
"""
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User
try:
error_occurred = False
table = client.table(self.database_arguments.user_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.user_table_name}.\nPlease create a new table called {self.database_arguments.user_table_name}\nAND set `hash_key` as 'user_id'"
)
## Token
try:
error_occurred = False
table = client.table(self.database_arguments.key_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'"
)
## Config
try:
error_occurred = False
table = client.table(self.database_arguments.config_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'"
)
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
for k, v in value.items():
if isinstance(v, datetime):
value[k] = v.isoformat()
await table.put_item(item=value)
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
key_name = "user_id"
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
key_name = "token"
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
key_name = "param_name"
response = await table.get_item({key_name: key})
new_response: Any = None
if table_name == "user":
new_response = LiteLLM_UserTable(**response)
elif table_name == "key":
new_response = {}
for k, v in response.items(): # handle json string
if (
(k == "aliases" or k == "config" or k == "metadata")
and v is not None
and isinstance(v, str)
):
new_response[k] = json.loads(v)
else:
new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == "config":
new_response = LiteLLM_Config(**response)
return new_response
async def update_data(
self, key: str, value: dict, table_name: Literal["user", "key", "config"]
):
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
try:
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
key_name = "user_id"
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
key_name = "token"
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
key_name = "param_name"
else:
raise Exception(
f"Invalid table name. Needs to be one of - {self.database_arguments.user_table_name}, {self.database_arguments.key_table_name}, {self.database_arguments.config_table_name}"
)
except Exception as e:
raise Exception(f"Error connecting to table - {str(e)}")
# Initialize an empty UpdateExpression
actions: List = []
for k, v in value.items():
# Convert datetime object to ISO8601 string
if isinstance(v, datetime):
v = v.isoformat()
# Accumulate updates
actions.append((F(k), Value(value=v)))
update_expression = UpdateExpression(set_updates=actions)
# Perform the update in DynamoDB
result = await table.update_item(
key={key_name: key},
update_expression=update_expression,
return_values=ReturnValues.none,
)
return result
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
Not Implemented yet.
"""
return super().delete_data(keys, table_name)