|
from litellm.proxy.db.base_client import CustomDB |
|
from litellm.proxy._types import ( |
|
DynamoDBArgs, |
|
LiteLLM_VerificationToken, |
|
LiteLLM_Config, |
|
LiteLLM_UserTable, |
|
) |
|
from litellm import get_secret |
|
from typing import Any, List, Literal, Optional, Union |
|
import json |
|
from datetime import datetime |
|
|
|
|
|
class DynamoDBWrapper(CustomDB): |
|
from aiodynamo.credentials import Credentials, StaticCredentials |
|
|
|
credentials: Credentials |
|
|
|
def __init__(self, database_arguments: DynamoDBArgs): |
|
from aiodynamo.client import Client |
|
from aiodynamo.credentials import Credentials, StaticCredentials |
|
from aiodynamo.http.httpx import HTTPX |
|
from aiodynamo.models import ( |
|
Throughput, |
|
KeySchema, |
|
KeySpec, |
|
KeyType, |
|
PayPerRequest, |
|
) |
|
from yarl import URL |
|
from aiodynamo.expressions import UpdateExpression, F, Value |
|
from aiodynamo.models import ReturnValues |
|
from aiodynamo.http.aiohttp import AIOHTTP |
|
from aiohttp import ClientSession |
|
|
|
self.throughput_type = None |
|
if database_arguments.billing_mode == "PAY_PER_REQUEST": |
|
self.throughput_type = PayPerRequest() |
|
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT": |
|
if ( |
|
database_arguments.read_capacity_units is not None |
|
and isinstance(database_arguments.read_capacity_units, int) |
|
and database_arguments.write_capacity_units is not None |
|
and isinstance(database_arguments.write_capacity_units, int) |
|
): |
|
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) |
|
else: |
|
raise Exception( |
|
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}" |
|
) |
|
self.database_arguments = database_arguments |
|
self.region_name = database_arguments.region_name |
|
|
|
async def connect(self): |
|
""" |
|
Connect to DB, and creating / updating any tables |
|
""" |
|
from aiodynamo.client import Client |
|
from aiodynamo.credentials import Credentials, StaticCredentials |
|
from aiodynamo.http.httpx import HTTPX |
|
from aiodynamo.models import ( |
|
Throughput, |
|
KeySchema, |
|
KeySpec, |
|
KeyType, |
|
PayPerRequest, |
|
) |
|
from yarl import URL |
|
from aiodynamo.expressions import UpdateExpression, F, Value |
|
from aiodynamo.models import ReturnValues |
|
from aiodynamo.http.aiohttp import AIOHTTP |
|
from aiohttp import ClientSession |
|
|
|
async with ClientSession() as session: |
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) |
|
|
|
try: |
|
error_occurred = False |
|
table = client.table(self.database_arguments.user_table_name) |
|
if not await table.exists(): |
|
await table.create( |
|
self.throughput_type, |
|
KeySchema(hash_key=KeySpec("user_id", KeyType.string)), |
|
) |
|
except Exception as e: |
|
error_occurred = True |
|
if error_occurred == True: |
|
raise Exception( |
|
f"Failed to create table - {self.database_arguments.user_table_name}.\nPlease create a new table called {self.database_arguments.user_table_name}\nAND set `hash_key` as 'user_id'" |
|
) |
|
|
|
try: |
|
error_occurred = False |
|
table = client.table(self.database_arguments.key_table_name) |
|
if not await table.exists(): |
|
await table.create( |
|
self.throughput_type, |
|
KeySchema(hash_key=KeySpec("token", KeyType.string)), |
|
) |
|
except Exception as e: |
|
error_occurred = True |
|
if error_occurred == True: |
|
raise Exception( |
|
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'" |
|
) |
|
|
|
try: |
|
error_occurred = False |
|
table = client.table(self.database_arguments.config_table_name) |
|
if not await table.exists(): |
|
await table.create( |
|
self.throughput_type, |
|
KeySchema(hash_key=KeySpec("param_name", KeyType.string)), |
|
) |
|
except Exception as e: |
|
error_occurred = True |
|
if error_occurred == True: |
|
raise Exception( |
|
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'" |
|
) |
|
|
|
async def insert_data( |
|
self, value: Any, table_name: Literal["user", "key", "config"] |
|
): |
|
from aiodynamo.client import Client |
|
from aiodynamo.credentials import Credentials, StaticCredentials |
|
from aiodynamo.http.httpx import HTTPX |
|
from aiodynamo.models import ( |
|
Throughput, |
|
KeySchema, |
|
KeySpec, |
|
KeyType, |
|
PayPerRequest, |
|
) |
|
from yarl import URL |
|
from aiodynamo.expressions import UpdateExpression, F, Value |
|
from aiodynamo.models import ReturnValues |
|
from aiodynamo.http.aiohttp import AIOHTTP |
|
from aiohttp import ClientSession |
|
|
|
async with ClientSession() as session: |
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) |
|
table = None |
|
if table_name == "user": |
|
table = client.table(self.database_arguments.user_table_name) |
|
elif table_name == "key": |
|
table = client.table(self.database_arguments.key_table_name) |
|
elif table_name == "config": |
|
table = client.table(self.database_arguments.config_table_name) |
|
|
|
for k, v in value.items(): |
|
if isinstance(v, datetime): |
|
value[k] = v.isoformat() |
|
|
|
await table.put_item(item=value) |
|
|
|
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): |
|
from aiodynamo.client import Client |
|
from aiodynamo.credentials import Credentials, StaticCredentials |
|
from aiodynamo.http.httpx import HTTPX |
|
from aiodynamo.models import ( |
|
Throughput, |
|
KeySchema, |
|
KeySpec, |
|
KeyType, |
|
PayPerRequest, |
|
) |
|
from yarl import URL |
|
from aiodynamo.expressions import UpdateExpression, F, Value |
|
from aiodynamo.models import ReturnValues |
|
from aiodynamo.http.aiohttp import AIOHTTP |
|
from aiohttp import ClientSession |
|
|
|
async with ClientSession() as session: |
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) |
|
table = None |
|
key_name = None |
|
if table_name == "user": |
|
table = client.table(self.database_arguments.user_table_name) |
|
key_name = "user_id" |
|
elif table_name == "key": |
|
table = client.table(self.database_arguments.key_table_name) |
|
key_name = "token" |
|
elif table_name == "config": |
|
table = client.table(self.database_arguments.config_table_name) |
|
key_name = "param_name" |
|
|
|
response = await table.get_item({key_name: key}) |
|
|
|
new_response: Any = None |
|
if table_name == "user": |
|
new_response = LiteLLM_UserTable(**response) |
|
elif table_name == "key": |
|
new_response = {} |
|
for k, v in response.items(): |
|
if ( |
|
(k == "aliases" or k == "config" or k == "metadata") |
|
and v is not None |
|
and isinstance(v, str) |
|
): |
|
new_response[k] = json.loads(v) |
|
else: |
|
new_response[k] = v |
|
new_response = LiteLLM_VerificationToken(**new_response) |
|
elif table_name == "config": |
|
new_response = LiteLLM_Config(**response) |
|
return new_response |
|
|
|
async def update_data( |
|
self, key: str, value: dict, table_name: Literal["user", "key", "config"] |
|
): |
|
from aiodynamo.client import Client |
|
from aiodynamo.credentials import Credentials, StaticCredentials |
|
from aiodynamo.http.httpx import HTTPX |
|
from aiodynamo.models import ( |
|
Throughput, |
|
KeySchema, |
|
KeySpec, |
|
KeyType, |
|
PayPerRequest, |
|
) |
|
from yarl import URL |
|
from aiodynamo.expressions import UpdateExpression, F, Value |
|
from aiodynamo.models import ReturnValues |
|
from aiodynamo.http.aiohttp import AIOHTTP |
|
from aiohttp import ClientSession |
|
|
|
async with ClientSession() as session: |
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name) |
|
table = None |
|
key_name = None |
|
try: |
|
if table_name == "user": |
|
table = client.table(self.database_arguments.user_table_name) |
|
key_name = "user_id" |
|
|
|
elif table_name == "key": |
|
table = client.table(self.database_arguments.key_table_name) |
|
key_name = "token" |
|
|
|
elif table_name == "config": |
|
table = client.table(self.database_arguments.config_table_name) |
|
key_name = "param_name" |
|
else: |
|
raise Exception( |
|
f"Invalid table name. Needs to be one of - {self.database_arguments.user_table_name}, {self.database_arguments.key_table_name}, {self.database_arguments.config_table_name}" |
|
) |
|
except Exception as e: |
|
raise Exception(f"Error connecting to table - {str(e)}") |
|
|
|
|
|
|
|
actions: List = [] |
|
for k, v in value.items(): |
|
|
|
if isinstance(v, datetime): |
|
v = v.isoformat() |
|
|
|
|
|
actions.append((F(k), Value(value=v))) |
|
|
|
update_expression = UpdateExpression(set_updates=actions) |
|
|
|
result = await table.update_item( |
|
key={key_name: key}, |
|
update_expression=update_expression, |
|
return_values=ReturnValues.none, |
|
) |
|
return result |
|
|
|
async def delete_data( |
|
self, keys: List[str], table_name: Literal["user", "key", "config"] |
|
): |
|
""" |
|
Not Implemented yet. |
|
""" |
|
return super().delete_data(keys, table_name) |
|
|