Spaces:
Sleeping
Sleeping
File size: 9,758 Bytes
7db0ae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
from pydantic import BaseModel, Extra, Field, root_validator
import enum
from typing import Optional, List, Union, Dict, Literal
from datetime import datetime
import uuid, json
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
def fields_set(self):
try:
return self.model_fields_set # noqa
except:
# if using pydantic v1
return self.__fields_set__
######### Request Class Definition ######
class ProxyChatCompletionRequest(LiteLLMBase):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
response_format: Optional[Dict[str, str]] = None
seed: Optional[int] = None
tools: Optional[List[str]] = None
tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated
# Optional LiteLLM params
caching: Optional[bool] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
num_retries: Optional[int] = None
context_window_fallback_dict: Optional[Dict[str, str]] = None
fallbacks: Optional[List[str]] = None
metadata: Optional[Dict[str, str]] = {}
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
class Config:
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase):
id: Optional[str]
class ModelInfo(LiteLLMBase):
id: Optional[str]
mode: Optional[Literal["embedding", "chat", "completion"]]
input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json
base_model: Optional[
Literal[
"gpt-4-1106-preview",
"gpt-4-32k",
"gpt-4",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo",
"text-embedding-ada-002",
]
]
class Config:
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("id") is None:
values.update({"id": str(uuid.uuid4())})
if values.get("mode") is None:
values.update({"mode": None})
if values.get("input_cost_per_token") is None:
values.update({"input_cost_per_token": None})
if values.get("output_cost_per_token") is None:
values.update({"output_cost_per_token": None})
if values.get("max_tokens") is None:
values.update({"max_tokens": None})
if values.get("base_model") is None:
values.update({"base_model": None})
return values
class ModelParams(LiteLLMBase):
model_name: str
litellm_params: dict
model_info: ModelInfo
class Config:
protected_namespaces = ()
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("model_info") is None:
values.update({"model_info": ModelInfo()})
return values
class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h"
models: Optional[list] = []
aliases: Optional[dict] = {}
config: Optional[dict] = {}
spend: Optional[float] = 0
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
class UpdateKeyRequest(LiteLLMBase):
key: str
duration: Optional[str] = None
models: Optional[list] = None
aliases: Optional[dict] = None
config: Optional[dict] = None
spend: Optional[float] = None
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
"""
Return the row in the db
"""
api_key: Optional[str] = None
models: list = []
aliases: dict = {}
config: dict = {}
spend: Optional[float] = 0
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
duration: str = "1h"
metadata: dict = {}
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: Optional[datetime]
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None
class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None
class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"
LOCAL = "local"
class DynamoDBArgs(LiteLLMBase):
billing_mode: Literal["PROVISIONED_THROUGHPUT", "PAY_PER_REQUEST"]
read_capacity_units: Optional[int] = None
write_capacity_units: Optional[int] = None
region_name: str
user_table_name: str = "LiteLLM_UserTable"
key_table_name: str = "LiteLLM_VerificationToken"
config_table_name: str = "LiteLLM_Config"
class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
"""
completion_model: Optional[str] = Field(
None, description="proxy level default model for all chat completion calls"
)
key_management_system: Optional[KeyManagementSystem] = Field(
None, description="key manager to load keys from / decrypt keys with"
)
use_google_kms: Optional[bool] = Field(
None, description="decrypt keys with google kms"
)
use_azure_key_vault: Optional[bool] = Field(
None, description="load keys from azure key vault"
)
master_key: Optional[str] = Field(
None, description="require a key for all calls to proxy"
)
database_url: Optional[str] = Field(
None,
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
)
database_type: Optional[Literal["dynamo_db"]] = Field(
None, description="to use dynamodb instead of postgres db"
)
database_args: Optional[DynamoDBArgs] = Field(
None,
description="custom args for instantiating dynamodb client - e.g. billing provision",
)
otel: Optional[bool] = Field(
None,
description="[BETA] OpenTelemetry support - this might change, use with caution.",
)
custom_auth: Optional[str] = Field(
None,
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
)
max_parallel_requests: Optional[int] = Field(
None, description="maximum parallel requests for each api key"
)
infer_model_from_keys: Optional[bool] = Field(
None,
description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)",
)
background_health_checks: Optional[bool] = Field(
None, description="run health checks in background"
)
health_check_interval: int = Field(
300, description="background health check interval in seconds"
)
alerting: Optional[List] = Field(
None,
description="List of alerting integrations. Today, just slack - `alerting: ['slack']`",
)
alerting_threshold: Optional[int] = Field(
None,
description="sends alerts if requests hang for 5min+",
)
class ConfigYAML(LiteLLMBase):
"""
Documents all the fields supported by the config.yaml
"""
environment_variables: Optional[dict] = Field(
None,
description="Object to pass in additional environment variables via POST request",
)
model_list: Optional[List[ModelParams]] = Field(
None,
description="List of supported models on the server, with model-specific configs",
)
litellm_settings: Optional[dict] = Field(
None,
description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache",
)
general_settings: Optional[ConfigGeneralSettings] = None
class Config:
protected_namespaces = ()
class LiteLLM_VerificationToken(LiteLLMBase):
token: str
spend: float = 0.0
expires: Union[str, None]
models: List[str]
aliases: Dict[str, str] = {}
config: Dict[str, str] = {}
user_id: Union[str, None]
max_parallel_requests: Union[int, None]
metadata: Dict[str, str] = {}
class LiteLLM_Config(LiteLLMBase):
param_name: str
param_value: Dict
class LiteLLM_UserTable(LiteLLMBase):
user_id: str
max_budget: Optional[float]
spend: float = 0.0
user_email: Optional[str]
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values
|