ka1kuk commited on
Commit
7db0ae4
·
verified ·
1 Parent(s): 20a7d21

Upload 235 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. litellm/__init__.py +557 -0
  2. litellm/_logging.py +30 -0
  3. litellm/_redis.py +93 -0
  4. litellm/_version.py +6 -0
  5. litellm/budget_manager.py +206 -0
  6. litellm/caching.py +678 -0
  7. litellm/cost.json +5 -0
  8. litellm/deprecated_litellm_server/.env.template +43 -0
  9. litellm/deprecated_litellm_server/Dockerfile +10 -0
  10. litellm/deprecated_litellm_server/README.md +3 -0
  11. litellm/deprecated_litellm_server/__init__.py +2 -0
  12. litellm/deprecated_litellm_server/main.py +193 -0
  13. litellm/deprecated_litellm_server/requirements.txt +7 -0
  14. litellm/deprecated_litellm_server/server_utils.py +85 -0
  15. litellm/exceptions.py +200 -0
  16. litellm/integrations/__init__.py +1 -0
  17. litellm/integrations/aispend.py +177 -0
  18. litellm/integrations/berrispend.py +184 -0
  19. litellm/integrations/custom_logger.py +130 -0
  20. litellm/integrations/dynamodb.py +92 -0
  21. litellm/integrations/helicone.py +114 -0
  22. litellm/integrations/langfuse.py +191 -0
  23. litellm/integrations/langsmith.py +75 -0
  24. litellm/integrations/litedebugger.py +262 -0
  25. litellm/integrations/llmonitor.py +127 -0
  26. litellm/integrations/prompt_layer.py +72 -0
  27. litellm/integrations/s3.py +150 -0
  28. litellm/integrations/supabase.py +117 -0
  29. litellm/integrations/traceloop.py +114 -0
  30. litellm/integrations/weights_biases.py +223 -0
  31. litellm/llms/__init__.py +1 -0
  32. litellm/llms/ai21.py +212 -0
  33. litellm/llms/aleph_alpha.py +304 -0
  34. litellm/llms/anthropic.py +215 -0
  35. litellm/llms/azure.py +799 -0
  36. litellm/llms/base.py +45 -0
  37. litellm/llms/baseten.py +164 -0
  38. litellm/llms/bedrock.py +799 -0
  39. litellm/llms/cloudflare.py +176 -0
  40. litellm/llms/cohere.py +293 -0
  41. litellm/llms/custom_httpx/azure_dall_e_2.py +136 -0
  42. litellm/llms/custom_httpx/bedrock_async.py +0 -0
  43. litellm/llms/gemini.py +222 -0
  44. litellm/llms/huggingface_llms_metadata/hf_conversational_models.txt +2523 -0
  45. litellm/llms/huggingface_llms_metadata/hf_text_generation_models.txt +0 -0
  46. litellm/llms/huggingface_restapi.py +750 -0
  47. litellm/llms/maritalk.py +189 -0
  48. litellm/llms/nlp_cloud.py +243 -0
  49. litellm/llms/ollama.py +400 -0
  50. litellm/llms/ollama_chat.py +333 -0
litellm/__init__.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### INIT VARIABLES ###
2
+ import threading, requests
3
+ from typing import Callable, List, Optional, Dict, Union, Any
4
+ from litellm.caching import Cache
5
+ from litellm._logging import set_verbose
6
+ from litellm.proxy._types import KeyManagementSystem
7
+ import httpx
8
+
9
+ input_callback: List[Union[str, Callable]] = []
10
+ success_callback: List[Union[str, Callable]] = []
11
+ failure_callback: List[Union[str, Callable]] = []
12
+ callbacks: List[Callable] = []
13
+ _async_input_callback: List[
14
+ Callable
15
+ ] = [] # internal variable - async custom callbacks are routed here.
16
+ _async_success_callback: List[
17
+ Union[str, Callable]
18
+ ] = [] # internal variable - async custom callbacks are routed here.
19
+ _async_failure_callback: List[
20
+ Callable
21
+ ] = [] # internal variable - async custom callbacks are routed here.
22
+ pre_call_rules: List[Callable] = []
23
+ post_call_rules: List[Callable] = []
24
+ email: Optional[
25
+ str
26
+ ] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
27
+ token: Optional[
28
+ str
29
+ ] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
30
+ telemetry = True
31
+ max_tokens = 256 # OpenAI Defaults
32
+ drop_params = False
33
+ retry = True
34
+ api_key: Optional[str] = None
35
+ openai_key: Optional[str] = None
36
+ azure_key: Optional[str] = None
37
+ anthropic_key: Optional[str] = None
38
+ replicate_key: Optional[str] = None
39
+ cohere_key: Optional[str] = None
40
+ maritalk_key: Optional[str] = None
41
+ ai21_key: Optional[str] = None
42
+ openrouter_key: Optional[str] = None
43
+ huggingface_key: Optional[str] = None
44
+ vertex_project: Optional[str] = None
45
+ vertex_location: Optional[str] = None
46
+ togetherai_api_key: Optional[str] = None
47
+ cloudflare_api_key: Optional[str] = None
48
+ baseten_key: Optional[str] = None
49
+ aleph_alpha_key: Optional[str] = None
50
+ nlp_cloud_key: Optional[str] = None
51
+ use_client: bool = False
52
+ logging: bool = True
53
+ caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
54
+ caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
55
+ cache: Optional[
56
+ Cache
57
+ ] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
58
+ model_alias_map: Dict[str, str] = {}
59
+ model_group_alias_map: Dict[str, str] = {}
60
+ max_budget: float = 0.0 # set the max budget across all providers
61
+ _openai_completion_params = [
62
+ "functions",
63
+ "function_call",
64
+ "temperature",
65
+ "temperature",
66
+ "top_p",
67
+ "n",
68
+ "stream",
69
+ "stop",
70
+ "max_tokens",
71
+ "presence_penalty",
72
+ "frequency_penalty",
73
+ "logit_bias",
74
+ "user",
75
+ "request_timeout",
76
+ "api_base",
77
+ "api_version",
78
+ "api_key",
79
+ "deployment_id",
80
+ "organization",
81
+ "base_url",
82
+ "default_headers",
83
+ "timeout",
84
+ "response_format",
85
+ "seed",
86
+ "tools",
87
+ "tool_choice",
88
+ "max_retries",
89
+ ]
90
+ _litellm_completion_params = [
91
+ "metadata",
92
+ "acompletion",
93
+ "caching",
94
+ "mock_response",
95
+ "api_key",
96
+ "api_version",
97
+ "api_base",
98
+ "force_timeout",
99
+ "logger_fn",
100
+ "verbose",
101
+ "custom_llm_provider",
102
+ "litellm_logging_obj",
103
+ "litellm_call_id",
104
+ "use_client",
105
+ "id",
106
+ "fallbacks",
107
+ "azure",
108
+ "headers",
109
+ "model_list",
110
+ "num_retries",
111
+ "context_window_fallback_dict",
112
+ "roles",
113
+ "final_prompt_value",
114
+ "bos_token",
115
+ "eos_token",
116
+ "request_timeout",
117
+ "complete_response",
118
+ "self",
119
+ "client",
120
+ "rpm",
121
+ "tpm",
122
+ "input_cost_per_token",
123
+ "output_cost_per_token",
124
+ "hf_model_name",
125
+ "model_info",
126
+ "proxy_server_request",
127
+ "preset_cache_key",
128
+ ]
129
+ _current_cost = 0 # private variable, used if max budget is set
130
+ error_logs: Dict = {}
131
+ add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
132
+ client_session: Optional[httpx.Client] = None
133
+ aclient_session: Optional[httpx.AsyncClient] = None
134
+ model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
135
+ model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
136
+ suppress_debug_info = False
137
+ dynamodb_table_name: Optional[str] = None
138
+ s3_callback_params: Optional[Dict] = None
139
+ #### RELIABILITY ####
140
+ request_timeout: Optional[float] = 6000
141
+ num_retries: Optional[int] = None # per model endpoint
142
+ fallbacks: Optional[List] = None
143
+ context_window_fallbacks: Optional[List] = None
144
+ allowed_fails: int = 0
145
+ num_retries_per_request: Optional[
146
+ int
147
+ ] = None # for the request overall (incl. fallbacks + model retries)
148
+ ####### SECRET MANAGERS #####################
149
+ secret_manager_client: Optional[
150
+ Any
151
+ ] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
152
+ _google_kms_resource_name: Optional[str] = None
153
+ _key_management_system: Optional[KeyManagementSystem] = None
154
+ #############################################
155
+
156
+
157
+ def get_model_cost_map(url: str):
158
+ try:
159
+ with requests.get(
160
+ url, timeout=5
161
+ ) as response: # set a 5 second timeout for the get request
162
+ response.raise_for_status() # Raise an exception if the request is unsuccessful
163
+ content = response.json()
164
+ return content
165
+ except Exception as e:
166
+ import importlib.resources
167
+ import json
168
+
169
+ with importlib.resources.open_text(
170
+ "litellm", "model_prices_and_context_window_backup.json"
171
+ ) as f:
172
+ content = json.load(f)
173
+ return content
174
+
175
+
176
+ model_cost = get_model_cost_map(url=model_cost_map_url)
177
+ custom_prompt_dict: Dict[str, dict] = {}
178
+
179
+
180
+ ####### THREAD-SPECIFIC DATA ###################
181
+ class MyLocal(threading.local):
182
+ def __init__(self):
183
+ self.user = "Hello World"
184
+
185
+
186
+ _thread_context = MyLocal()
187
+
188
+
189
+ def identify(event_details):
190
+ # Store user in thread local data
191
+ if "user" in event_details:
192
+ _thread_context.user = event_details["user"]
193
+
194
+
195
+ ####### ADDITIONAL PARAMS ################### configurable params if you use proxy models like Helicone, map spend to org id, etc.
196
+ api_base = None
197
+ headers = None
198
+ api_version = None
199
+ organization = None
200
+ config_path = None
201
+ ####### COMPLETION MODELS ###################
202
+ open_ai_chat_completion_models: List = []
203
+ open_ai_text_completion_models: List = []
204
+ cohere_models: List = []
205
+ anthropic_models: List = []
206
+ openrouter_models: List = []
207
+ vertex_language_models: List = []
208
+ vertex_vision_models: List = []
209
+ vertex_chat_models: List = []
210
+ vertex_code_chat_models: List = []
211
+ vertex_text_models: List = []
212
+ vertex_code_text_models: List = []
213
+ ai21_models: List = []
214
+ nlp_cloud_models: List = []
215
+ aleph_alpha_models: List = []
216
+ bedrock_models: List = []
217
+ deepinfra_models: List = []
218
+ perplexity_models: List = []
219
+ for key, value in model_cost.items():
220
+ if value.get("litellm_provider") == "openai":
221
+ open_ai_chat_completion_models.append(key)
222
+ elif value.get("litellm_provider") == "text-completion-openai":
223
+ open_ai_text_completion_models.append(key)
224
+ elif value.get("litellm_provider") == "cohere":
225
+ cohere_models.append(key)
226
+ elif value.get("litellm_provider") == "anthropic":
227
+ anthropic_models.append(key)
228
+ elif value.get("litellm_provider") == "openrouter":
229
+ openrouter_models.append(key)
230
+ elif value.get("litellm_provider") == "vertex_ai-text-models":
231
+ vertex_text_models.append(key)
232
+ elif value.get("litellm_provider") == "vertex_ai-code-text-models":
233
+ vertex_code_text_models.append(key)
234
+ elif value.get("litellm_provider") == "vertex_ai-language-models":
235
+ vertex_language_models.append(key)
236
+ elif value.get("litellm_provider") == "vertex_ai-vision-models":
237
+ vertex_vision_models.append(key)
238
+ elif value.get("litellm_provider") == "vertex_ai-chat-models":
239
+ vertex_chat_models.append(key)
240
+ elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
241
+ vertex_code_chat_models.append(key)
242
+ elif value.get("litellm_provider") == "ai21":
243
+ ai21_models.append(key)
244
+ elif value.get("litellm_provider") == "nlp_cloud":
245
+ nlp_cloud_models.append(key)
246
+ elif value.get("litellm_provider") == "aleph_alpha":
247
+ aleph_alpha_models.append(key)
248
+ elif value.get("litellm_provider") == "bedrock":
249
+ bedrock_models.append(key)
250
+ elif value.get("litellm_provider") == "deepinfra":
251
+ deepinfra_models.append(key)
252
+ elif value.get("litellm_provider") == "perplexity":
253
+ perplexity_models.append(key)
254
+
255
+ # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
256
+ openai_compatible_endpoints: List = [
257
+ "api.perplexity.ai",
258
+ "api.endpoints.anyscale.com/v1",
259
+ "api.deepinfra.com/v1/openai",
260
+ "api.mistral.ai/v1",
261
+ ]
262
+
263
+ # this is maintained for Exception Mapping
264
+ openai_compatible_providers: List = [
265
+ "anyscale",
266
+ "mistral",
267
+ "deepinfra",
268
+ "perplexity",
269
+ "xinference",
270
+ ]
271
+
272
+
273
+ # well supported replicate llms
274
+ replicate_models: List = [
275
+ # llama replicate supported LLMs
276
+ "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
277
+ "a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52",
278
+ "meta/codellama-13b:1c914d844307b0588599b8393480a3ba917b660c7e9dfae681542b5325f228db",
279
+ # Vicuna
280
+ "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
281
+ "joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe",
282
+ # Flan T-5
283
+ "daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f"
284
+ # Others
285
+ "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
286
+ "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
287
+ ]
288
+
289
+ huggingface_models: List = [
290
+ "meta-llama/Llama-2-7b-hf",
291
+ "meta-llama/Llama-2-7b-chat-hf",
292
+ "meta-llama/Llama-2-13b-hf",
293
+ "meta-llama/Llama-2-13b-chat-hf",
294
+ "meta-llama/Llama-2-70b-hf",
295
+ "meta-llama/Llama-2-70b-chat-hf",
296
+ "meta-llama/Llama-2-7b",
297
+ "meta-llama/Llama-2-7b-chat",
298
+ "meta-llama/Llama-2-13b",
299
+ "meta-llama/Llama-2-13b-chat",
300
+ "meta-llama/Llama-2-70b",
301
+ "meta-llama/Llama-2-70b-chat",
302
+ ] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/providers
303
+
304
+ together_ai_models: List = [
305
+ # llama llms - chat
306
+ "togethercomputer/llama-2-70b-chat",
307
+ # llama llms - language / instruct
308
+ "togethercomputer/llama-2-70b",
309
+ "togethercomputer/LLaMA-2-7B-32K",
310
+ "togethercomputer/Llama-2-7B-32K-Instruct",
311
+ "togethercomputer/llama-2-7b",
312
+ # falcon llms
313
+ "togethercomputer/falcon-40b-instruct",
314
+ "togethercomputer/falcon-7b-instruct",
315
+ # alpaca
316
+ "togethercomputer/alpaca-7b",
317
+ # chat llms
318
+ "HuggingFaceH4/starchat-alpha",
319
+ # code llms
320
+ "togethercomputer/CodeLlama-34b",
321
+ "togethercomputer/CodeLlama-34b-Instruct",
322
+ "togethercomputer/CodeLlama-34b-Python",
323
+ "defog/sqlcoder",
324
+ "NumbersStation/nsql-llama-2-7B",
325
+ "WizardLM/WizardCoder-15B-V1.0",
326
+ "WizardLM/WizardCoder-Python-34B-V1.0",
327
+ # language llms
328
+ "NousResearch/Nous-Hermes-Llama2-13b",
329
+ "Austism/chronos-hermes-13b",
330
+ "upstage/SOLAR-0-70b-16bit",
331
+ "WizardLM/WizardLM-70B-V1.0",
332
+ ] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
333
+
334
+
335
+ baseten_models: List = [
336
+ "qvv0xeq",
337
+ "q841o8w",
338
+ "31dxrj3",
339
+ ] # FALCON 7B # WizardLM # Mosaic ML
340
+
341
+
342
+ # used for Cost Tracking & Token counting
343
+ # https://azure.microsoft.com/en-in/pricing/details/cognitive-services/openai-service/
344
+ # Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting
345
+ azure_llms = {
346
+ "gpt-35-turbo": "azure/gpt-35-turbo",
347
+ "gpt-35-turbo-16k": "azure/gpt-35-turbo-16k",
348
+ "gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct",
349
+ }
350
+
351
+ azure_embedding_models = {
352
+ "ada": "azure/ada",
353
+ }
354
+
355
+ petals_models = [
356
+ "petals-team/StableBeluga2",
357
+ ]
358
+
359
+ ollama_models = ["llama2"]
360
+
361
+ maritalk_models = ["maritalk"]
362
+
363
+ model_list = (
364
+ open_ai_chat_completion_models
365
+ + open_ai_text_completion_models
366
+ + cohere_models
367
+ + anthropic_models
368
+ + replicate_models
369
+ + openrouter_models
370
+ + huggingface_models
371
+ + vertex_chat_models
372
+ + vertex_text_models
373
+ + ai21_models
374
+ + together_ai_models
375
+ + baseten_models
376
+ + aleph_alpha_models
377
+ + nlp_cloud_models
378
+ + ollama_models
379
+ + bedrock_models
380
+ + deepinfra_models
381
+ + perplexity_models
382
+ + maritalk_models
383
+ )
384
+
385
+ provider_list: List = [
386
+ "openai",
387
+ "custom_openai",
388
+ "text-completion-openai",
389
+ "cohere",
390
+ "anthropic",
391
+ "replicate",
392
+ "huggingface",
393
+ "together_ai",
394
+ "openrouter",
395
+ "vertex_ai",
396
+ "palm",
397
+ "gemini",
398
+ "ai21",
399
+ "baseten",
400
+ "azure",
401
+ "sagemaker",
402
+ "bedrock",
403
+ "vllm",
404
+ "nlp_cloud",
405
+ "petals",
406
+ "oobabooga",
407
+ "ollama",
408
+ "ollama_chat",
409
+ "deepinfra",
410
+ "perplexity",
411
+ "anyscale",
412
+ "mistral",
413
+ "maritalk",
414
+ "voyage",
415
+ "cloudflare",
416
+ "xinference",
417
+ "custom", # custom apis
418
+ ]
419
+
420
+ models_by_provider: dict = {
421
+ "openai": open_ai_chat_completion_models + open_ai_text_completion_models,
422
+ "cohere": cohere_models,
423
+ "anthropic": anthropic_models,
424
+ "replicate": replicate_models,
425
+ "huggingface": huggingface_models,
426
+ "together_ai": together_ai_models,
427
+ "baseten": baseten_models,
428
+ "openrouter": openrouter_models,
429
+ "vertex_ai": vertex_chat_models + vertex_text_models,
430
+ "ai21": ai21_models,
431
+ "bedrock": bedrock_models,
432
+ "petals": petals_models,
433
+ "ollama": ollama_models,
434
+ "deepinfra": deepinfra_models,
435
+ "perplexity": perplexity_models,
436
+ "maritalk": maritalk_models,
437
+ }
438
+
439
+ # mapping for those models which have larger equivalents
440
+ longer_context_model_fallback_dict: dict = {
441
+ # openai chat completion models
442
+ "gpt-3.5-turbo": "gpt-3.5-turbo-16k",
443
+ "gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301",
444
+ "gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613",
445
+ "gpt-4": "gpt-4-32k",
446
+ "gpt-4-0314": "gpt-4-32k-0314",
447
+ "gpt-4-0613": "gpt-4-32k-0613",
448
+ # anthropic
449
+ "claude-instant-1": "claude-2",
450
+ "claude-instant-1.2": "claude-2",
451
+ # vertexai
452
+ "chat-bison": "chat-bison-32k",
453
+ "chat-bison@001": "chat-bison-32k",
454
+ "codechat-bison": "codechat-bison-32k",
455
+ "codechat-bison@001": "codechat-bison-32k",
456
+ # openrouter
457
+ "openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k",
458
+ "openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2",
459
+ }
460
+
461
+ ####### EMBEDDING MODELS ###################
462
+ open_ai_embedding_models: List = ["text-embedding-ada-002"]
463
+ cohere_embedding_models: List = [
464
+ "embed-english-v3.0",
465
+ "embed-english-light-v3.0",
466
+ "embed-multilingual-v3.0",
467
+ "embed-english-v2.0",
468
+ "embed-english-light-v2.0",
469
+ "embed-multilingual-v2.0",
470
+ ]
471
+ bedrock_embedding_models: List = [
472
+ "amazon.titan-embed-text-v1",
473
+ "cohere.embed-english-v3",
474
+ "cohere.embed-multilingual-v3",
475
+ ]
476
+
477
+ all_embedding_models = (
478
+ open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
479
+ )
480
+
481
+ ####### IMAGE GENERATION MODELS ###################
482
+ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
483
+
484
+
485
+ from .timeout import timeout
486
+ from .utils import (
487
+ client,
488
+ exception_type,
489
+ get_optional_params,
490
+ modify_integration,
491
+ token_counter,
492
+ cost_per_token,
493
+ completion_cost,
494
+ get_litellm_params,
495
+ Logging,
496
+ acreate,
497
+ get_model_list,
498
+ get_max_tokens,
499
+ get_model_info,
500
+ register_prompt_template,
501
+ validate_environment,
502
+ check_valid_key,
503
+ get_llm_provider,
504
+ register_model,
505
+ encode,
506
+ decode,
507
+ _calculate_retry_after,
508
+ _should_retry,
509
+ get_secret,
510
+ )
511
+ from .llms.huggingface_restapi import HuggingfaceConfig
512
+ from .llms.anthropic import AnthropicConfig
513
+ from .llms.replicate import ReplicateConfig
514
+ from .llms.cohere import CohereConfig
515
+ from .llms.ai21 import AI21Config
516
+ from .llms.together_ai import TogetherAIConfig
517
+ from .llms.cloudflare import CloudflareConfig
518
+ from .llms.palm import PalmConfig
519
+ from .llms.gemini import GeminiConfig
520
+ from .llms.nlp_cloud import NLPCloudConfig
521
+ from .llms.aleph_alpha import AlephAlphaConfig
522
+ from .llms.petals import PetalsConfig
523
+ from .llms.vertex_ai import VertexAIConfig
524
+ from .llms.sagemaker import SagemakerConfig
525
+ from .llms.ollama import OllamaConfig
526
+ from .llms.maritalk import MaritTalkConfig
527
+ from .llms.bedrock import (
528
+ AmazonTitanConfig,
529
+ AmazonAI21Config,
530
+ AmazonAnthropicConfig,
531
+ AmazonCohereConfig,
532
+ AmazonLlamaConfig,
533
+ )
534
+ from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
535
+ from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
536
+ from .main import * # type: ignore
537
+ from .integrations import *
538
+ from .exceptions import (
539
+ AuthenticationError,
540
+ InvalidRequestError,
541
+ BadRequestError,
542
+ NotFoundError,
543
+ RateLimitError,
544
+ ServiceUnavailableError,
545
+ OpenAIError,
546
+ ContextWindowExceededError,
547
+ ContentPolicyViolationError,
548
+ BudgetExceededError,
549
+ APIError,
550
+ Timeout,
551
+ APIConnectionError,
552
+ APIResponseValidationError,
553
+ UnprocessableEntityError,
554
+ )
555
+ from .budget_manager import BudgetManager
556
+ from .proxy.proxy_cli import run_server
557
+ from .router import Router
litellm/_logging.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ set_verbose = False
4
+
5
+ # Create a handler for the logger (you may need to adapt this based on your needs)
6
+ handler = logging.StreamHandler()
7
+ handler.setLevel(logging.DEBUG)
8
+
9
+ # Create a formatter and set it for the handler
10
+
11
+ formatter = logging.Formatter("\033[92m%(name)s - %(levelname)s\033[0m: %(message)s")
12
+
13
+ handler.setFormatter(formatter)
14
+
15
+
16
+ def print_verbose(print_statement):
17
+ try:
18
+ if set_verbose:
19
+ print(print_statement) # noqa
20
+ except:
21
+ pass
22
+
23
+
24
+ verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
25
+ verbose_router_logger = logging.getLogger("LiteLLM Router")
26
+ verbose_logger = logging.getLogger("LiteLLM")
27
+
28
+ # Add the handler to the logger
29
+ verbose_router_logger.addHandler(handler)
30
+ verbose_proxy_logger.addHandler(handler)
litellm/_redis.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +-----------------------------------------------+
2
+ # | |
3
+ # | Give Feedback / Get Help |
4
+ # | https://github.com/BerriAI/litellm/issues/new |
5
+ # | |
6
+ # +-----------------------------------------------+
7
+ #
8
+ # Thank you users! We ❤️ you! - Krrish & Ishaan
9
+
10
+ # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
11
+ import os
12
+ import inspect
13
+ import redis, litellm
14
+ from typing import List, Optional
15
+
16
+
17
+ def _get_redis_kwargs():
18
+ arg_spec = inspect.getfullargspec(redis.Redis)
19
+
20
+ # Only allow primitive arguments
21
+ exclude_args = {
22
+ "self",
23
+ "connection_pool",
24
+ "retry",
25
+ }
26
+
27
+ include_args = ["url"]
28
+
29
+ available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
30
+
31
+ return available_args
32
+
33
+
34
+ def _get_redis_env_kwarg_mapping():
35
+ PREFIX = "REDIS_"
36
+
37
+ return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
38
+
39
+
40
+ def _redis_kwargs_from_environment():
41
+ mapping = _get_redis_env_kwarg_mapping()
42
+
43
+ return_dict = {}
44
+ for k, v in mapping.items():
45
+ value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
46
+ if value is not None:
47
+ return_dict[v] = value
48
+ return return_dict
49
+
50
+
51
+ def get_redis_url_from_environment():
52
+ if "REDIS_URL" in os.environ:
53
+ return os.environ["REDIS_URL"]
54
+
55
+ if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
56
+ raise ValueError(
57
+ "Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
58
+ )
59
+
60
+ if "REDIS_PASSWORD" in os.environ:
61
+ redis_password = f":{os.environ['REDIS_PASSWORD']}@"
62
+ else:
63
+ redis_password = ""
64
+
65
+ return (
66
+ f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
67
+ )
68
+
69
+
70
+ def get_redis_client(**env_overrides):
71
+ ### check if "os.environ/<key-name>" passed in
72
+ for k, v in env_overrides.items():
73
+ if isinstance(v, str) and v.startswith("os.environ/"):
74
+ v = v.replace("os.environ/", "")
75
+ value = litellm.get_secret(v)
76
+ env_overrides[k] = value
77
+
78
+ redis_kwargs = {
79
+ **_redis_kwargs_from_environment(),
80
+ **env_overrides,
81
+ }
82
+
83
+ if "url" in redis_kwargs and redis_kwargs["url"] is not None:
84
+ redis_kwargs.pop("host", None)
85
+ redis_kwargs.pop("port", None)
86
+ redis_kwargs.pop("db", None)
87
+ redis_kwargs.pop("password", None)
88
+
89
+ return redis.Redis.from_url(**redis_kwargs)
90
+ elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
91
+ raise ValueError("Either 'host' or 'url' must be specified for redis.")
92
+ litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
93
+ return redis.Redis(**redis_kwargs)
litellm/_version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import importlib_metadata
2
+
3
+ try:
4
+ version = importlib_metadata.version("litellm")
5
+ except:
6
+ pass
litellm/budget_manager.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, time
2
+ import litellm
3
+ from litellm.utils import ModelResponse
4
+ import requests, threading
5
+ from typing import Optional, Union, Literal
6
+
7
+
8
+ class BudgetManager:
9
+ def __init__(
10
+ self,
11
+ project_name: str,
12
+ client_type: str = "local",
13
+ api_base: Optional[str] = None,
14
+ ):
15
+ self.client_type = client_type
16
+ self.project_name = project_name
17
+ self.api_base = api_base or "https://api.litellm.ai"
18
+ ## load the data or init the initial dictionaries
19
+ self.load_data()
20
+
21
+ def print_verbose(self, print_statement):
22
+ try:
23
+ if litellm.set_verbose:
24
+ import logging
25
+
26
+ logging.info(print_statement)
27
+ except:
28
+ pass
29
+
30
+ def load_data(self):
31
+ if self.client_type == "local":
32
+ # Check if user dict file exists
33
+ if os.path.isfile("user_cost.json"):
34
+ # Load the user dict
35
+ with open("user_cost.json", "r") as json_file:
36
+ self.user_dict = json.load(json_file)
37
+ else:
38
+ self.print_verbose("User Dictionary not found!")
39
+ self.user_dict = {}
40
+ self.print_verbose(f"user dict from local: {self.user_dict}")
41
+ elif self.client_type == "hosted":
42
+ # Load the user_dict from hosted db
43
+ url = self.api_base + "/get_budget"
44
+ headers = {"Content-Type": "application/json"}
45
+ data = {"project_name": self.project_name}
46
+ response = requests.post(url, headers=headers, json=data)
47
+ response = response.json()
48
+ if response["status"] == "error":
49
+ self.user_dict = (
50
+ {}
51
+ ) # assume this means the user dict hasn't been stored yet
52
+ else:
53
+ self.user_dict = response["data"]
54
+
55
+ def create_budget(
56
+ self,
57
+ total_budget: float,
58
+ user: str,
59
+ duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
60
+ created_at: float = time.time(),
61
+ ):
62
+ self.user_dict[user] = {"total_budget": total_budget}
63
+ if duration is None:
64
+ return self.user_dict[user]
65
+
66
+ if duration == "daily":
67
+ duration_in_days = 1
68
+ elif duration == "weekly":
69
+ duration_in_days = 7
70
+ elif duration == "monthly":
71
+ duration_in_days = 28
72
+ elif duration == "yearly":
73
+ duration_in_days = 365
74
+ else:
75
+ raise ValueError(
76
+ """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
77
+ )
78
+ self.user_dict[user] = {
79
+ "total_budget": total_budget,
80
+ "duration": duration_in_days,
81
+ "created_at": created_at,
82
+ "last_updated_at": created_at,
83
+ }
84
+ self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
85
+ return self.user_dict[user]
86
+
87
+ def projected_cost(self, model: str, messages: list, user: str):
88
+ text = "".join(message["content"] for message in messages)
89
+ prompt_tokens = litellm.token_counter(model=model, text=text)
90
+ prompt_cost, _ = litellm.cost_per_token(
91
+ model=model, prompt_tokens=prompt_tokens, completion_tokens=0
92
+ )
93
+ current_cost = self.user_dict[user].get("current_cost", 0)
94
+ projected_cost = prompt_cost + current_cost
95
+ return projected_cost
96
+
97
+ def get_total_budget(self, user: str):
98
+ return self.user_dict[user]["total_budget"]
99
+
100
+ def update_cost(
101
+ self,
102
+ user: str,
103
+ completion_obj: Optional[ModelResponse] = None,
104
+ model: Optional[str] = None,
105
+ input_text: Optional[str] = None,
106
+ output_text: Optional[str] = None,
107
+ ):
108
+ if model and input_text and output_text:
109
+ prompt_tokens = litellm.token_counter(
110
+ model=model, messages=[{"role": "user", "content": input_text}]
111
+ )
112
+ completion_tokens = litellm.token_counter(
113
+ model=model, messages=[{"role": "user", "content": output_text}]
114
+ )
115
+ (
116
+ prompt_tokens_cost_usd_dollar,
117
+ completion_tokens_cost_usd_dollar,
118
+ ) = litellm.cost_per_token(
119
+ model=model,
120
+ prompt_tokens=prompt_tokens,
121
+ completion_tokens=completion_tokens,
122
+ )
123
+ cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
124
+ elif completion_obj:
125
+ cost = litellm.completion_cost(completion_response=completion_obj)
126
+ model = completion_obj[
127
+ "model"
128
+ ] # if this throws an error try, model = completion_obj['model']
129
+ else:
130
+ raise ValueError(
131
+ "Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
132
+ )
133
+
134
+ self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
135
+ "current_cost", 0
136
+ )
137
+ if "model_cost" in self.user_dict[user]:
138
+ self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
139
+ "model_cost"
140
+ ].get(model, 0)
141
+ else:
142
+ self.user_dict[user]["model_cost"] = {model: cost}
143
+
144
+ self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
145
+ return {"user": self.user_dict[user]}
146
+
147
+ def get_current_cost(self, user):
148
+ return self.user_dict[user].get("current_cost", 0)
149
+
150
+ def get_model_cost(self, user):
151
+ return self.user_dict[user].get("model_cost", 0)
152
+
153
+ def is_valid_user(self, user: str) -> bool:
154
+ return user in self.user_dict
155
+
156
+ def get_users(self):
157
+ return list(self.user_dict.keys())
158
+
159
+ def reset_cost(self, user):
160
+ self.user_dict[user]["current_cost"] = 0
161
+ self.user_dict[user]["model_cost"] = {}
162
+ return {"user": self.user_dict[user]}
163
+
164
+ def reset_on_duration(self, user: str):
165
+ # Get current and creation time
166
+ last_updated_at = self.user_dict[user]["last_updated_at"]
167
+ current_time = time.time()
168
+
169
+ # Convert duration from days to seconds
170
+ duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60
171
+
172
+ # Check if duration has elapsed
173
+ if current_time - last_updated_at >= duration_in_seconds:
174
+ # Reset cost if duration has elapsed and update the creation time
175
+ self.reset_cost(user)
176
+ self.user_dict[user]["last_updated_at"] = current_time
177
+ self._save_data_thread() # Save the data
178
+
179
+ def update_budget_all_users(self):
180
+ for user in self.get_users():
181
+ if "duration" in self.user_dict[user]:
182
+ self.reset_on_duration(user)
183
+
184
+ def _save_data_thread(self):
185
+ thread = threading.Thread(
186
+ target=self.save_data
187
+ ) # [Non-Blocking]: saves data without blocking execution
188
+ thread.start()
189
+
190
+ def save_data(self):
191
+ if self.client_type == "local":
192
+ import json
193
+
194
+ # save the user dict
195
+ with open("user_cost.json", "w") as json_file:
196
+ json.dump(
197
+ self.user_dict, json_file, indent=4
198
+ ) # Indent for pretty formatting
199
+ return {"status": "success"}
200
+ elif self.client_type == "hosted":
201
+ url = self.api_base + "/set_budget"
202
+ headers = {"Content-Type": "application/json"}
203
+ data = {"project_name": self.project_name, "user_dict": self.user_dict}
204
+ response = requests.post(url, headers=headers, json=data)
205
+ response = response.json()
206
+ return response
litellm/caching.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +-----------------------------------------------+
2
+ # | |
3
+ # | Give Feedback / Get Help |
4
+ # | https://github.com/BerriAI/litellm/issues/new |
5
+ # | |
6
+ # +-----------------------------------------------+
7
+ #
8
+ # Thank you users! We ❤️ you! - Krrish & Ishaan
9
+
10
+ import litellm
11
+ import time, logging
12
+ import json, traceback, ast, hashlib
13
+ from typing import Optional, Literal, List, Union, Any
14
+ from openai._models import BaseModel as OpenAIObject
15
+
16
+
17
+ def print_verbose(print_statement):
18
+ try:
19
+ if litellm.set_verbose:
20
+ print(print_statement) # noqa
21
+ except:
22
+ pass
23
+
24
+
25
+ class BaseCache:
26
+ def set_cache(self, key, value, **kwargs):
27
+ raise NotImplementedError
28
+
29
+ def get_cache(self, key, **kwargs):
30
+ raise NotImplementedError
31
+
32
+
33
+ class InMemoryCache(BaseCache):
34
+ def __init__(self):
35
+ # if users don't provider one, use the default litellm cache
36
+ self.cache_dict = {}
37
+ self.ttl_dict = {}
38
+
39
+ def set_cache(self, key, value, **kwargs):
40
+ self.cache_dict[key] = value
41
+ if "ttl" in kwargs:
42
+ self.ttl_dict[key] = time.time() + kwargs["ttl"]
43
+
44
+ def get_cache(self, key, **kwargs):
45
+ if key in self.cache_dict:
46
+ if key in self.ttl_dict:
47
+ if time.time() > self.ttl_dict[key]:
48
+ self.cache_dict.pop(key, None)
49
+ return None
50
+ original_cached_response = self.cache_dict[key]
51
+ try:
52
+ cached_response = json.loads(original_cached_response)
53
+ except:
54
+ cached_response = original_cached_response
55
+ return cached_response
56
+ return None
57
+
58
+ def flush_cache(self):
59
+ self.cache_dict.clear()
60
+ self.ttl_dict.clear()
61
+
62
+
63
+ class RedisCache(BaseCache):
64
+ def __init__(self, host=None, port=None, password=None, **kwargs):
65
+ import redis
66
+
67
+ # if users don't provider one, use the default litellm cache
68
+ from ._redis import get_redis_client
69
+
70
+ redis_kwargs = {}
71
+ if host is not None:
72
+ redis_kwargs["host"] = host
73
+ if port is not None:
74
+ redis_kwargs["port"] = port
75
+ if password is not None:
76
+ redis_kwargs["password"] = password
77
+
78
+ redis_kwargs.update(kwargs)
79
+
80
+ self.redis_client = get_redis_client(**redis_kwargs)
81
+
82
+ def set_cache(self, key, value, **kwargs):
83
+ ttl = kwargs.get("ttl", None)
84
+ print_verbose(f"Set Redis Cache: key: {key}\nValue {value}")
85
+ try:
86
+ self.redis_client.set(name=key, value=str(value), ex=ttl)
87
+ except Exception as e:
88
+ # NON blocking - notify users Redis is throwing an exception
89
+ logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
90
+
91
+ def get_cache(self, key, **kwargs):
92
+ try:
93
+ print_verbose(f"Get Redis Cache: key: {key}")
94
+ cached_response = self.redis_client.get(key)
95
+ print_verbose(
96
+ f"Got Redis Cache: key: {key}, cached_response {cached_response}"
97
+ )
98
+ if cached_response != None:
99
+ # cached_response is in `b{} convert it to ModelResponse
100
+ cached_response = cached_response.decode(
101
+ "utf-8"
102
+ ) # Convert bytes to string
103
+ try:
104
+ cached_response = json.loads(
105
+ cached_response
106
+ ) # Convert string to dictionary
107
+ except:
108
+ cached_response = ast.literal_eval(cached_response)
109
+ return cached_response
110
+ except Exception as e:
111
+ # NON blocking - notify users Redis is throwing an exception
112
+ traceback.print_exc()
113
+ logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
114
+
115
+ def flush_cache(self):
116
+ self.redis_client.flushall()
117
+
118
+
119
+ class S3Cache(BaseCache):
120
+ def __init__(
121
+ self,
122
+ s3_bucket_name,
123
+ s3_region_name=None,
124
+ s3_api_version=None,
125
+ s3_use_ssl=True,
126
+ s3_verify=None,
127
+ s3_endpoint_url=None,
128
+ s3_aws_access_key_id=None,
129
+ s3_aws_secret_access_key=None,
130
+ s3_aws_session_token=None,
131
+ s3_config=None,
132
+ **kwargs,
133
+ ):
134
+ import boto3
135
+
136
+ self.bucket_name = s3_bucket_name
137
+ # Create an S3 client with custom endpoint URL
138
+ self.s3_client = boto3.client(
139
+ "s3",
140
+ region_name=s3_region_name,
141
+ endpoint_url=s3_endpoint_url,
142
+ api_version=s3_api_version,
143
+ use_ssl=s3_use_ssl,
144
+ verify=s3_verify,
145
+ aws_access_key_id=s3_aws_access_key_id,
146
+ aws_secret_access_key=s3_aws_secret_access_key,
147
+ aws_session_token=s3_aws_session_token,
148
+ config=s3_config,
149
+ **kwargs,
150
+ )
151
+
152
+ def set_cache(self, key, value, **kwargs):
153
+ try:
154
+ print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
155
+ ttl = kwargs.get("ttl", None)
156
+ # Convert value to JSON before storing in S3
157
+ serialized_value = json.dumps(value)
158
+ if ttl is not None:
159
+ cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
160
+ import datetime
161
+
162
+ # Calculate expiration time
163
+ expiration_time = datetime.datetime.now() + ttl
164
+
165
+ # Upload the data to S3 with the calculated expiration time
166
+ self.s3_client.put_object(
167
+ Bucket=self.bucket_name,
168
+ Key=key,
169
+ Body=serialized_value,
170
+ Expires=expiration_time,
171
+ CacheControl=cache_control,
172
+ ContentType="application/json",
173
+ ContentLanguage="en",
174
+ ContentDisposition=f"inline; filename=\"{key}.json\""
175
+ )
176
+ else:
177
+ cache_control = "immutable, max-age=31536000, s-maxage=31536000"
178
+ # Upload the data to S3 without specifying Expires
179
+ self.s3_client.put_object(
180
+ Bucket=self.bucket_name,
181
+ Key=key,
182
+ Body=serialized_value,
183
+ CacheControl=cache_control,
184
+ ContentType="application/json",
185
+ ContentLanguage="en",
186
+ ContentDisposition=f"inline; filename=\"{key}.json\""
187
+ )
188
+ except Exception as e:
189
+ # NON blocking - notify users S3 is throwing an exception
190
+ print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
191
+
192
+ def get_cache(self, key, **kwargs):
193
+ import boto3, botocore
194
+
195
+ try:
196
+ print_verbose(f"Get S3 Cache: key: {key}")
197
+ # Download the data from S3
198
+ cached_response = self.s3_client.get_object(
199
+ Bucket=self.bucket_name, Key=key
200
+ )
201
+
202
+ if cached_response != None:
203
+ # cached_response is in `b{} convert it to ModelResponse
204
+ cached_response = (
205
+ cached_response["Body"].read().decode("utf-8")
206
+ ) # Convert bytes to string
207
+ try:
208
+ cached_response = json.loads(
209
+ cached_response
210
+ ) # Convert string to dictionary
211
+ except Exception as e:
212
+ cached_response = ast.literal_eval(cached_response)
213
+ if type(cached_response) is not dict:
214
+ cached_response = dict(cached_response)
215
+ print_verbose(
216
+ f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
217
+ )
218
+
219
+ return cached_response
220
+ except botocore.exceptions.ClientError as e:
221
+ if e.response["Error"]["Code"] == "NoSuchKey":
222
+ print_verbose(
223
+ f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
224
+ )
225
+ return None
226
+
227
+ except Exception as e:
228
+ # NON blocking - notify users S3 is throwing an exception
229
+ traceback.print_exc()
230
+ print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
231
+
232
+ def flush_cache(self):
233
+ pass
234
+
235
+
236
+ class DualCache(BaseCache):
237
+ """
238
+ This updates both Redis and an in-memory cache simultaneously.
239
+ When data is updated or inserted, it is written to both the in-memory cache + Redis.
240
+ This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ in_memory_cache: Optional[InMemoryCache] = None,
246
+ redis_cache: Optional[RedisCache] = None,
247
+ ) -> None:
248
+ super().__init__()
249
+ # If in_memory_cache is not provided, use the default InMemoryCache
250
+ self.in_memory_cache = in_memory_cache or InMemoryCache()
251
+ # If redis_cache is not provided, use the default RedisCache
252
+ self.redis_cache = redis_cache
253
+
254
+ def set_cache(self, key, value, local_only: bool = False, **kwargs):
255
+ # Update both Redis and in-memory cache
256
+ try:
257
+ print_verbose(f"set cache: key: {key}; value: {value}")
258
+ if self.in_memory_cache is not None:
259
+ self.in_memory_cache.set_cache(key, value, **kwargs)
260
+
261
+ if self.redis_cache is not None and local_only == False:
262
+ self.redis_cache.set_cache(key, value, **kwargs)
263
+ except Exception as e:
264
+ print_verbose(e)
265
+
266
+ def get_cache(self, key, local_only: bool = False, **kwargs):
267
+ # Try to fetch from in-memory cache first
268
+ try:
269
+ print_verbose(f"get cache: cache key: {key}; local_only: {local_only}")
270
+ result = None
271
+ if self.in_memory_cache is not None:
272
+ in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
273
+
274
+ print_verbose(f"in_memory_result: {in_memory_result}")
275
+ if in_memory_result is not None:
276
+ result = in_memory_result
277
+
278
+ if result is None and self.redis_cache is not None and local_only == False:
279
+ # If not found in in-memory cache, try fetching from Redis
280
+ redis_result = self.redis_cache.get_cache(key, **kwargs)
281
+
282
+ if redis_result is not None:
283
+ # Update in-memory cache with the value from Redis
284
+ self.in_memory_cache.set_cache(key, redis_result, **kwargs)
285
+
286
+ result = redis_result
287
+
288
+ print_verbose(f"get cache: cache result: {result}")
289
+ return result
290
+ except Exception as e:
291
+ traceback.print_exc()
292
+
293
+ def flush_cache(self):
294
+ if self.in_memory_cache is not None:
295
+ self.in_memory_cache.flush_cache()
296
+ if self.redis_cache is not None:
297
+ self.redis_cache.flush_cache()
298
+
299
+
300
+ #### LiteLLM.Completion / Embedding Cache ####
301
+ class Cache:
302
+ def __init__(
303
+ self,
304
+ type: Optional[Literal["local", "redis", "s3"]] = "local",
305
+ host: Optional[str] = None,
306
+ port: Optional[str] = None,
307
+ password: Optional[str] = None,
308
+ supported_call_types: Optional[
309
+ List[Literal["completion", "acompletion", "embedding", "aembedding"]]
310
+ ] = ["completion", "acompletion", "embedding", "aembedding"],
311
+ # s3 Bucket, boto3 configuration
312
+ s3_bucket_name: Optional[str] = None,
313
+ s3_region_name: Optional[str] = None,
314
+ s3_api_version: Optional[str] = None,
315
+ s3_use_ssl: Optional[bool] = True,
316
+ s3_verify: Optional[Union[bool, str]] = None,
317
+ s3_endpoint_url: Optional[str] = None,
318
+ s3_aws_access_key_id: Optional[str] = None,
319
+ s3_aws_secret_access_key: Optional[str] = None,
320
+ s3_aws_session_token: Optional[str] = None,
321
+ s3_config: Optional[Any] = None,
322
+ **kwargs,
323
+ ):
324
+ """
325
+ Initializes the cache based on the given type.
326
+
327
+ Args:
328
+ type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
329
+ host (str, optional): The host address for the Redis cache. Required if type is "redis".
330
+ port (int, optional): The port number for the Redis cache. Required if type is "redis".
331
+ password (str, optional): The password for the Redis cache. Required if type is "redis".
332
+ supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
333
+ **kwargs: Additional keyword arguments for redis.Redis() cache
334
+
335
+ Raises:
336
+ ValueError: If an invalid cache type is provided.
337
+
338
+ Returns:
339
+ None. Cache is set as a litellm param
340
+ """
341
+ if type == "redis":
342
+ self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
343
+ if type == "local":
344
+ self.cache = InMemoryCache()
345
+ if type == "s3":
346
+ self.cache = S3Cache(
347
+ s3_bucket_name=s3_bucket_name,
348
+ s3_region_name=s3_region_name,
349
+ s3_api_version=s3_api_version,
350
+ s3_use_ssl=s3_use_ssl,
351
+ s3_verify=s3_verify,
352
+ s3_endpoint_url=s3_endpoint_url,
353
+ s3_aws_access_key_id=s3_aws_access_key_id,
354
+ s3_aws_secret_access_key=s3_aws_secret_access_key,
355
+ s3_aws_session_token=s3_aws_session_token,
356
+ s3_config=s3_config,
357
+ **kwargs,
358
+ )
359
+ if "cache" not in litellm.input_callback:
360
+ litellm.input_callback.append("cache")
361
+ if "cache" not in litellm.success_callback:
362
+ litellm.success_callback.append("cache")
363
+ if "cache" not in litellm._async_success_callback:
364
+ litellm._async_success_callback.append("cache")
365
+ self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
366
+ self.type = type
367
+
368
+ def get_cache_key(self, *args, **kwargs):
369
+ """
370
+ Get the cache key for the given arguments.
371
+
372
+ Args:
373
+ *args: args to litellm.completion() or embedding()
374
+ **kwargs: kwargs to litellm.completion() or embedding()
375
+
376
+ Returns:
377
+ str: The cache key generated from the arguments, or None if no cache key could be generated.
378
+ """
379
+ cache_key = ""
380
+ print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
381
+
382
+ # for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
383
+ if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
384
+ print_verbose(f"\nReturning preset cache key: {cache_key}")
385
+ return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
386
+
387
+ # sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
388
+ completion_kwargs = [
389
+ "model",
390
+ "messages",
391
+ "temperature",
392
+ "top_p",
393
+ "n",
394
+ "stop",
395
+ "max_tokens",
396
+ "presence_penalty",
397
+ "frequency_penalty",
398
+ "logit_bias",
399
+ "user",
400
+ "response_format",
401
+ "seed",
402
+ "tools",
403
+ "tool_choice",
404
+ ]
405
+ embedding_only_kwargs = [
406
+ "input",
407
+ "encoding_format",
408
+ ] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
409
+
410
+ # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
411
+ combined_kwargs = completion_kwargs + embedding_only_kwargs
412
+ for param in combined_kwargs:
413
+ # ignore litellm params here
414
+ if param in kwargs:
415
+ # check if param == model and model_group is passed in, then override model with model_group
416
+ if param == "model":
417
+ model_group = None
418
+ caching_group = None
419
+ metadata = kwargs.get("metadata", None)
420
+ litellm_params = kwargs.get("litellm_params", {})
421
+ if metadata is not None:
422
+ model_group = metadata.get("model_group")
423
+ model_group = metadata.get("model_group", None)
424
+ caching_groups = metadata.get("caching_groups", None)
425
+ if caching_groups:
426
+ for group in caching_groups:
427
+ if model_group in group:
428
+ caching_group = group
429
+ break
430
+ if litellm_params is not None:
431
+ metadata = litellm_params.get("metadata", None)
432
+ if metadata is not None:
433
+ model_group = metadata.get("model_group", None)
434
+ caching_groups = metadata.get("caching_groups", None)
435
+ if caching_groups:
436
+ for group in caching_groups:
437
+ if model_group in group:
438
+ caching_group = group
439
+ break
440
+ param_value = (
441
+ caching_group or model_group or kwargs[param]
442
+ ) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
443
+ else:
444
+ if kwargs[param] is None:
445
+ continue # ignore None params
446
+ param_value = kwargs[param]
447
+ cache_key += f"{str(param)}: {str(param_value)}"
448
+ print_verbose(f"\nCreated cache key: {cache_key}")
449
+ # Use hashlib to create a sha256 hash of the cache key
450
+ hash_object = hashlib.sha256(cache_key.encode())
451
+ # Hexadecimal representation of the hash
452
+ hash_hex = hash_object.hexdigest()
453
+ print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
454
+ return hash_hex
455
+
456
+ def generate_streaming_content(self, content):
457
+ chunk_size = 5 # Adjust the chunk size as needed
458
+ for i in range(0, len(content), chunk_size):
459
+ yield {
460
+ "choices": [
461
+ {
462
+ "delta": {
463
+ "role": "assistant",
464
+ "content": content[i : i + chunk_size],
465
+ }
466
+ }
467
+ ]
468
+ }
469
+ time.sleep(0.02)
470
+
471
+ def get_cache(self, *args, **kwargs):
472
+ """
473
+ Retrieves the cached result for the given arguments.
474
+
475
+ Args:
476
+ *args: args to litellm.completion() or embedding()
477
+ **kwargs: kwargs to litellm.completion() or embedding()
478
+
479
+ Returns:
480
+ The cached result if it exists, otherwise None.
481
+ """
482
+ try: # never block execution
483
+ if "cache_key" in kwargs:
484
+ cache_key = kwargs["cache_key"]
485
+ else:
486
+ cache_key = self.get_cache_key(*args, **kwargs)
487
+ if cache_key is not None:
488
+ cache_control_args = kwargs.get("cache", {})
489
+ max_age = cache_control_args.get(
490
+ "s-max-age", cache_control_args.get("s-maxage", float("inf"))
491
+ )
492
+ cached_result = self.cache.get_cache(cache_key)
493
+ # Check if a timestamp was stored with the cached response
494
+ if (
495
+ cached_result is not None
496
+ and isinstance(cached_result, dict)
497
+ and "timestamp" in cached_result
498
+ and max_age is not None
499
+ ):
500
+ timestamp = cached_result["timestamp"]
501
+ current_time = time.time()
502
+
503
+ # Calculate age of the cached response
504
+ response_age = current_time - timestamp
505
+
506
+ # Check if the cached response is older than the max-age
507
+ if response_age > max_age:
508
+ print_verbose(
509
+ f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
510
+ )
511
+ return None # Cached response is too old
512
+
513
+ # If the response is fresh, or there's no max-age requirement, return the cached response
514
+ # cached_response is in `b{} convert it to ModelResponse
515
+ cached_response = cached_result.get("response")
516
+ try:
517
+ if isinstance(cached_response, dict):
518
+ pass
519
+ else:
520
+ cached_response = json.loads(
521
+ cached_response
522
+ ) # Convert string to dictionary
523
+ except:
524
+ cached_response = ast.literal_eval(cached_response)
525
+ return cached_response
526
+ return cached_result
527
+ except Exception as e:
528
+ print_verbose(f"An exception occurred: {traceback.format_exc()}")
529
+ return None
530
+
531
+ def add_cache(self, result, *args, **kwargs):
532
+ """
533
+ Adds a result to the cache.
534
+
535
+ Args:
536
+ *args: args to litellm.completion() or embedding()
537
+ **kwargs: kwargs to litellm.completion() or embedding()
538
+
539
+ Returns:
540
+ None
541
+ """
542
+ try:
543
+ if "cache_key" in kwargs:
544
+ cache_key = kwargs["cache_key"]
545
+ else:
546
+ cache_key = self.get_cache_key(*args, **kwargs)
547
+ if cache_key is not None:
548
+ if isinstance(result, OpenAIObject):
549
+ result = result.model_dump_json()
550
+
551
+ ## Get Cache-Controls ##
552
+ if kwargs.get("cache", None) is not None and isinstance(
553
+ kwargs.get("cache"), dict
554
+ ):
555
+ for k, v in kwargs.get("cache").items():
556
+ if k == "ttl":
557
+ kwargs["ttl"] = v
558
+ cached_data = {"timestamp": time.time(), "response": result}
559
+ self.cache.set_cache(cache_key, cached_data, **kwargs)
560
+ except Exception as e:
561
+ print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
562
+ traceback.print_exc()
563
+ pass
564
+
565
+ async def _async_add_cache(self, result, *args, **kwargs):
566
+ self.add_cache(result, *args, **kwargs)
567
+
568
+
569
+ def enable_cache(
570
+ type: Optional[Literal["local", "redis", "s3"]] = "local",
571
+ host: Optional[str] = None,
572
+ port: Optional[str] = None,
573
+ password: Optional[str] = None,
574
+ supported_call_types: Optional[
575
+ List[Literal["completion", "acompletion", "embedding", "aembedding"]]
576
+ ] = ["completion", "acompletion", "embedding", "aembedding"],
577
+ **kwargs,
578
+ ):
579
+ """
580
+ Enable cache with the specified configuration.
581
+
582
+ Args:
583
+ type (Optional[Literal["local", "redis"]]): The type of cache to enable. Defaults to "local".
584
+ host (Optional[str]): The host address of the cache server. Defaults to None.
585
+ port (Optional[str]): The port number of the cache server. Defaults to None.
586
+ password (Optional[str]): The password for the cache server. Defaults to None.
587
+ supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
588
+ The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
589
+ **kwargs: Additional keyword arguments.
590
+
591
+ Returns:
592
+ None
593
+
594
+ Raises:
595
+ None
596
+ """
597
+ print_verbose("LiteLLM: Enabling Cache")
598
+ if "cache" not in litellm.input_callback:
599
+ litellm.input_callback.append("cache")
600
+ if "cache" not in litellm.success_callback:
601
+ litellm.success_callback.append("cache")
602
+ if "cache" not in litellm._async_success_callback:
603
+ litellm._async_success_callback.append("cache")
604
+
605
+ if litellm.cache == None:
606
+ litellm.cache = Cache(
607
+ type=type,
608
+ host=host,
609
+ port=port,
610
+ password=password,
611
+ supported_call_types=supported_call_types,
612
+ **kwargs,
613
+ )
614
+ print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
615
+ print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
616
+
617
+
618
+ def update_cache(
619
+ type: Optional[Literal["local", "redis"]] = "local",
620
+ host: Optional[str] = None,
621
+ port: Optional[str] = None,
622
+ password: Optional[str] = None,
623
+ supported_call_types: Optional[
624
+ List[Literal["completion", "acompletion", "embedding", "aembedding"]]
625
+ ] = ["completion", "acompletion", "embedding", "aembedding"],
626
+ **kwargs,
627
+ ):
628
+ """
629
+ Update the cache for LiteLLM.
630
+
631
+ Args:
632
+ type (Optional[Literal["local", "redis"]]): The type of cache. Defaults to "local".
633
+ host (Optional[str]): The host of the cache. Defaults to None.
634
+ port (Optional[str]): The port of the cache. Defaults to None.
635
+ password (Optional[str]): The password for the cache. Defaults to None.
636
+ supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
637
+ The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
638
+ **kwargs: Additional keyword arguments for the cache.
639
+
640
+ Returns:
641
+ None
642
+
643
+ """
644
+ print_verbose("LiteLLM: Updating Cache")
645
+ litellm.cache = Cache(
646
+ type=type,
647
+ host=host,
648
+ port=port,
649
+ password=password,
650
+ supported_call_types=supported_call_types,
651
+ **kwargs,
652
+ )
653
+ print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
654
+ print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
655
+
656
+
657
+ def disable_cache():
658
+ """
659
+ Disable the cache used by LiteLLM.
660
+
661
+ This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
662
+
663
+ Parameters:
664
+ None
665
+
666
+ Returns:
667
+ None
668
+ """
669
+ from contextlib import suppress
670
+
671
+ print_verbose("LiteLLM: Disabling Cache")
672
+ with suppress(ValueError):
673
+ litellm.input_callback.remove("cache")
674
+ litellm.success_callback.remove("cache")
675
+ litellm._async_success_callback.remove("cache")
676
+
677
+ litellm.cache = None
678
+ print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
litellm/cost.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "gpt-3.5-turbo-0613": 0.00015000000000000001,
3
+ "claude-2": 0.00016454,
4
+ "gpt-4-0613": 0.015408
5
+ }
litellm/deprecated_litellm_server/.env.template ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # set AUTH STRATEGY FOR LLM APIs - Defaults to using Environment Variables
2
+ # AUTH_STRATEGY = "ENV" # ENV or DYNAMIC, ENV always reads from environment variables, DYNAMIC reads request headers to set LLM api keys
3
+
4
+ # OPENAI_API_KEY = ""
5
+
6
+ # HUGGINGFACE_API_KEY=""
7
+
8
+ # TOGETHERAI_API_KEY=""
9
+
10
+ # REPLICATE_API_KEY=""
11
+
12
+ # ## bedrock / sagemaker
13
+ # AWS_ACCESS_KEY_ID = ""
14
+ # AWS_SECRET_ACCESS_KEY = ""
15
+
16
+ # AZURE_API_KEY = ""
17
+ # AZURE_API_BASE = ""
18
+ # AZURE_API_VERSION = ""
19
+
20
+ # ANTHROPIC_API_KEY = ""
21
+
22
+ # COHERE_API_KEY = ""
23
+
24
+ # ## CONFIG FILE ##
25
+ # # CONFIG_FILE_PATH = "" # uncomment to point to config file
26
+
27
+ # ## LOGGING ##
28
+
29
+ # SET_VERBOSE = "False" # set to 'True' to see detailed input/output logs
30
+
31
+ # ### LANGFUSE
32
+ # LANGFUSE_PUBLIC_KEY = ""
33
+ # LANGFUSE_SECRET_KEY = ""
34
+ # # Optional, defaults to https://cloud.langfuse.com
35
+ # LANGFUSE_HOST = "" # optional
36
+
37
+
38
+ # ## CACHING ##
39
+
40
+ # ### REDIS
41
+ # REDIS_HOST = ""
42
+ # REDIS_PORT = ""
43
+ # REDIS_PASSWORD = ""
litellm/deprecated_litellm_server/Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # FROM python:3.10
2
+
3
+ # ENV LITELLM_CONFIG_PATH="/litellm.secrets.toml"
4
+ # COPY . /app
5
+ # WORKDIR /app
6
+ # RUN pip install -r requirements.txt
7
+
8
+ # EXPOSE $PORT
9
+
10
+ # CMD exec uvicorn main:app --host 0.0.0.0 --port $PORT --workers 10
litellm/deprecated_litellm_server/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # litellm-server [experimental]
2
+
3
+ Deprecated. See litellm/proxy
litellm/deprecated_litellm_server/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from .main import *
2
+ # from .server_utils import *
litellm/deprecated_litellm_server/main.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os, traceback
2
+ # from fastapi import FastAPI, Request, HTTPException
3
+ # from fastapi.routing import APIRouter
4
+ # from fastapi.responses import StreamingResponse, FileResponse
5
+ # from fastapi.middleware.cors import CORSMiddleware
6
+ # import json, sys
7
+ # from typing import Optional
8
+ # sys.path.insert(
9
+ # 0, os.path.abspath("../")
10
+ # ) # Adds the parent directory to the system path - for litellm local dev
11
+ # import litellm
12
+
13
+ # try:
14
+ # from litellm.deprecated_litellm_server.server_utils import set_callbacks, load_router_config, print_verbose
15
+ # except ImportError:
16
+ # from litellm.deprecated_litellm_server.server_utils import set_callbacks, load_router_config, print_verbose
17
+ # import dotenv
18
+ # dotenv.load_dotenv() # load env variables
19
+
20
+ # app = FastAPI(docs_url="/", title="LiteLLM API")
21
+ # router = APIRouter()
22
+ # origins = ["*"]
23
+
24
+ # app.add_middleware(
25
+ # CORSMiddleware,
26
+ # allow_origins=origins,
27
+ # allow_credentials=True,
28
+ # allow_methods=["*"],
29
+ # allow_headers=["*"],
30
+ # )
31
+ # #### GLOBAL VARIABLES ####
32
+ # llm_router: Optional[litellm.Router] = None
33
+ # llm_model_list: Optional[list] = None
34
+ # server_settings: Optional[dict] = None
35
+
36
+ # set_callbacks() # sets litellm callbacks for logging if they exist in the environment
37
+
38
+ # if "CONFIG_FILE_PATH" in os.environ:
39
+ # llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
40
+ # else:
41
+ # llm_router, llm_model_list, server_settings = load_router_config(router=llm_router)
42
+ # #### API ENDPOINTS ####
43
+ # @router.get("/v1/models")
44
+ # @router.get("/models") # if project requires model list
45
+ # def model_list():
46
+ # all_models = litellm.utils.get_valid_models()
47
+ # if llm_model_list:
48
+ # all_models += llm_model_list
49
+ # return dict(
50
+ # data=[
51
+ # {
52
+ # "id": model,
53
+ # "object": "model",
54
+ # "created": 1677610602,
55
+ # "owned_by": "openai",
56
+ # }
57
+ # for model in all_models
58
+ # ],
59
+ # object="list",
60
+ # )
61
+ # # for streaming
62
+ # def data_generator(response):
63
+
64
+ # for chunk in response:
65
+
66
+ # yield f"data: {json.dumps(chunk)}\n\n"
67
+
68
+ # @router.post("/v1/completions")
69
+ # @router.post("/completions")
70
+ # async def completion(request: Request):
71
+ # data = await request.json()
72
+ # response = litellm.completion(
73
+ # **data
74
+ # )
75
+ # if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
76
+ # return StreamingResponse(data_generator(response), media_type='text/event-stream')
77
+ # return response
78
+
79
+ # @router.post("/v1/embeddings")
80
+ # @router.post("/embeddings")
81
+ # async def embedding(request: Request):
82
+ # try:
83
+ # data = await request.json()
84
+ # # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
85
+ # if os.getenv("AUTH_STRATEGY", None) == "DYNAMIC" and "authorization" in request.headers: # if users pass LLM api keys as part of header
86
+ # api_key = request.headers.get("authorization")
87
+ # api_key = api_key.replace("Bearer", "").strip() # type: ignore
88
+ # if len(api_key.strip()) > 0:
89
+ # api_key = api_key
90
+ # data["api_key"] = api_key
91
+ # response = litellm.embedding(
92
+ # **data
93
+ # )
94
+ # return response
95
+ # except Exception as e:
96
+ # error_traceback = traceback.format_exc()
97
+ # error_msg = f"{str(e)}\n\n{error_traceback}"
98
+ # return {"error": error_msg}
99
+
100
+ # @router.post("/v1/chat/completions")
101
+ # @router.post("/chat/completions")
102
+ # @router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
103
+ # async def chat_completion(request: Request, model: Optional[str] = None):
104
+ # global llm_model_list, server_settings
105
+ # try:
106
+ # data = await request.json()
107
+ # server_model = server_settings.get("completion_model", None) if server_settings else None
108
+ # data["model"] = server_model or model or data["model"]
109
+ # ## CHECK KEYS ##
110
+ # # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
111
+ # # env_validation = litellm.validate_environment(model=data["model"])
112
+ # # if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
113
+ # # if "authorization" in request.headers:
114
+ # # api_key = request.headers.get("authorization")
115
+ # # elif "api-key" in request.headers:
116
+ # # api_key = request.headers.get("api-key")
117
+ # # print(f"api_key in headers: {api_key}")
118
+ # # if " " in api_key:
119
+ # # api_key = api_key.split(" ")[1]
120
+ # # print(f"api_key split: {api_key}")
121
+ # # if len(api_key) > 0:
122
+ # # api_key = api_key
123
+ # # data["api_key"] = api_key
124
+ # # print(f"api_key in data: {api_key}")
125
+ # ## CHECK CONFIG ##
126
+ # if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
127
+ # for m in llm_model_list:
128
+ # if data["model"] == m["model_name"]:
129
+ # for key, value in m["litellm_params"].items():
130
+ # data[key] = value
131
+ # break
132
+ # response = litellm.completion(
133
+ # **data
134
+ # )
135
+ # if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
136
+ # return StreamingResponse(data_generator(response), media_type='text/event-stream')
137
+ # return response
138
+ # except Exception as e:
139
+ # error_traceback = traceback.format_exc()
140
+
141
+ # error_msg = f"{str(e)}\n\n{error_traceback}"
142
+ # # return {"error": error_msg}
143
+ # raise HTTPException(status_code=500, detail=error_msg)
144
+
145
+ # @router.post("/router/completions")
146
+ # async def router_completion(request: Request):
147
+ # global llm_router
148
+ # try:
149
+ # data = await request.json()
150
+ # if "model_list" in data:
151
+ # llm_router = litellm.Router(model_list=data.pop("model_list"))
152
+ # if llm_router is None:
153
+ # raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
154
+
155
+ # # openai.ChatCompletion.create replacement
156
+ # response = await llm_router.acompletion(model="gpt-3.5-turbo",
157
+ # messages=[{"role": "user", "content": "Hey, how's it going?"}])
158
+
159
+ # if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
160
+ # return StreamingResponse(data_generator(response), media_type='text/event-stream')
161
+ # return response
162
+ # except Exception as e:
163
+ # error_traceback = traceback.format_exc()
164
+ # error_msg = f"{str(e)}\n\n{error_traceback}"
165
+ # return {"error": error_msg}
166
+
167
+ # @router.post("/router/embedding")
168
+ # async def router_embedding(request: Request):
169
+ # global llm_router
170
+ # try:
171
+ # data = await request.json()
172
+ # if "model_list" in data:
173
+ # llm_router = litellm.Router(model_list=data.pop("model_list"))
174
+ # if llm_router is None:
175
+ # raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
176
+
177
+ # response = await llm_router.aembedding(model="gpt-3.5-turbo", # type: ignore
178
+ # messages=[{"role": "user", "content": "Hey, how's it going?"}])
179
+
180
+ # if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
181
+ # return StreamingResponse(data_generator(response), media_type='text/event-stream')
182
+ # return response
183
+ # except Exception as e:
184
+ # error_traceback = traceback.format_exc()
185
+ # error_msg = f"{str(e)}\n\n{error_traceback}"
186
+ # return {"error": error_msg}
187
+
188
+ # @router.get("/")
189
+ # async def home(request: Request):
190
+ # return "LiteLLM: RUNNING"
191
+
192
+
193
+ # app.include_router(router)
litellm/deprecated_litellm_server/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # openai
2
+ # fastapi
3
+ # uvicorn
4
+ # boto3
5
+ # litellm
6
+ # python-dotenv
7
+ # redis
litellm/deprecated_litellm_server/server_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os, litellm
2
+ # import pkg_resources
3
+ # import dotenv
4
+ # dotenv.load_dotenv() # load env variables
5
+
6
+ # def print_verbose(print_statement):
7
+ # pass
8
+
9
+ # def get_package_version(package_name):
10
+ # try:
11
+ # package = pkg_resources.get_distribution(package_name)
12
+ # return package.version
13
+ # except pkg_resources.DistributionNotFound:
14
+ # return None
15
+
16
+ # # Usage example
17
+ # package_name = "litellm"
18
+ # version = get_package_version(package_name)
19
+ # if version:
20
+ # print_verbose(f"The version of {package_name} is {version}")
21
+ # else:
22
+ # print_verbose(f"{package_name} is not installed")
23
+ # import yaml
24
+ # import dotenv
25
+ # from typing import Optional
26
+ # dotenv.load_dotenv() # load env variables
27
+
28
+ # def set_callbacks():
29
+ # ## LOGGING
30
+ # if len(os.getenv("SET_VERBOSE", "")) > 0:
31
+ # if os.getenv("SET_VERBOSE") == "True":
32
+ # litellm.set_verbose = True
33
+ # print_verbose("\033[92mLiteLLM: Switched on verbose logging\033[0m")
34
+ # else:
35
+ # litellm.set_verbose = False
36
+
37
+ # ### LANGFUSE
38
+ # if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0:
39
+ # litellm.success_callback = ["langfuse"]
40
+ # print_verbose("\033[92mLiteLLM: Switched on Langfuse feature\033[0m")
41
+
42
+ # ## CACHING
43
+ # ### REDIS
44
+ # # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
45
+ # # print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}")
46
+ # # from litellm.caching import Cache
47
+ # # litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
48
+ # # print("\033[92mLiteLLM: Switched on Redis caching\033[0m")
49
+
50
+
51
+ # def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'):
52
+ # config = {}
53
+ # server_settings = {}
54
+ # try:
55
+ # if os.path.exists(config_file_path): # type: ignore
56
+ # with open(config_file_path, 'r') as file: # type: ignore
57
+ # config = yaml.safe_load(file)
58
+ # else:
59
+ # pass
60
+ # except:
61
+ # pass
62
+
63
+ # ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
64
+ # server_settings = config.get("server_settings", None)
65
+ # if server_settings:
66
+ # server_settings = server_settings
67
+
68
+ # ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
69
+ # litellm_settings = config.get('litellm_settings', None)
70
+ # if litellm_settings:
71
+ # for key, value in litellm_settings.items():
72
+ # setattr(litellm, key, value)
73
+
74
+ # ## MODEL LIST
75
+ # model_list = config.get('model_list', None)
76
+ # if model_list:
77
+ # router = litellm.Router(model_list=model_list)
78
+
79
+ # ## ENVIRONMENT VARIABLES
80
+ # environment_variables = config.get('environment_variables', None)
81
+ # if environment_variables:
82
+ # for key, value in environment_variables.items():
83
+ # os.environ[key] = value
84
+
85
+ # return router, model_list, server_settings
litellm/exceptions.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +-----------------------------------------------+
2
+ # | |
3
+ # | Give Feedback / Get Help |
4
+ # | https://github.com/BerriAI/litellm/issues/new |
5
+ # | |
6
+ # +-----------------------------------------------+
7
+ #
8
+ # Thank you users! We ❤️ you! - Krrish & Ishaan
9
+
10
+ ## LiteLLM versions of the OpenAI Exception Types
11
+
12
+ from openai import (
13
+ AuthenticationError,
14
+ BadRequestError,
15
+ NotFoundError,
16
+ RateLimitError,
17
+ APIStatusError,
18
+ OpenAIError,
19
+ APIError,
20
+ APITimeoutError,
21
+ APIConnectionError,
22
+ APIResponseValidationError,
23
+ UnprocessableEntityError,
24
+ )
25
+ import httpx
26
+
27
+
28
+ class AuthenticationError(AuthenticationError): # type: ignore
29
+ def __init__(self, message, llm_provider, model, response: httpx.Response):
30
+ self.status_code = 401
31
+ self.message = message
32
+ self.llm_provider = llm_provider
33
+ self.model = model
34
+ super().__init__(
35
+ self.message, response=response, body=None
36
+ ) # Call the base class constructor with the parameters it needs
37
+
38
+
39
+ # raise when invalid models passed, example gpt-8
40
+ class NotFoundError(NotFoundError): # type: ignore
41
+ def __init__(self, message, model, llm_provider, response: httpx.Response):
42
+ self.status_code = 404
43
+ self.message = message
44
+ self.model = model
45
+ self.llm_provider = llm_provider
46
+ super().__init__(
47
+ self.message, response=response, body=None
48
+ ) # Call the base class constructor with the parameters it needs
49
+
50
+
51
+ class BadRequestError(BadRequestError): # type: ignore
52
+ def __init__(self, message, model, llm_provider, response: httpx.Response):
53
+ self.status_code = 400
54
+ self.message = message
55
+ self.model = model
56
+ self.llm_provider = llm_provider
57
+ super().__init__(
58
+ self.message, response=response, body=None
59
+ ) # Call the base class constructor with the parameters it needs
60
+
61
+
62
+ class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
63
+ def __init__(self, message, model, llm_provider, response: httpx.Response):
64
+ self.status_code = 422
65
+ self.message = message
66
+ self.model = model
67
+ self.llm_provider = llm_provider
68
+ super().__init__(
69
+ self.message, response=response, body=None
70
+ ) # Call the base class constructor with the parameters it needs
71
+
72
+
73
+ class Timeout(APITimeoutError): # type: ignore
74
+ def __init__(self, message, model, llm_provider):
75
+ self.status_code = 408
76
+ self.message = message
77
+ self.model = model
78
+ self.llm_provider = llm_provider
79
+ request = httpx.Request(method="POST", url="https://api.openai.com/v1")
80
+ super().__init__(
81
+ request=request
82
+ ) # Call the base class constructor with the parameters it needs
83
+
84
+
85
+ class RateLimitError(RateLimitError): # type: ignore
86
+ def __init__(self, message, llm_provider, model, response: httpx.Response):
87
+ self.status_code = 429
88
+ self.message = message
89
+ self.llm_provider = llm_provider
90
+ self.modle = model
91
+ super().__init__(
92
+ self.message, response=response, body=None
93
+ ) # Call the base class constructor with the parameters it needs
94
+
95
+
96
+ # sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
97
+ class ContextWindowExceededError(BadRequestError): # type: ignore
98
+ def __init__(self, message, model, llm_provider, response: httpx.Response):
99
+ self.status_code = 400
100
+ self.message = message
101
+ self.model = model
102
+ self.llm_provider = llm_provider
103
+ super().__init__(
104
+ message=self.message,
105
+ model=self.model, # type: ignore
106
+ llm_provider=self.llm_provider, # type: ignore
107
+ response=response,
108
+ ) # Call the base class constructor with the parameters it needs
109
+
110
+
111
+ class ContentPolicyViolationError(BadRequestError): # type: ignore
112
+ # Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
113
+ def __init__(self, message, model, llm_provider, response: httpx.Response):
114
+ self.status_code = 400
115
+ self.message = message
116
+ self.model = model
117
+ self.llm_provider = llm_provider
118
+ super().__init__(
119
+ message=self.message,
120
+ model=self.model, # type: ignore
121
+ llm_provider=self.llm_provider, # type: ignore
122
+ response=response,
123
+ ) # Call the base class constructor with the parameters it needs
124
+
125
+
126
+ class ServiceUnavailableError(APIStatusError): # type: ignore
127
+ def __init__(self, message, llm_provider, model, response: httpx.Response):
128
+ self.status_code = 503
129
+ self.message = message
130
+ self.llm_provider = llm_provider
131
+ self.model = model
132
+ super().__init__(
133
+ self.message, response=response, body=None
134
+ ) # Call the base class constructor with the parameters it needs
135
+
136
+
137
+ # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
138
+ class APIError(APIError): # type: ignore
139
+ def __init__(
140
+ self, status_code, message, llm_provider, model, request: httpx.Request
141
+ ):
142
+ self.status_code = status_code
143
+ self.message = message
144
+ self.llm_provider = llm_provider
145
+ self.model = model
146
+ super().__init__(self.message, request=request, body=None) # type: ignore
147
+
148
+
149
+ # raised if an invalid request (not get, delete, put, post) is made
150
+ class APIConnectionError(APIConnectionError): # type: ignore
151
+ def __init__(self, message, llm_provider, model, request: httpx.Request):
152
+ self.message = message
153
+ self.llm_provider = llm_provider
154
+ self.model = model
155
+ self.status_code = 500
156
+ super().__init__(message=self.message, request=request)
157
+
158
+
159
+ # raised if an invalid request (not get, delete, put, post) is made
160
+ class APIResponseValidationError(APIResponseValidationError): # type: ignore
161
+ def __init__(self, message, llm_provider, model):
162
+ self.message = message
163
+ self.llm_provider = llm_provider
164
+ self.model = model
165
+ request = httpx.Request(method="POST", url="https://api.openai.com/v1")
166
+ response = httpx.Response(status_code=500, request=request)
167
+ super().__init__(response=response, body=None, message=message)
168
+
169
+
170
+ class OpenAIError(OpenAIError): # type: ignore
171
+ def __init__(self, original_exception):
172
+ self.status_code = original_exception.http_status
173
+ super().__init__(
174
+ http_body=original_exception.http_body,
175
+ http_status=original_exception.http_status,
176
+ json_body=original_exception.json_body,
177
+ headers=original_exception.headers,
178
+ code=original_exception.code,
179
+ )
180
+ self.llm_provider = "openai"
181
+
182
+
183
+ class BudgetExceededError(Exception):
184
+ def __init__(self, current_cost, max_budget):
185
+ self.current_cost = current_cost
186
+ self.max_budget = max_budget
187
+ message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
188
+ super().__init__(message)
189
+
190
+
191
+ ## DEPRECATED ##
192
+ class InvalidRequestError(BadRequestError): # type: ignore
193
+ def __init__(self, message, model, llm_provider):
194
+ self.status_code = 400
195
+ self.message = message
196
+ self.model = model
197
+ self.llm_provider = llm_provider
198
+ super().__init__(
199
+ self.message, f"{self.model}"
200
+ ) # Call the base class constructor with the parameters it needs
litellm/integrations/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import *
litellm/integrations/aispend.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success + failure, log events to aispend.io
3
+ import dotenv, os
4
+ import requests
5
+
6
+ dotenv.load_dotenv() # Loading env variables using dotenv
7
+ import traceback
8
+ import datetime
9
+
10
+ model_cost = {
11
+ "gpt-3.5-turbo": {
12
+ "max_tokens": 4000,
13
+ "input_cost_per_token": 0.0000015,
14
+ "output_cost_per_token": 0.000002,
15
+ },
16
+ "gpt-35-turbo": {
17
+ "max_tokens": 4000,
18
+ "input_cost_per_token": 0.0000015,
19
+ "output_cost_per_token": 0.000002,
20
+ }, # azure model name
21
+ "gpt-3.5-turbo-0613": {
22
+ "max_tokens": 4000,
23
+ "input_cost_per_token": 0.0000015,
24
+ "output_cost_per_token": 0.000002,
25
+ },
26
+ "gpt-3.5-turbo-0301": {
27
+ "max_tokens": 4000,
28
+ "input_cost_per_token": 0.0000015,
29
+ "output_cost_per_token": 0.000002,
30
+ },
31
+ "gpt-3.5-turbo-16k": {
32
+ "max_tokens": 16000,
33
+ "input_cost_per_token": 0.000003,
34
+ "output_cost_per_token": 0.000004,
35
+ },
36
+ "gpt-35-turbo-16k": {
37
+ "max_tokens": 16000,
38
+ "input_cost_per_token": 0.000003,
39
+ "output_cost_per_token": 0.000004,
40
+ }, # azure model name
41
+ "gpt-3.5-turbo-16k-0613": {
42
+ "max_tokens": 16000,
43
+ "input_cost_per_token": 0.000003,
44
+ "output_cost_per_token": 0.000004,
45
+ },
46
+ "gpt-4": {
47
+ "max_tokens": 8000,
48
+ "input_cost_per_token": 0.000003,
49
+ "output_cost_per_token": 0.00006,
50
+ },
51
+ "gpt-4-0613": {
52
+ "max_tokens": 8000,
53
+ "input_cost_per_token": 0.000003,
54
+ "output_cost_per_token": 0.00006,
55
+ },
56
+ "gpt-4-32k": {
57
+ "max_tokens": 8000,
58
+ "input_cost_per_token": 0.00006,
59
+ "output_cost_per_token": 0.00012,
60
+ },
61
+ "claude-instant-1": {
62
+ "max_tokens": 100000,
63
+ "input_cost_per_token": 0.00000163,
64
+ "output_cost_per_token": 0.00000551,
65
+ },
66
+ "claude-2": {
67
+ "max_tokens": 100000,
68
+ "input_cost_per_token": 0.00001102,
69
+ "output_cost_per_token": 0.00003268,
70
+ },
71
+ "text-bison-001": {
72
+ "max_tokens": 8192,
73
+ "input_cost_per_token": 0.000004,
74
+ "output_cost_per_token": 0.000004,
75
+ },
76
+ "chat-bison-001": {
77
+ "max_tokens": 4096,
78
+ "input_cost_per_token": 0.000002,
79
+ "output_cost_per_token": 0.000002,
80
+ },
81
+ "command-nightly": {
82
+ "max_tokens": 4096,
83
+ "input_cost_per_token": 0.000015,
84
+ "output_cost_per_token": 0.000015,
85
+ },
86
+ }
87
+
88
+
89
+ class AISpendLogger:
90
+ # Class variables or attributes
91
+ def __init__(self):
92
+ # Instance variables
93
+ self.account_id = os.getenv("AISPEND_ACCOUNT_ID")
94
+ self.api_key = os.getenv("AISPEND_API_KEY")
95
+
96
+ def price_calculator(self, model, response_obj, start_time, end_time):
97
+ # try and find if the model is in the model_cost map
98
+ # else default to the average of the costs
99
+ prompt_tokens_cost_usd_dollar = 0
100
+ completion_tokens_cost_usd_dollar = 0
101
+ if model in model_cost:
102
+ prompt_tokens_cost_usd_dollar = (
103
+ model_cost[model]["input_cost_per_token"]
104
+ * response_obj["usage"]["prompt_tokens"]
105
+ )
106
+ completion_tokens_cost_usd_dollar = (
107
+ model_cost[model]["output_cost_per_token"]
108
+ * response_obj["usage"]["completion_tokens"]
109
+ )
110
+ elif "replicate" in model:
111
+ # replicate models are charged based on time
112
+ # llama 2 runs on an nvidia a100 which costs $0.0032 per second - https://replicate.com/replicate/llama-2-70b-chat
113
+ model_run_time = end_time - start_time # assuming time in seconds
114
+ cost_usd_dollar = model_run_time * 0.0032
115
+ prompt_tokens_cost_usd_dollar = cost_usd_dollar / 2
116
+ completion_tokens_cost_usd_dollar = cost_usd_dollar / 2
117
+ else:
118
+ # calculate average input cost
119
+ input_cost_sum = 0
120
+ output_cost_sum = 0
121
+ for model in model_cost:
122
+ input_cost_sum += model_cost[model]["input_cost_per_token"]
123
+ output_cost_sum += model_cost[model]["output_cost_per_token"]
124
+ avg_input_cost = input_cost_sum / len(model_cost.keys())
125
+ avg_output_cost = output_cost_sum / len(model_cost.keys())
126
+ prompt_tokens_cost_usd_dollar = (
127
+ model_cost[model]["input_cost_per_token"]
128
+ * response_obj["usage"]["prompt_tokens"]
129
+ )
130
+ completion_tokens_cost_usd_dollar = (
131
+ model_cost[model]["output_cost_per_token"]
132
+ * response_obj["usage"]["completion_tokens"]
133
+ )
134
+ return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
135
+
136
+ def log_event(self, model, response_obj, start_time, end_time, print_verbose):
137
+ # Method definition
138
+ try:
139
+ print_verbose(
140
+ f"AISpend Logging - Enters logging function for model {model}"
141
+ )
142
+
143
+ url = f"https://aispend.io/api/v1/accounts/{self.account_id}/data"
144
+ headers = {
145
+ "Authorization": f"Bearer {self.api_key}",
146
+ "Content-Type": "application/json",
147
+ }
148
+
149
+ response_timestamp = datetime.datetime.fromtimestamp(
150
+ int(response_obj["created"])
151
+ ).strftime("%Y-%m-%d")
152
+
153
+ (
154
+ prompt_tokens_cost_usd_dollar,
155
+ completion_tokens_cost_usd_dollar,
156
+ ) = self.price_calculator(model, response_obj, start_time, end_time)
157
+ prompt_tokens_cost_usd_cent = prompt_tokens_cost_usd_dollar * 100
158
+ completion_tokens_cost_usd_cent = completion_tokens_cost_usd_dollar * 100
159
+ data = [
160
+ {
161
+ "requests": 1,
162
+ "requests_context": 1,
163
+ "context_tokens": response_obj["usage"]["prompt_tokens"],
164
+ "requests_generated": 1,
165
+ "generated_tokens": response_obj["usage"]["completion_tokens"],
166
+ "recorded_date": response_timestamp,
167
+ "model_id": response_obj["model"],
168
+ "generated_tokens_cost_usd_cent": prompt_tokens_cost_usd_cent,
169
+ "context_tokens_cost_usd_cent": completion_tokens_cost_usd_cent,
170
+ }
171
+ ]
172
+
173
+ print_verbose(f"AISpend Logging - final data object: {data}")
174
+ except:
175
+ # traceback.print_exc()
176
+ print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
177
+ pass
litellm/integrations/berrispend.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success + failure, log events to aispend.io
3
+ import dotenv, os
4
+ import requests
5
+
6
+ dotenv.load_dotenv() # Loading env variables using dotenv
7
+ import traceback
8
+ import datetime
9
+
10
+ model_cost = {
11
+ "gpt-3.5-turbo": {
12
+ "max_tokens": 4000,
13
+ "input_cost_per_token": 0.0000015,
14
+ "output_cost_per_token": 0.000002,
15
+ },
16
+ "gpt-35-turbo": {
17
+ "max_tokens": 4000,
18
+ "input_cost_per_token": 0.0000015,
19
+ "output_cost_per_token": 0.000002,
20
+ }, # azure model name
21
+ "gpt-3.5-turbo-0613": {
22
+ "max_tokens": 4000,
23
+ "input_cost_per_token": 0.0000015,
24
+ "output_cost_per_token": 0.000002,
25
+ },
26
+ "gpt-3.5-turbo-0301": {
27
+ "max_tokens": 4000,
28
+ "input_cost_per_token": 0.0000015,
29
+ "output_cost_per_token": 0.000002,
30
+ },
31
+ "gpt-3.5-turbo-16k": {
32
+ "max_tokens": 16000,
33
+ "input_cost_per_token": 0.000003,
34
+ "output_cost_per_token": 0.000004,
35
+ },
36
+ "gpt-35-turbo-16k": {
37
+ "max_tokens": 16000,
38
+ "input_cost_per_token": 0.000003,
39
+ "output_cost_per_token": 0.000004,
40
+ }, # azure model name
41
+ "gpt-3.5-turbo-16k-0613": {
42
+ "max_tokens": 16000,
43
+ "input_cost_per_token": 0.000003,
44
+ "output_cost_per_token": 0.000004,
45
+ },
46
+ "gpt-4": {
47
+ "max_tokens": 8000,
48
+ "input_cost_per_token": 0.000003,
49
+ "output_cost_per_token": 0.00006,
50
+ },
51
+ "gpt-4-0613": {
52
+ "max_tokens": 8000,
53
+ "input_cost_per_token": 0.000003,
54
+ "output_cost_per_token": 0.00006,
55
+ },
56
+ "gpt-4-32k": {
57
+ "max_tokens": 8000,
58
+ "input_cost_per_token": 0.00006,
59
+ "output_cost_per_token": 0.00012,
60
+ },
61
+ "claude-instant-1": {
62
+ "max_tokens": 100000,
63
+ "input_cost_per_token": 0.00000163,
64
+ "output_cost_per_token": 0.00000551,
65
+ },
66
+ "claude-2": {
67
+ "max_tokens": 100000,
68
+ "input_cost_per_token": 0.00001102,
69
+ "output_cost_per_token": 0.00003268,
70
+ },
71
+ "text-bison-001": {
72
+ "max_tokens": 8192,
73
+ "input_cost_per_token": 0.000004,
74
+ "output_cost_per_token": 0.000004,
75
+ },
76
+ "chat-bison-001": {
77
+ "max_tokens": 4096,
78
+ "input_cost_per_token": 0.000002,
79
+ "output_cost_per_token": 0.000002,
80
+ },
81
+ "command-nightly": {
82
+ "max_tokens": 4096,
83
+ "input_cost_per_token": 0.000015,
84
+ "output_cost_per_token": 0.000015,
85
+ },
86
+ }
87
+
88
+
89
+ class BerriSpendLogger:
90
+ # Class variables or attributes
91
+ def __init__(self):
92
+ # Instance variables
93
+ self.account_id = os.getenv("BERRISPEND_ACCOUNT_ID")
94
+
95
+ def price_calculator(self, model, response_obj, start_time, end_time):
96
+ # try and find if the model is in the model_cost map
97
+ # else default to the average of the costs
98
+ prompt_tokens_cost_usd_dollar = 0
99
+ completion_tokens_cost_usd_dollar = 0
100
+ if model in model_cost:
101
+ prompt_tokens_cost_usd_dollar = (
102
+ model_cost[model]["input_cost_per_token"]
103
+ * response_obj["usage"]["prompt_tokens"]
104
+ )
105
+ completion_tokens_cost_usd_dollar = (
106
+ model_cost[model]["output_cost_per_token"]
107
+ * response_obj["usage"]["completion_tokens"]
108
+ )
109
+ elif "replicate" in model:
110
+ # replicate models are charged based on time
111
+ # llama 2 runs on an nvidia a100 which costs $0.0032 per second - https://replicate.com/replicate/llama-2-70b-chat
112
+ model_run_time = end_time - start_time # assuming time in seconds
113
+ cost_usd_dollar = model_run_time * 0.0032
114
+ prompt_tokens_cost_usd_dollar = cost_usd_dollar / 2
115
+ completion_tokens_cost_usd_dollar = cost_usd_dollar / 2
116
+ else:
117
+ # calculate average input cost
118
+ input_cost_sum = 0
119
+ output_cost_sum = 0
120
+ for model in model_cost:
121
+ input_cost_sum += model_cost[model]["input_cost_per_token"]
122
+ output_cost_sum += model_cost[model]["output_cost_per_token"]
123
+ avg_input_cost = input_cost_sum / len(model_cost.keys())
124
+ avg_output_cost = output_cost_sum / len(model_cost.keys())
125
+ prompt_tokens_cost_usd_dollar = (
126
+ model_cost[model]["input_cost_per_token"]
127
+ * response_obj["usage"]["prompt_tokens"]
128
+ )
129
+ completion_tokens_cost_usd_dollar = (
130
+ model_cost[model]["output_cost_per_token"]
131
+ * response_obj["usage"]["completion_tokens"]
132
+ )
133
+ return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
134
+
135
+ def log_event(
136
+ self, model, messages, response_obj, start_time, end_time, print_verbose
137
+ ):
138
+ # Method definition
139
+ try:
140
+ print_verbose(
141
+ f"BerriSpend Logging - Enters logging function for model {model}"
142
+ )
143
+
144
+ url = f"https://berrispend.berri.ai/spend"
145
+ headers = {"Content-Type": "application/json"}
146
+
147
+ (
148
+ prompt_tokens_cost_usd_dollar,
149
+ completion_tokens_cost_usd_dollar,
150
+ ) = self.price_calculator(model, response_obj, start_time, end_time)
151
+ total_cost = (
152
+ prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
153
+ )
154
+
155
+ response_time = (end_time - start_time).total_seconds()
156
+ if "response" in response_obj:
157
+ data = [
158
+ {
159
+ "response_time": response_time,
160
+ "model_id": response_obj["model"],
161
+ "total_cost": total_cost,
162
+ "messages": messages,
163
+ "response": response_obj["choices"][0]["message"]["content"],
164
+ "account_id": self.account_id,
165
+ }
166
+ ]
167
+ elif "error" in response_obj:
168
+ data = [
169
+ {
170
+ "response_time": response_time,
171
+ "model_id": response_obj["model"],
172
+ "total_cost": total_cost,
173
+ "messages": messages,
174
+ "error": response_obj["error"],
175
+ "account_id": self.account_id,
176
+ }
177
+ ]
178
+
179
+ print_verbose(f"BerriSpend Logging - final data object: {data}")
180
+ response = requests.post(url, headers=headers, json=data)
181
+ except:
182
+ # traceback.print_exc()
183
+ print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}")
184
+ pass
litellm/integrations/custom_logger.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success, logs events to Promptlayer
3
+ import dotenv, os
4
+ import requests
5
+ from litellm.proxy._types import UserAPIKeyAuth
6
+ from litellm.caching import DualCache
7
+ from typing import Literal
8
+
9
+ dotenv.load_dotenv() # Loading env variables using dotenv
10
+ import traceback
11
+
12
+
13
+ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
14
+ # Class variables or attributes
15
+ def __init__(self):
16
+ pass
17
+
18
+ def log_pre_api_call(self, model, messages, kwargs):
19
+ pass
20
+
21
+ def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
22
+ pass
23
+
24
+ def log_stream_event(self, kwargs, response_obj, start_time, end_time):
25
+ pass
26
+
27
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
28
+ pass
29
+
30
+ def log_failure_event(self, kwargs, response_obj, start_time, end_time):
31
+ pass
32
+
33
+ #### ASYNC ####
34
+
35
+ async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
36
+ pass
37
+
38
+ async def async_log_pre_api_call(self, model, messages, kwargs):
39
+ pass
40
+
41
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
42
+ pass
43
+
44
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
45
+ pass
46
+
47
+ #### CALL HOOKS - proxy only ####
48
+ """
49
+ Control the modify incoming / outgoung data before calling the model
50
+ """
51
+
52
+ async def async_pre_call_hook(
53
+ self,
54
+ user_api_key_dict: UserAPIKeyAuth,
55
+ cache: DualCache,
56
+ data: dict,
57
+ call_type: Literal["completion", "embeddings"],
58
+ ):
59
+ pass
60
+
61
+ async def async_post_call_failure_hook(
62
+ self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
63
+ ):
64
+ pass
65
+
66
+ #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
67
+
68
+ def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
69
+ try:
70
+ kwargs["model"] = model
71
+ kwargs["messages"] = messages
72
+ kwargs["log_event_type"] = "pre_api_call"
73
+ callback_func(
74
+ kwargs,
75
+ )
76
+ print_verbose(f"Custom Logger - model call details: {kwargs}")
77
+ except:
78
+ traceback.print_exc()
79
+ print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
80
+
81
+ async def async_log_input_event(
82
+ self, model, messages, kwargs, print_verbose, callback_func
83
+ ):
84
+ try:
85
+ kwargs["model"] = model
86
+ kwargs["messages"] = messages
87
+ kwargs["log_event_type"] = "pre_api_call"
88
+ await callback_func(
89
+ kwargs,
90
+ )
91
+ print_verbose(f"Custom Logger - model call details: {kwargs}")
92
+ except:
93
+ traceback.print_exc()
94
+ print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
95
+
96
+ def log_event(
97
+ self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
98
+ ):
99
+ # Method definition
100
+ try:
101
+ kwargs["log_event_type"] = "post_api_call"
102
+ callback_func(
103
+ kwargs, # kwargs to func
104
+ response_obj,
105
+ start_time,
106
+ end_time,
107
+ )
108
+ print_verbose(f"Custom Logger - final response object: {response_obj}")
109
+ except:
110
+ # traceback.print_exc()
111
+ print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
112
+ pass
113
+
114
+ async def async_log_event(
115
+ self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
116
+ ):
117
+ # Method definition
118
+ try:
119
+ kwargs["log_event_type"] = "post_api_call"
120
+ await callback_func(
121
+ kwargs, # kwargs to func
122
+ response_obj,
123
+ start_time,
124
+ end_time,
125
+ )
126
+ print_verbose(f"Custom Logger - final response object: {response_obj}")
127
+ except:
128
+ # traceback.print_exc()
129
+ print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
130
+ pass
litellm/integrations/dynamodb.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success + failure, log events to Supabase
3
+
4
+ import dotenv, os
5
+ import requests
6
+
7
+ dotenv.load_dotenv() # Loading env variables using dotenv
8
+ import traceback
9
+ import datetime, subprocess, sys
10
+ import litellm, uuid
11
+ from litellm._logging import print_verbose
12
+
13
+
14
+ class DyanmoDBLogger:
15
+ # Class variables or attributes
16
+
17
+ def __init__(self):
18
+ # Instance variables
19
+ import boto3
20
+
21
+ self.dynamodb = boto3.resource(
22
+ "dynamodb", region_name=os.environ["AWS_REGION_NAME"]
23
+ )
24
+ if litellm.dynamodb_table_name is None:
25
+ raise ValueError(
26
+ "LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
27
+ )
28
+ self.table_name = litellm.dynamodb_table_name
29
+
30
+ async def _async_log_event(
31
+ self, kwargs, response_obj, start_time, end_time, print_verbose
32
+ ):
33
+ self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
34
+
35
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
36
+ try:
37
+ print_verbose(
38
+ f"DynamoDB Logging - Enters logging function for model {kwargs}"
39
+ )
40
+
41
+ # construct payload to send to DynamoDB
42
+ # follows the same params as langfuse.py
43
+ litellm_params = kwargs.get("litellm_params", {})
44
+ metadata = (
45
+ litellm_params.get("metadata", {}) or {}
46
+ ) # if litellm_params['metadata'] == None
47
+ messages = kwargs.get("messages")
48
+ optional_params = kwargs.get("optional_params", {})
49
+ call_type = kwargs.get("call_type", "litellm.completion")
50
+ usage = response_obj["usage"]
51
+ id = response_obj.get("id", str(uuid.uuid4()))
52
+
53
+ # Build the initial payload
54
+ payload = {
55
+ "id": id,
56
+ "call_type": call_type,
57
+ "startTime": start_time,
58
+ "endTime": end_time,
59
+ "model": kwargs.get("model", ""),
60
+ "user": kwargs.get("user", ""),
61
+ "modelParameters": optional_params,
62
+ "messages": messages,
63
+ "response": response_obj,
64
+ "usage": usage,
65
+ "metadata": metadata,
66
+ }
67
+
68
+ # Ensure everything in the payload is converted to str
69
+ for key, value in payload.items():
70
+ try:
71
+ payload[key] = str(value)
72
+ except:
73
+ # non blocking if it can't cast to a str
74
+ pass
75
+
76
+ print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
77
+
78
+ # put data in dyanmo DB
79
+ table = self.dynamodb.Table(self.table_name)
80
+ # Assuming log_data is a dictionary with log information
81
+ response = table.put_item(Item=payload)
82
+
83
+ print_verbose(f"Response from DynamoDB:{str(response)}")
84
+
85
+ print_verbose(
86
+ f"DynamoDB Layer Logging - final response object: {response_obj}"
87
+ )
88
+ return response
89
+ except:
90
+ traceback.print_exc()
91
+ print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
92
+ pass
litellm/integrations/helicone.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success, logs events to Helicone
3
+ import dotenv, os
4
+ import requests
5
+
6
+ dotenv.load_dotenv() # Loading env variables using dotenv
7
+ import traceback
8
+
9
+
10
+ class HeliconeLogger:
11
+ # Class variables or attributes
12
+ helicone_model_list = ["gpt", "claude"]
13
+
14
+ def __init__(self):
15
+ # Instance variables
16
+ self.provider_url = "https://api.openai.com/v1"
17
+ self.key = os.getenv("HELICONE_API_KEY")
18
+
19
+ def claude_mapping(self, model, messages, response_obj):
20
+ from anthropic import HUMAN_PROMPT, AI_PROMPT
21
+
22
+ prompt = f"{HUMAN_PROMPT}"
23
+ for message in messages:
24
+ if "role" in message:
25
+ if message["role"] == "user":
26
+ prompt += f"{HUMAN_PROMPT}{message['content']}"
27
+ else:
28
+ prompt += f"{AI_PROMPT}{message['content']}"
29
+ else:
30
+ prompt += f"{HUMAN_PROMPT}{message['content']}"
31
+ prompt += f"{AI_PROMPT}"
32
+ claude_provider_request = {"model": model, "prompt": prompt}
33
+
34
+ claude_response_obj = {
35
+ "completion": response_obj["choices"][0]["message"]["content"],
36
+ "model": model,
37
+ "stop_reason": "stop_sequence",
38
+ }
39
+
40
+ return claude_provider_request, claude_response_obj
41
+
42
+ def log_success(
43
+ self, model, messages, response_obj, start_time, end_time, print_verbose
44
+ ):
45
+ # Method definition
46
+ try:
47
+ print_verbose(
48
+ f"Helicone Logging - Enters logging function for model {model}"
49
+ )
50
+ model = (
51
+ model
52
+ if any(
53
+ accepted_model in model
54
+ for accepted_model in self.helicone_model_list
55
+ )
56
+ else "gpt-3.5-turbo"
57
+ )
58
+ provider_request = {"model": model, "messages": messages}
59
+
60
+ if "claude" in model:
61
+ provider_request, response_obj = self.claude_mapping(
62
+ model=model, messages=messages, response_obj=response_obj
63
+ )
64
+
65
+ providerResponse = {
66
+ "json": response_obj,
67
+ "headers": {"openai-version": "2020-10-01"},
68
+ "status": 200,
69
+ }
70
+
71
+ # Code to be executed
72
+ url = "https://api.hconeai.com/oai/v1/log"
73
+ headers = {
74
+ "Authorization": f"Bearer {self.key}",
75
+ "Content-Type": "application/json",
76
+ }
77
+ start_time_seconds = int(start_time.timestamp())
78
+ start_time_milliseconds = int(
79
+ (start_time.timestamp() - start_time_seconds) * 1000
80
+ )
81
+ end_time_seconds = int(end_time.timestamp())
82
+ end_time_milliseconds = int(
83
+ (end_time.timestamp() - end_time_seconds) * 1000
84
+ )
85
+ data = {
86
+ "providerRequest": {
87
+ "url": self.provider_url,
88
+ "json": provider_request,
89
+ "meta": {"Helicone-Auth": f"Bearer {self.key}"},
90
+ },
91
+ "providerResponse": providerResponse,
92
+ "timing": {
93
+ "startTime": {
94
+ "seconds": start_time_seconds,
95
+ "milliseconds": start_time_milliseconds,
96
+ },
97
+ "endTime": {
98
+ "seconds": end_time_seconds,
99
+ "milliseconds": end_time_milliseconds,
100
+ },
101
+ }, # {"seconds": .., "milliseconds": ..}
102
+ }
103
+ response = requests.post(url, headers=headers, json=data)
104
+ if response.status_code == 200:
105
+ print_verbose("Helicone Logging - Success!")
106
+ else:
107
+ print_verbose(
108
+ f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
109
+ )
110
+ print_verbose(f"Helicone Logging - Error {response.text}")
111
+ except:
112
+ # traceback.print_exc()
113
+ print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
114
+ pass
litellm/integrations/langfuse.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success, logs events to Langfuse
3
+ import dotenv, os
4
+ import requests
5
+ import requests
6
+ from datetime import datetime
7
+
8
+ dotenv.load_dotenv() # Loading env variables using dotenv
9
+ import traceback
10
+ from packaging.version import Version
11
+
12
+
13
+ class LangFuseLogger:
14
+ # Class variables or attributes
15
+ def __init__(self):
16
+ try:
17
+ from langfuse import Langfuse
18
+ except Exception as e:
19
+ raise Exception(
20
+ f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\033[0m"
21
+ )
22
+ # Instance variables
23
+ self.secret_key = os.getenv("LANGFUSE_SECRET_KEY")
24
+ self.public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
25
+ self.langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
26
+ self.langfuse_release = os.getenv("LANGFUSE_RELEASE")
27
+ self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")
28
+ self.Langfuse = Langfuse(
29
+ public_key=self.public_key,
30
+ secret_key=self.secret_key,
31
+ host=self.langfuse_host,
32
+ release=self.langfuse_release,
33
+ debug=self.langfuse_debug,
34
+ )
35
+
36
+ def log_event(
37
+ self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
38
+ ):
39
+ # Method definition
40
+
41
+ try:
42
+ print_verbose(
43
+ f"Langfuse Logging - Enters logging function for model {kwargs}"
44
+ )
45
+ litellm_params = kwargs.get("litellm_params", {})
46
+ metadata = (
47
+ litellm_params.get("metadata", {}) or {}
48
+ ) # if litellm_params['metadata'] == None
49
+ prompt = [kwargs.get("messages")]
50
+ optional_params = kwargs.get("optional_params", {})
51
+
52
+ optional_params.pop("functions", None)
53
+ optional_params.pop("tools", None)
54
+
55
+ # langfuse only accepts str, int, bool, float for logging
56
+ for param, value in optional_params.items():
57
+ if not isinstance(value, (str, int, bool, float)):
58
+ try:
59
+ optional_params[param] = str(value)
60
+ except:
61
+ # if casting value to str fails don't block logging
62
+ pass
63
+
64
+ # end of processing langfuse ########################
65
+ input = prompt
66
+ output = response_obj["choices"][0]["message"].json()
67
+ print_verbose(
68
+ f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}"
69
+ )
70
+ self._log_langfuse_v2(
71
+ user_id,
72
+ metadata,
73
+ output,
74
+ start_time,
75
+ end_time,
76
+ kwargs,
77
+ optional_params,
78
+ input,
79
+ response_obj,
80
+ ) if self._is_langfuse_v2() else self._log_langfuse_v1(
81
+ user_id,
82
+ metadata,
83
+ output,
84
+ start_time,
85
+ end_time,
86
+ kwargs,
87
+ optional_params,
88
+ input,
89
+ response_obj,
90
+ )
91
+
92
+ self.Langfuse.flush()
93
+ print_verbose(
94
+ f"Langfuse Layer Logging - final response object: {response_obj}"
95
+ )
96
+ except:
97
+ traceback.print_exc()
98
+ print_verbose(f"Langfuse Layer Error - {traceback.format_exc()}")
99
+ pass
100
+
101
+ async def _async_log_event(
102
+ self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
103
+ ):
104
+ self.log_event(
105
+ kwargs, response_obj, start_time, end_time, user_id, print_verbose
106
+ )
107
+
108
+ def _is_langfuse_v2(self):
109
+ import langfuse
110
+
111
+ return Version(langfuse.version.__version__) >= Version("2.0.0")
112
+
113
+ def _log_langfuse_v1(
114
+ self,
115
+ user_id,
116
+ metadata,
117
+ output,
118
+ start_time,
119
+ end_time,
120
+ kwargs,
121
+ optional_params,
122
+ input,
123
+ response_obj,
124
+ ):
125
+ from langfuse.model import CreateTrace, CreateGeneration
126
+
127
+ print(
128
+ "Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
129
+ )
130
+
131
+ trace = self.Langfuse.trace(
132
+ CreateTrace(
133
+ name=metadata.get("generation_name", "litellm-completion"),
134
+ input=input,
135
+ output=output,
136
+ userId=user_id,
137
+ )
138
+ )
139
+
140
+ trace.generation(
141
+ CreateGeneration(
142
+ name=metadata.get("generation_name", "litellm-completion"),
143
+ startTime=start_time,
144
+ endTime=end_time,
145
+ model=kwargs["model"],
146
+ modelParameters=optional_params,
147
+ input=input,
148
+ output=output,
149
+ usage={
150
+ "prompt_tokens": response_obj["usage"]["prompt_tokens"],
151
+ "completion_tokens": response_obj["usage"]["completion_tokens"],
152
+ },
153
+ metadata=metadata,
154
+ )
155
+ )
156
+
157
+ def _log_langfuse_v2(
158
+ self,
159
+ user_id,
160
+ metadata,
161
+ output,
162
+ start_time,
163
+ end_time,
164
+ kwargs,
165
+ optional_params,
166
+ input,
167
+ response_obj,
168
+ ):
169
+ trace = self.Langfuse.trace(
170
+ name=metadata.get("generation_name", "litellm-completion"),
171
+ input=input,
172
+ output=output,
173
+ user_id=metadata.get("trace_user_id", user_id),
174
+ id=metadata.get("trace_id", None),
175
+ )
176
+
177
+ trace.generation(
178
+ name=metadata.get("generation_name", "litellm-completion"),
179
+ id=metadata.get("generation_id", None),
180
+ startTime=start_time,
181
+ endTime=end_time,
182
+ model=kwargs["model"],
183
+ modelParameters=optional_params,
184
+ input=input,
185
+ output=output,
186
+ usage={
187
+ "prompt_tokens": response_obj["usage"]["prompt_tokens"],
188
+ "completion_tokens": response_obj["usage"]["completion_tokens"],
189
+ },
190
+ metadata=metadata,
191
+ )
litellm/integrations/langsmith.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success, logs events to Langsmith
3
+ import dotenv, os
4
+ import requests
5
+ import requests
6
+ from datetime import datetime
7
+
8
+ dotenv.load_dotenv() # Loading env variables using dotenv
9
+ import traceback
10
+
11
+
12
+ class LangsmithLogger:
13
+ # Class variables or attributes
14
+ def __init__(self):
15
+ self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
16
+
17
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
18
+ # Method definition
19
+ # inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
20
+ metadata = {}
21
+ if "litellm_params" in kwargs:
22
+ metadata = kwargs["litellm_params"].get("metadata", {})
23
+ # set project name and run_name for langsmith logging
24
+ # users can pass project_name and run name to litellm.completion()
25
+ # Example: litellm.completion(model, messages, metadata={"project_name": "my-litellm-project", "run_name": "my-langsmith-run"})
26
+ # if not set litellm will use default project_name = litellm-completion, run_name = LLMRun
27
+ project_name = metadata.get("project_name", "litellm-completion")
28
+ run_name = metadata.get("run_name", "LLMRun")
29
+ print_verbose(
30
+ f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
31
+ )
32
+ try:
33
+ print_verbose(
34
+ f"Langsmith Logging - Enters logging function for model {kwargs}"
35
+ )
36
+ import requests
37
+ import datetime
38
+ from datetime import timezone
39
+
40
+ try:
41
+ start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat()
42
+ end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat()
43
+ except:
44
+ start_time = datetime.datetime.utcnow().isoformat()
45
+ end_time = datetime.datetime.utcnow().isoformat()
46
+
47
+ # filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs
48
+ new_kwargs = {}
49
+ for key in kwargs:
50
+ value = kwargs[key]
51
+ if key == "start_time" or key == "end_time":
52
+ pass
53
+ elif type(value) != dict:
54
+ new_kwargs[key] = value
55
+
56
+ requests.post(
57
+ "https://api.smith.langchain.com/runs",
58
+ json={
59
+ "name": run_name,
60
+ "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
61
+ "inputs": {**new_kwargs},
62
+ "outputs": response_obj.json(),
63
+ "session_name": project_name,
64
+ "start_time": start_time,
65
+ "end_time": end_time,
66
+ },
67
+ headers={"x-api-key": self.langsmith_api_key},
68
+ )
69
+ print_verbose(
70
+ f"Langsmith Layer Logging - final response object: {response_obj}"
71
+ )
72
+ except:
73
+ # traceback.print_exc()
74
+ print_verbose(f"Langsmith Layer Error - {traceback.format_exc()}")
75
+ pass
litellm/integrations/litedebugger.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests, traceback, json, os
2
+ import types
3
+
4
+
5
+ class LiteDebugger:
6
+ user_email = None
7
+ dashboard_url = None
8
+
9
+ def __init__(self, email=None):
10
+ self.api_url = "https://api.litellm.ai/debugger"
11
+ self.validate_environment(email)
12
+ pass
13
+
14
+ def validate_environment(self, email):
15
+ try:
16
+ self.user_email = (
17
+ email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
18
+ )
19
+ if (
20
+ self.user_email == None
21
+ ): # if users are trying to use_client=True but token not set
22
+ raise ValueError(
23
+ "litellm.use_client = True but no token or email passed. Please set it in litellm.token"
24
+ )
25
+ self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
26
+ try:
27
+ print(
28
+ f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m"
29
+ )
30
+ except:
31
+ print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}")
32
+ if self.user_email == None:
33
+ raise ValueError(
34
+ "[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
35
+ )
36
+ except Exception as e:
37
+ raise ValueError(
38
+ "[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
39
+ )
40
+
41
+ def input_log_event(
42
+ self,
43
+ model,
44
+ messages,
45
+ end_user,
46
+ litellm_call_id,
47
+ call_type,
48
+ print_verbose,
49
+ litellm_params,
50
+ optional_params,
51
+ ):
52
+ print_verbose(
53
+ f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
54
+ )
55
+ try:
56
+ print_verbose(
57
+ f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
58
+ )
59
+
60
+ def remove_key_value(dictionary, key):
61
+ new_dict = dictionary.copy() # Create a copy of the original dictionary
62
+ new_dict.pop(key) # Remove the specified key-value pair from the copy
63
+ return new_dict
64
+
65
+ updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
66
+
67
+ if call_type == "embedding":
68
+ for (
69
+ message
70
+ ) in (
71
+ messages
72
+ ): # assuming the input is a list as required by the embedding function
73
+ litellm_data_obj = {
74
+ "model": model,
75
+ "messages": [{"role": "user", "content": message}],
76
+ "end_user": end_user,
77
+ "status": "initiated",
78
+ "litellm_call_id": litellm_call_id,
79
+ "user_email": self.user_email,
80
+ "litellm_params": updated_litellm_params,
81
+ "optional_params": optional_params,
82
+ }
83
+ print_verbose(
84
+ f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
85
+ )
86
+ response = requests.post(
87
+ url=self.api_url,
88
+ headers={"content-type": "application/json"},
89
+ data=json.dumps(litellm_data_obj),
90
+ )
91
+ print_verbose(f"LiteDebugger: embedding api response - {response.text}")
92
+ elif call_type == "completion":
93
+ litellm_data_obj = {
94
+ "model": model,
95
+ "messages": messages
96
+ if isinstance(messages, list)
97
+ else [{"role": "user", "content": messages}],
98
+ "end_user": end_user,
99
+ "status": "initiated",
100
+ "litellm_call_id": litellm_call_id,
101
+ "user_email": self.user_email,
102
+ "litellm_params": updated_litellm_params,
103
+ "optional_params": optional_params,
104
+ }
105
+ print_verbose(
106
+ f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
107
+ )
108
+ response = requests.post(
109
+ url=self.api_url,
110
+ headers={"content-type": "application/json"},
111
+ data=json.dumps(litellm_data_obj),
112
+ )
113
+ print_verbose(
114
+ f"LiteDebugger: completion api response - {response.text}"
115
+ )
116
+ except:
117
+ print_verbose(
118
+ f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
119
+ )
120
+ pass
121
+
122
+ def post_call_log_event(
123
+ self, original_response, litellm_call_id, print_verbose, call_type, stream
124
+ ):
125
+ print_verbose(
126
+ f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
127
+ )
128
+ try:
129
+ if call_type == "embedding":
130
+ litellm_data_obj = {
131
+ "status": "received",
132
+ "additional_details": {
133
+ "original_response": str(
134
+ original_response["data"][0]["embedding"][:5]
135
+ )
136
+ }, # don't store the entire vector
137
+ "litellm_call_id": litellm_call_id,
138
+ "user_email": self.user_email,
139
+ }
140
+ elif call_type == "completion" and not stream:
141
+ litellm_data_obj = {
142
+ "status": "received",
143
+ "additional_details": {"original_response": original_response},
144
+ "litellm_call_id": litellm_call_id,
145
+ "user_email": self.user_email,
146
+ }
147
+ elif call_type == "completion" and stream:
148
+ litellm_data_obj = {
149
+ "status": "received",
150
+ "additional_details": {
151
+ "original_response": "Streamed response"
152
+ if isinstance(original_response, types.GeneratorType)
153
+ else original_response
154
+ },
155
+ "litellm_call_id": litellm_call_id,
156
+ "user_email": self.user_email,
157
+ }
158
+ print_verbose(f"litedebugger post-call data object - {litellm_data_obj}")
159
+ response = requests.post(
160
+ url=self.api_url,
161
+ headers={"content-type": "application/json"},
162
+ data=json.dumps(litellm_data_obj),
163
+ )
164
+ print_verbose(f"LiteDebugger: api response - {response.text}")
165
+ except:
166
+ print_verbose(
167
+ f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
168
+ )
169
+
170
+ def log_event(
171
+ self,
172
+ end_user,
173
+ response_obj,
174
+ start_time,
175
+ end_time,
176
+ litellm_call_id,
177
+ print_verbose,
178
+ call_type,
179
+ stream=False,
180
+ ):
181
+ print_verbose(
182
+ f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
183
+ )
184
+ try:
185
+ print_verbose(
186
+ f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"
187
+ )
188
+ total_cost = 0 # [TODO] implement cost tracking
189
+ response_time = (end_time - start_time).total_seconds()
190
+ if call_type == "completion" and stream == False:
191
+ litellm_data_obj = {
192
+ "response_time": response_time,
193
+ "total_cost": total_cost,
194
+ "response": response_obj["choices"][0]["message"]["content"],
195
+ "litellm_call_id": litellm_call_id,
196
+ "status": "success",
197
+ }
198
+ print_verbose(
199
+ f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
200
+ )
201
+ response = requests.post(
202
+ url=self.api_url,
203
+ headers={"content-type": "application/json"},
204
+ data=json.dumps(litellm_data_obj),
205
+ )
206
+ elif call_type == "embedding":
207
+ litellm_data_obj = {
208
+ "response_time": response_time,
209
+ "total_cost": total_cost,
210
+ "response": str(response_obj["data"][0]["embedding"][:5]),
211
+ "litellm_call_id": litellm_call_id,
212
+ "status": "success",
213
+ }
214
+ response = requests.post(
215
+ url=self.api_url,
216
+ headers={"content-type": "application/json"},
217
+ data=json.dumps(litellm_data_obj),
218
+ )
219
+ elif call_type == "completion" and stream == True:
220
+ if len(response_obj["content"]) > 0: # don't log the empty strings
221
+ litellm_data_obj = {
222
+ "response_time": response_time,
223
+ "total_cost": total_cost,
224
+ "response": response_obj["content"],
225
+ "litellm_call_id": litellm_call_id,
226
+ "status": "success",
227
+ }
228
+ print_verbose(
229
+ f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
230
+ )
231
+ response = requests.post(
232
+ url=self.api_url,
233
+ headers={"content-type": "application/json"},
234
+ data=json.dumps(litellm_data_obj),
235
+ )
236
+ elif "error" in response_obj:
237
+ if "Unable to map your input to a model." in response_obj["error"]:
238
+ total_cost = 0
239
+ litellm_data_obj = {
240
+ "response_time": response_time,
241
+ "model": response_obj["model"],
242
+ "total_cost": total_cost,
243
+ "error": response_obj["error"],
244
+ "end_user": end_user,
245
+ "litellm_call_id": litellm_call_id,
246
+ "status": "failure",
247
+ "user_email": self.user_email,
248
+ }
249
+ print_verbose(
250
+ f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
251
+ )
252
+ response = requests.post(
253
+ url=self.api_url,
254
+ headers={"content-type": "application/json"},
255
+ data=json.dumps(litellm_data_obj),
256
+ )
257
+ print_verbose(f"LiteDebugger: api response - {response.text}")
258
+ except:
259
+ print_verbose(
260
+ f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
261
+ )
262
+ pass
litellm/integrations/llmonitor.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success + failure, log events to aispend.io
3
+ import datetime
4
+ import traceback
5
+ import dotenv
6
+ import os
7
+ import requests
8
+
9
+ dotenv.load_dotenv() # Loading env variables using dotenv
10
+
11
+
12
+ # convert to {completion: xx, tokens: xx}
13
+ def parse_usage(usage):
14
+ return {
15
+ "completion": usage["completion_tokens"] if "completion_tokens" in usage else 0,
16
+ "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
17
+ }
18
+
19
+
20
+ def parse_messages(input):
21
+ if input is None:
22
+ return None
23
+
24
+ def clean_message(message):
25
+ # if is strin, return as is
26
+ if isinstance(message, str):
27
+ return message
28
+
29
+ if "message" in message:
30
+ return clean_message(message["message"])
31
+ text = message["content"]
32
+ if text == None:
33
+ text = message.get("function_call", None)
34
+
35
+ return {
36
+ "role": message["role"],
37
+ "text": text,
38
+ }
39
+
40
+ if isinstance(input, list):
41
+ if len(input) == 1:
42
+ return clean_message(input[0])
43
+ else:
44
+ return [clean_message(msg) for msg in input]
45
+ else:
46
+ return clean_message(input)
47
+
48
+
49
+ class LLMonitorLogger:
50
+ # Class variables or attributes
51
+ def __init__(self):
52
+ # Instance variables
53
+ self.api_url = os.getenv("LLMONITOR_API_URL") or "https://app.llmonitor.com"
54
+ self.app_id = os.getenv("LLMONITOR_APP_ID")
55
+
56
+ def log_event(
57
+ self,
58
+ type,
59
+ event,
60
+ run_id,
61
+ model,
62
+ print_verbose,
63
+ input=None,
64
+ user_id=None,
65
+ response_obj=None,
66
+ start_time=datetime.datetime.now(),
67
+ end_time=datetime.datetime.now(),
68
+ error=None,
69
+ ):
70
+ # Method definition
71
+ try:
72
+ print_verbose(f"LLMonitor Logging - Logging request for model {model}")
73
+
74
+ if response_obj:
75
+ usage = (
76
+ parse_usage(response_obj["usage"])
77
+ if "usage" in response_obj
78
+ else None
79
+ )
80
+ output = response_obj["choices"] if "choices" in response_obj else None
81
+ else:
82
+ usage = None
83
+ output = None
84
+
85
+ if error:
86
+ error_obj = {"stack": error}
87
+
88
+ else:
89
+ error_obj = None
90
+
91
+ data = [
92
+ {
93
+ "type": type,
94
+ "name": model,
95
+ "runId": run_id,
96
+ "app": self.app_id,
97
+ "event": "start",
98
+ "timestamp": start_time.isoformat(),
99
+ "userId": user_id,
100
+ "input": parse_messages(input),
101
+ },
102
+ {
103
+ "type": type,
104
+ "runId": run_id,
105
+ "app": self.app_id,
106
+ "event": event,
107
+ "error": error_obj,
108
+ "timestamp": end_time.isoformat(),
109
+ "userId": user_id,
110
+ "output": parse_messages(output),
111
+ "tokensUsage": usage,
112
+ },
113
+ ]
114
+
115
+ print_verbose(f"LLMonitor Logging - final data object: {data}")
116
+
117
+ response = requests.post(
118
+ self.api_url + "/api/report",
119
+ headers={"Content-Type": "application/json"},
120
+ json={"events": data},
121
+ )
122
+
123
+ print_verbose(f"LLMonitor Logging - response: {response}")
124
+ except:
125
+ # traceback.print_exc()
126
+ print_verbose(f"LLMonitor Logging Error - {traceback.format_exc()}")
127
+ pass
litellm/integrations/prompt_layer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success, logs events to Promptlayer
3
+ import dotenv, os
4
+ import requests
5
+ import requests
6
+
7
+ dotenv.load_dotenv() # Loading env variables using dotenv
8
+ import traceback
9
+
10
+
11
+ class PromptLayerLogger:
12
+ # Class variables or attributes
13
+ def __init__(self):
14
+ # Instance variables
15
+ self.key = os.getenv("PROMPTLAYER_API_KEY")
16
+
17
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
18
+ # Method definition
19
+ try:
20
+ new_kwargs = {}
21
+ new_kwargs["model"] = kwargs["model"]
22
+ new_kwargs["messages"] = kwargs["messages"]
23
+
24
+ # add kwargs["optional_params"] to new_kwargs
25
+ for optional_param in kwargs["optional_params"]:
26
+ new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
27
+
28
+ print_verbose(
29
+ f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
30
+ )
31
+
32
+ request_response = requests.post(
33
+ "https://api.promptlayer.com/rest/track-request",
34
+ json={
35
+ "function_name": "openai.ChatCompletion.create",
36
+ "kwargs": new_kwargs,
37
+ "tags": ["hello", "world"],
38
+ "request_response": dict(response_obj),
39
+ "request_start_time": int(start_time.timestamp()),
40
+ "request_end_time": int(end_time.timestamp()),
41
+ "api_key": self.key,
42
+ # Optional params for PromptLayer
43
+ # "prompt_id": "<PROMPT ID>",
44
+ # "prompt_input_variables": "<Dictionary of variables for prompt>",
45
+ # "prompt_version":1,
46
+ },
47
+ )
48
+ print_verbose(
49
+ f"Prompt Layer Logging: success - final response object: {request_response.text}"
50
+ )
51
+ response_json = request_response.json()
52
+ if "success" not in request_response.json():
53
+ raise Exception("Promptlayer did not successfully log the response!")
54
+
55
+ if "request_id" in response_json:
56
+ print(kwargs["litellm_params"]["metadata"])
57
+ if kwargs["litellm_params"]["metadata"] is not None:
58
+ response = requests.post(
59
+ "https://api.promptlayer.com/rest/track-metadata",
60
+ json={
61
+ "request_id": response_json["request_id"],
62
+ "api_key": self.key,
63
+ "metadata": kwargs["litellm_params"]["metadata"],
64
+ },
65
+ )
66
+ print_verbose(
67
+ f"Prompt Layer Logging: success - metadata post response object: {response.text}"
68
+ )
69
+
70
+ except:
71
+ print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")
72
+ pass
litellm/integrations/s3.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success + failure, log events to Supabase
3
+
4
+ import dotenv, os
5
+ import requests
6
+
7
+ dotenv.load_dotenv() # Loading env variables using dotenv
8
+ import traceback
9
+ import datetime, subprocess, sys
10
+ import litellm, uuid
11
+ from litellm._logging import print_verbose
12
+
13
+
14
+ class S3Logger:
15
+ # Class variables or attributes
16
+ def __init__(
17
+ self,
18
+ s3_bucket_name=None,
19
+ s3_region_name=None,
20
+ s3_api_version=None,
21
+ s3_use_ssl=True,
22
+ s3_verify=None,
23
+ s3_endpoint_url=None,
24
+ s3_aws_access_key_id=None,
25
+ s3_aws_secret_access_key=None,
26
+ s3_aws_session_token=None,
27
+ s3_config=None,
28
+ **kwargs,
29
+ ):
30
+ import boto3
31
+
32
+ try:
33
+ print_verbose("in init s3 logger")
34
+
35
+ if litellm.s3_callback_params is not None:
36
+ # read in .env variables - example os.environ/AWS_BUCKET_NAME
37
+ for key, value in litellm.s3_callback_params.items():
38
+ if type(value) is str and value.startswith("os.environ/"):
39
+ litellm.s3_callback_params[key] = litellm.get_secret(value)
40
+ # now set s3 params from litellm.s3_logger_params
41
+ s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name")
42
+ s3_region_name = litellm.s3_callback_params.get("s3_region_name")
43
+ s3_api_version = litellm.s3_callback_params.get("s3_api_version")
44
+ s3_use_ssl = litellm.s3_callback_params.get("s3_use_ssl")
45
+ s3_verify = litellm.s3_callback_params.get("s3_verify")
46
+ s3_endpoint_url = litellm.s3_callback_params.get("s3_endpoint_url")
47
+ s3_aws_access_key_id = litellm.s3_callback_params.get(
48
+ "s3_aws_access_key_id"
49
+ )
50
+ s3_aws_secret_access_key = litellm.s3_callback_params.get(
51
+ "s3_aws_secret_access_key"
52
+ )
53
+ s3_aws_session_token = litellm.s3_callback_params.get(
54
+ "s3_aws_session_token"
55
+ )
56
+ s3_config = litellm.s3_callback_params.get("s3_config")
57
+ # done reading litellm.s3_callback_params
58
+
59
+ self.bucket_name = s3_bucket_name
60
+ # Create an S3 client with custom endpoint URL
61
+ self.s3_client = boto3.client(
62
+ "s3",
63
+ region_name=s3_region_name,
64
+ endpoint_url=s3_endpoint_url,
65
+ api_version=s3_api_version,
66
+ use_ssl=s3_use_ssl,
67
+ verify=s3_verify,
68
+ aws_access_key_id=s3_aws_access_key_id,
69
+ aws_secret_access_key=s3_aws_secret_access_key,
70
+ aws_session_token=s3_aws_session_token,
71
+ config=s3_config,
72
+ **kwargs,
73
+ )
74
+ except Exception as e:
75
+ print_verbose(f"Got exception on init s3 client {str(e)}")
76
+ raise e
77
+
78
+ async def _async_log_event(
79
+ self, kwargs, response_obj, start_time, end_time, print_verbose
80
+ ):
81
+ self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
82
+
83
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
84
+ try:
85
+ print_verbose(f"s3 Logging - Enters logging function for model {kwargs}")
86
+
87
+ # construct payload to send to s3
88
+ # follows the same params as langfuse.py
89
+ litellm_params = kwargs.get("litellm_params", {})
90
+ metadata = (
91
+ litellm_params.get("metadata", {}) or {}
92
+ ) # if litellm_params['metadata'] == None
93
+ messages = kwargs.get("messages")
94
+ optional_params = kwargs.get("optional_params", {})
95
+ call_type = kwargs.get("call_type", "litellm.completion")
96
+ cache_hit = kwargs.get("cache_hit", False)
97
+ usage = response_obj["usage"]
98
+ id = response_obj.get("id", str(uuid.uuid4()))
99
+
100
+ # Build the initial payload
101
+ payload = {
102
+ "id": id,
103
+ "call_type": call_type,
104
+ "cache_hit": cache_hit,
105
+ "startTime": start_time,
106
+ "endTime": end_time,
107
+ "model": kwargs.get("model", ""),
108
+ "user": kwargs.get("user", ""),
109
+ "modelParameters": optional_params,
110
+ "messages": messages,
111
+ "response": response_obj,
112
+ "usage": usage,
113
+ "metadata": metadata,
114
+ }
115
+
116
+ # Ensure everything in the payload is converted to str
117
+ for key, value in payload.items():
118
+ try:
119
+ payload[key] = str(value)
120
+ except:
121
+ # non blocking if it can't cast to a str
122
+ pass
123
+
124
+ s3_object_key = (
125
+ payload["id"] + "-time=" + str(start_time)
126
+ ) # we need the s3 key to include the time, so we log cache hits too
127
+
128
+ import json
129
+
130
+ payload = json.dumps(payload)
131
+
132
+ print_verbose(f"\ns3 Logger - Logging payload = {payload}")
133
+
134
+ response = self.s3_client.put_object(
135
+ Bucket=self.bucket_name,
136
+ Key=s3_object_key,
137
+ Body=payload,
138
+ ContentType="application/json",
139
+ ContentLanguage="en",
140
+ ContentDisposition=f'inline; filename="{key}.json"',
141
+ )
142
+
143
+ print_verbose(f"Response from s3:{str(response)}")
144
+
145
+ print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
146
+ return response
147
+ except Exception as e:
148
+ traceback.print_exc()
149
+ print_verbose(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}")
150
+ pass
litellm/integrations/supabase.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # On success + failure, log events to Supabase
3
+
4
+ import dotenv, os
5
+ import requests
6
+
7
+ dotenv.load_dotenv() # Loading env variables using dotenv
8
+ import traceback
9
+ import datetime, subprocess, sys
10
+ import litellm
11
+
12
+
13
+ class Supabase:
14
+ # Class variables or attributes
15
+ supabase_table_name = "request_logs"
16
+
17
+ def __init__(self):
18
+ # Instance variables
19
+ self.supabase_url = os.getenv("SUPABASE_URL")
20
+ self.supabase_key = os.getenv("SUPABASE_KEY")
21
+ try:
22
+ import supabase
23
+ except ImportError:
24
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"])
25
+ import supabase
26
+ self.supabase_client = supabase.create_client(
27
+ self.supabase_url, self.supabase_key
28
+ )
29
+
30
+ def input_log_event(
31
+ self, model, messages, end_user, litellm_call_id, print_verbose
32
+ ):
33
+ try:
34
+ print_verbose(
35
+ f"Supabase Logging - Enters input logging function for model {model}"
36
+ )
37
+ supabase_data_obj = {
38
+ "model": model,
39
+ "messages": messages,
40
+ "end_user": end_user,
41
+ "status": "initiated",
42
+ "litellm_call_id": litellm_call_id,
43
+ }
44
+ data, count = (
45
+ self.supabase_client.table(self.supabase_table_name)
46
+ .insert(supabase_data_obj)
47
+ .execute()
48
+ )
49
+ print_verbose(f"data: {data}")
50
+ except:
51
+ print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
52
+ pass
53
+
54
+ def log_event(
55
+ self,
56
+ model,
57
+ messages,
58
+ end_user,
59
+ response_obj,
60
+ start_time,
61
+ end_time,
62
+ litellm_call_id,
63
+ print_verbose,
64
+ ):
65
+ try:
66
+ print_verbose(
67
+ f"Supabase Logging - Enters logging function for model {model}, response_obj: {response_obj}"
68
+ )
69
+
70
+ total_cost = litellm.completion_cost(completion_response=response_obj)
71
+
72
+ response_time = (end_time - start_time).total_seconds()
73
+ if "choices" in response_obj:
74
+ supabase_data_obj = {
75
+ "response_time": response_time,
76
+ "model": response_obj["model"],
77
+ "total_cost": total_cost,
78
+ "messages": messages,
79
+ "response": response_obj["choices"][0]["message"]["content"],
80
+ "end_user": end_user,
81
+ "litellm_call_id": litellm_call_id,
82
+ "status": "success",
83
+ }
84
+ print_verbose(
85
+ f"Supabase Logging - final data object: {supabase_data_obj}"
86
+ )
87
+ data, count = (
88
+ self.supabase_client.table(self.supabase_table_name)
89
+ .upsert(supabase_data_obj, on_conflict="litellm_call_id")
90
+ .execute()
91
+ )
92
+ elif "error" in response_obj:
93
+ if "Unable to map your input to a model." in response_obj["error"]:
94
+ total_cost = 0
95
+ supabase_data_obj = {
96
+ "response_time": response_time,
97
+ "model": response_obj["model"],
98
+ "total_cost": total_cost,
99
+ "messages": messages,
100
+ "error": response_obj["error"],
101
+ "end_user": end_user,
102
+ "litellm_call_id": litellm_call_id,
103
+ "status": "failure",
104
+ }
105
+ print_verbose(
106
+ f"Supabase Logging - final data object: {supabase_data_obj}"
107
+ )
108
+ data, count = (
109
+ self.supabase_client.table(self.supabase_table_name)
110
+ .upsert(supabase_data_obj, on_conflict="litellm_call_id")
111
+ .execute()
112
+ )
113
+
114
+ except:
115
+ # traceback.print_exc()
116
+ print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
117
+ pass
litellm/integrations/traceloop.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class TraceloopLogger:
2
+ def __init__(self):
3
+ from traceloop.sdk.tracing.tracing import TracerWrapper
4
+ from traceloop.sdk import Traceloop
5
+
6
+ Traceloop.init(app_name="Litellm-Server", disable_batch=True)
7
+ self.tracer_wrapper = TracerWrapper()
8
+
9
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
10
+ from opentelemetry.trace import SpanKind
11
+ from opentelemetry.semconv.ai import SpanAttributes
12
+
13
+ try:
14
+ tracer = self.tracer_wrapper.get_tracer()
15
+
16
+ model = kwargs.get("model")
17
+
18
+ # LiteLLM uses the standard OpenAI library, so it's already handled by Traceloop SDK
19
+ if kwargs.get("litellm_params").get("custom_llm_provider") == "openai":
20
+ return
21
+
22
+ optional_params = kwargs.get("optional_params", {})
23
+ with tracer.start_as_current_span(
24
+ "litellm.completion",
25
+ kind=SpanKind.CLIENT,
26
+ ) as span:
27
+ if span.is_recording():
28
+ span.set_attribute(
29
+ SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model")
30
+ )
31
+ if "stop" in optional_params:
32
+ span.set_attribute(
33
+ SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
34
+ optional_params.get("stop"),
35
+ )
36
+ if "frequency_penalty" in optional_params:
37
+ span.set_attribute(
38
+ SpanAttributes.LLM_FREQUENCY_PENALTY,
39
+ optional_params.get("frequency_penalty"),
40
+ )
41
+ if "presence_penalty" in optional_params:
42
+ span.set_attribute(
43
+ SpanAttributes.LLM_PRESENCE_PENALTY,
44
+ optional_params.get("presence_penalty"),
45
+ )
46
+ if "top_p" in optional_params:
47
+ span.set_attribute(
48
+ SpanAttributes.LLM_TOP_P, optional_params.get("top_p")
49
+ )
50
+ if "tools" in optional_params or "functions" in optional_params:
51
+ span.set_attribute(
52
+ SpanAttributes.LLM_REQUEST_FUNCTIONS,
53
+ optional_params.get(
54
+ "tools", optional_params.get("functions")
55
+ ),
56
+ )
57
+ if "user" in optional_params:
58
+ span.set_attribute(
59
+ SpanAttributes.LLM_USER, optional_params.get("user")
60
+ )
61
+ if "max_tokens" in optional_params:
62
+ span.set_attribute(
63
+ SpanAttributes.LLM_REQUEST_MAX_TOKENS,
64
+ kwargs.get("max_tokens"),
65
+ )
66
+ if "temperature" in optional_params:
67
+ span.set_attribute(
68
+ SpanAttributes.LLM_TEMPERATURE, kwargs.get("temperature")
69
+ )
70
+
71
+ for idx, prompt in enumerate(kwargs.get("messages")):
72
+ span.set_attribute(
73
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.role",
74
+ prompt.get("role"),
75
+ )
76
+ span.set_attribute(
77
+ f"{SpanAttributes.LLM_PROMPTS}.{idx}.content",
78
+ prompt.get("content"),
79
+ )
80
+
81
+ span.set_attribute(
82
+ SpanAttributes.LLM_RESPONSE_MODEL, response_obj.get("model")
83
+ )
84
+ usage = response_obj.get("usage")
85
+ if usage:
86
+ span.set_attribute(
87
+ SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
88
+ usage.get("total_tokens"),
89
+ )
90
+ span.set_attribute(
91
+ SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
92
+ usage.get("completion_tokens"),
93
+ )
94
+ span.set_attribute(
95
+ SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
96
+ usage.get("prompt_tokens"),
97
+ )
98
+
99
+ for idx, choice in enumerate(response_obj.get("choices")):
100
+ span.set_attribute(
101
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.finish_reason",
102
+ choice.get("finish_reason"),
103
+ )
104
+ span.set_attribute(
105
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.role",
106
+ choice.get("message").get("role"),
107
+ )
108
+ span.set_attribute(
109
+ f"{SpanAttributes.LLM_COMPLETIONS}.{idx}.content",
110
+ choice.get("message").get("content"),
111
+ )
112
+
113
+ except Exception as e:
114
+ print_verbose(f"Traceloop Layer Error - {e}")
litellm/integrations/weights_biases.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imported_openAIResponse = True
2
+ try:
3
+ import io
4
+ import logging
5
+ import sys
6
+ from typing import Any, Dict, List, Optional, TypeVar
7
+
8
+ from wandb.sdk.data_types import trace_tree
9
+
10
+ if sys.version_info >= (3, 8):
11
+ from typing import Literal, Protocol
12
+ else:
13
+ from typing_extensions import Literal, Protocol
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ K = TypeVar("K", bound=str)
18
+ V = TypeVar("V")
19
+
20
+ class OpenAIResponse(Protocol[K, V]): # type: ignore
21
+ # contains a (known) object attribute
22
+ object: Literal["chat.completion", "edit", "text_completion"]
23
+
24
+ def __getitem__(self, key: K) -> V:
25
+ ... # pragma: no cover
26
+
27
+ def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
28
+ ... # pragma: no cover
29
+
30
+ class OpenAIRequestResponseResolver:
31
+ def __call__(
32
+ self,
33
+ request: Dict[str, Any],
34
+ response: OpenAIResponse,
35
+ time_elapsed: float,
36
+ ) -> Optional[trace_tree.WBTraceTree]:
37
+ try:
38
+ if response["object"] == "edit":
39
+ return self._resolve_edit(request, response, time_elapsed)
40
+ elif response["object"] == "text_completion":
41
+ return self._resolve_completion(request, response, time_elapsed)
42
+ elif response["object"] == "chat.completion":
43
+ return self._resolve_chat_completion(
44
+ request, response, time_elapsed
45
+ )
46
+ else:
47
+ logger.info(f"Unknown OpenAI response object: {response['object']}")
48
+ except Exception as e:
49
+ logger.warning(f"Failed to resolve request/response: {e}")
50
+ return None
51
+
52
+ @staticmethod
53
+ def results_to_trace_tree(
54
+ request: Dict[str, Any],
55
+ response: OpenAIResponse,
56
+ results: List[trace_tree.Result],
57
+ time_elapsed: float,
58
+ ) -> trace_tree.WBTraceTree:
59
+ """Converts the request, response, and results into a trace tree.
60
+
61
+ params:
62
+ request: The request dictionary
63
+ response: The response object
64
+ results: A list of results object
65
+ time_elapsed: The time elapsed in seconds
66
+ returns:
67
+ A wandb trace tree object.
68
+ """
69
+ start_time_ms = int(round(response["created"] * 1000))
70
+ end_time_ms = start_time_ms + int(round(time_elapsed * 1000))
71
+ span = trace_tree.Span(
72
+ name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}",
73
+ attributes=dict(response), # type: ignore
74
+ start_time_ms=start_time_ms,
75
+ end_time_ms=end_time_ms,
76
+ span_kind=trace_tree.SpanKind.LLM,
77
+ results=results,
78
+ )
79
+ model_obj = {"request": request, "response": response, "_kind": "openai"}
80
+ return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj)
81
+
82
+ def _resolve_edit(
83
+ self,
84
+ request: Dict[str, Any],
85
+ response: OpenAIResponse,
86
+ time_elapsed: float,
87
+ ) -> trace_tree.WBTraceTree:
88
+ """Resolves the request and response objects for `openai.Edit`."""
89
+ request_str = (
90
+ f"\n\n**Instruction**: {request['instruction']}\n\n"
91
+ f"**Input**: {request['input']}\n"
92
+ )
93
+ choices = [
94
+ f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"]
95
+ ]
96
+
97
+ return self._request_response_result_to_trace(
98
+ request=request,
99
+ response=response,
100
+ request_str=request_str,
101
+ choices=choices,
102
+ time_elapsed=time_elapsed,
103
+ )
104
+
105
+ def _resolve_completion(
106
+ self,
107
+ request: Dict[str, Any],
108
+ response: OpenAIResponse,
109
+ time_elapsed: float,
110
+ ) -> trace_tree.WBTraceTree:
111
+ """Resolves the request and response objects for `openai.Completion`."""
112
+ request_str = f"\n\n**Prompt**: {request['prompt']}\n"
113
+ choices = [
114
+ f"\n\n**Completion**: {choice['text']}\n"
115
+ for choice in response["choices"]
116
+ ]
117
+
118
+ return self._request_response_result_to_trace(
119
+ request=request,
120
+ response=response,
121
+ request_str=request_str,
122
+ choices=choices,
123
+ time_elapsed=time_elapsed,
124
+ )
125
+
126
+ def _resolve_chat_completion(
127
+ self,
128
+ request: Dict[str, Any],
129
+ response: OpenAIResponse,
130
+ time_elapsed: float,
131
+ ) -> trace_tree.WBTraceTree:
132
+ """Resolves the request and response objects for `openai.Completion`."""
133
+ prompt = io.StringIO()
134
+ for message in request["messages"]:
135
+ prompt.write(f"\n\n**{message['role']}**: {message['content']}\n")
136
+ request_str = prompt.getvalue()
137
+
138
+ choices = [
139
+ f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n"
140
+ for choice in response["choices"]
141
+ ]
142
+
143
+ return self._request_response_result_to_trace(
144
+ request=request,
145
+ response=response,
146
+ request_str=request_str,
147
+ choices=choices,
148
+ time_elapsed=time_elapsed,
149
+ )
150
+
151
+ def _request_response_result_to_trace(
152
+ self,
153
+ request: Dict[str, Any],
154
+ response: OpenAIResponse,
155
+ request_str: str,
156
+ choices: List[str],
157
+ time_elapsed: float,
158
+ ) -> trace_tree.WBTraceTree:
159
+ """Resolves the request and response objects for `openai.Completion`."""
160
+ results = [
161
+ trace_tree.Result(
162
+ inputs={"request": request_str},
163
+ outputs={"response": choice},
164
+ )
165
+ for choice in choices
166
+ ]
167
+ trace = self.results_to_trace_tree(request, response, results, time_elapsed)
168
+ return trace
169
+
170
+ except:
171
+ imported_openAIResponse = False
172
+
173
+
174
+ #### What this does ####
175
+ # On success, logs events to Langfuse
176
+ import dotenv, os
177
+ import requests
178
+ import requests
179
+ from datetime import datetime
180
+
181
+ dotenv.load_dotenv() # Loading env variables using dotenv
182
+ import traceback
183
+
184
+
185
+ class WeightsBiasesLogger:
186
+ # Class variables or attributes
187
+ def __init__(self):
188
+ try:
189
+ import wandb
190
+ except:
191
+ raise Exception(
192
+ "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
193
+ )
194
+ if imported_openAIResponse == False:
195
+ raise Exception(
196
+ "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
197
+ )
198
+ self.resolver = OpenAIRequestResponseResolver()
199
+
200
+ def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
201
+ # Method definition
202
+ import wandb
203
+
204
+ try:
205
+ print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
206
+ run = wandb.init()
207
+ print_verbose(response_obj)
208
+
209
+ trace = self.resolver(
210
+ kwargs, response_obj, (end_time - start_time).total_seconds()
211
+ )
212
+
213
+ if trace is not None:
214
+ run.log({"trace": trace})
215
+
216
+ run.finish()
217
+ print_verbose(
218
+ f"W&B Logging Logging - final response object: {response_obj}"
219
+ )
220
+ except:
221
+ # traceback.print_exc()
222
+ print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
223
+ pass
litellm/llms/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import *
litellm/llms/ai21.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, types, traceback
2
+ import json
3
+ from enum import Enum
4
+ import requests
5
+ import time, httpx
6
+ from typing import Callable, Optional
7
+ from litellm.utils import ModelResponse, Choices, Message
8
+ import litellm
9
+
10
+
11
+ class AI21Error(Exception):
12
+ def __init__(self, status_code, message):
13
+ self.status_code = status_code
14
+ self.message = message
15
+ self.request = httpx.Request(
16
+ method="POST", url="https://api.ai21.com/studio/v1/"
17
+ )
18
+ self.response = httpx.Response(status_code=status_code, request=self.request)
19
+ super().__init__(
20
+ self.message
21
+ ) # Call the base class constructor with the parameters it needs
22
+
23
+
24
+ class AI21Config:
25
+ """
26
+ Reference: https://docs.ai21.com/reference/j2-complete-ref
27
+
28
+ The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters:
29
+
30
+ - `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful.
31
+
32
+ - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
33
+
34
+ - `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated.
35
+
36
+ - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
37
+
38
+ - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
39
+
40
+ - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
41
+
42
+ - `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position.
43
+
44
+ - `frequencyPenalty` (object): Placeholder for frequency penalty object.
45
+
46
+ - `presencePenalty` (object): Placeholder for presence penalty object.
47
+
48
+ - `countPenalty` (object): Placeholder for count penalty object.
49
+ """
50
+
51
+ numResults: Optional[int] = None
52
+ maxTokens: Optional[int] = None
53
+ minTokens: Optional[int] = None
54
+ temperature: Optional[float] = None
55
+ topP: Optional[float] = None
56
+ stopSequences: Optional[list] = None
57
+ topKReturn: Optional[int] = None
58
+ frequencePenalty: Optional[dict] = None
59
+ presencePenalty: Optional[dict] = None
60
+ countPenalty: Optional[dict] = None
61
+
62
+ def __init__(
63
+ self,
64
+ numResults: Optional[int] = None,
65
+ maxTokens: Optional[int] = None,
66
+ minTokens: Optional[int] = None,
67
+ temperature: Optional[float] = None,
68
+ topP: Optional[float] = None,
69
+ stopSequences: Optional[list] = None,
70
+ topKReturn: Optional[int] = None,
71
+ frequencePenalty: Optional[dict] = None,
72
+ presencePenalty: Optional[dict] = None,
73
+ countPenalty: Optional[dict] = None,
74
+ ) -> None:
75
+ locals_ = locals()
76
+ for key, value in locals_.items():
77
+ if key != "self" and value is not None:
78
+ setattr(self.__class__, key, value)
79
+
80
+ @classmethod
81
+ def get_config(cls):
82
+ return {
83
+ k: v
84
+ for k, v in cls.__dict__.items()
85
+ if not k.startswith("__")
86
+ and not isinstance(
87
+ v,
88
+ (
89
+ types.FunctionType,
90
+ types.BuiltinFunctionType,
91
+ classmethod,
92
+ staticmethod,
93
+ ),
94
+ )
95
+ and v is not None
96
+ }
97
+
98
+
99
+ def validate_environment(api_key):
100
+ if api_key is None:
101
+ raise ValueError(
102
+ "Missing AI21 API Key - A call is being made to ai21 but no key is set either in the environment variables or via params"
103
+ )
104
+ headers = {
105
+ "accept": "application/json",
106
+ "content-type": "application/json",
107
+ "Authorization": "Bearer " + api_key,
108
+ }
109
+ return headers
110
+
111
+
112
+ def completion(
113
+ model: str,
114
+ messages: list,
115
+ api_base: str,
116
+ model_response: ModelResponse,
117
+ print_verbose: Callable,
118
+ encoding,
119
+ api_key,
120
+ logging_obj,
121
+ optional_params=None,
122
+ litellm_params=None,
123
+ logger_fn=None,
124
+ ):
125
+ headers = validate_environment(api_key)
126
+ model = model
127
+ prompt = ""
128
+ for message in messages:
129
+ if "role" in message:
130
+ if message["role"] == "user":
131
+ prompt += f"{message['content']}"
132
+ else:
133
+ prompt += f"{message['content']}"
134
+ else:
135
+ prompt += f"{message['content']}"
136
+
137
+ ## Load Config
138
+ config = litellm.AI21Config.get_config()
139
+ for k, v in config.items():
140
+ if (
141
+ k not in optional_params
142
+ ): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
143
+ optional_params[k] = v
144
+
145
+ data = {
146
+ "prompt": prompt,
147
+ # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
148
+ **optional_params,
149
+ }
150
+
151
+ ## LOGGING
152
+ logging_obj.pre_call(
153
+ input=prompt,
154
+ api_key=api_key,
155
+ additional_args={"complete_input_dict": data},
156
+ )
157
+ ## COMPLETION CALL
158
+ response = requests.post(
159
+ api_base + model + "/complete", headers=headers, data=json.dumps(data)
160
+ )
161
+ if response.status_code != 200:
162
+ raise AI21Error(status_code=response.status_code, message=response.text)
163
+ if "stream" in optional_params and optional_params["stream"] == True:
164
+ return response.iter_lines()
165
+ else:
166
+ ## LOGGING
167
+ logging_obj.post_call(
168
+ input=prompt,
169
+ api_key=api_key,
170
+ original_response=response.text,
171
+ additional_args={"complete_input_dict": data},
172
+ )
173
+ ## RESPONSE OBJECT
174
+ completion_response = response.json()
175
+ try:
176
+ choices_list = []
177
+ for idx, item in enumerate(completion_response["completions"]):
178
+ if len(item["data"]["text"]) > 0:
179
+ message_obj = Message(content=item["data"]["text"])
180
+ else:
181
+ message_obj = Message(content=None)
182
+ choice_obj = Choices(
183
+ finish_reason=item["finishReason"]["reason"],
184
+ index=idx + 1,
185
+ message=message_obj,
186
+ )
187
+ choices_list.append(choice_obj)
188
+ model_response["choices"] = choices_list
189
+ except Exception as e:
190
+ raise AI21Error(
191
+ message=traceback.format_exc(), status_code=response.status_code
192
+ )
193
+
194
+ ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
195
+ prompt_tokens = len(encoding.encode(prompt))
196
+ completion_tokens = len(
197
+ encoding.encode(model_response["choices"][0]["message"].get("content"))
198
+ )
199
+
200
+ model_response["created"] = int(time.time())
201
+ model_response["model"] = model
202
+ model_response["usage"] = {
203
+ "prompt_tokens": prompt_tokens,
204
+ "completion_tokens": completion_tokens,
205
+ "total_tokens": prompt_tokens + completion_tokens,
206
+ }
207
+ return model_response
208
+
209
+
210
+ def embedding():
211
+ # logic for parsing in - calling - parsing out model embedding calls
212
+ pass
litellm/llms/aleph_alpha.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, types
2
+ import json
3
+ from enum import Enum
4
+ import requests
5
+ import time
6
+ from typing import Callable, Optional
7
+ import litellm
8
+ from litellm.utils import ModelResponse, Choices, Message, Usage
9
+ import httpx
10
+
11
+
12
+ class AlephAlphaError(Exception):
13
+ def __init__(self, status_code, message):
14
+ self.status_code = status_code
15
+ self.message = message
16
+ self.request = httpx.Request(
17
+ method="POST", url="https://api.aleph-alpha.com/complete"
18
+ )
19
+ self.response = httpx.Response(status_code=status_code, request=self.request)
20
+ super().__init__(
21
+ self.message
22
+ ) # Call the base class constructor with the parameters it needs
23
+
24
+
25
+ class AlephAlphaConfig:
26
+ """
27
+ Reference: https://docs.aleph-alpha.com/api/complete/
28
+
29
+ The `AlephAlphaConfig` class represents the configuration for the Aleph Alpha API. Here are the properties:
30
+
31
+ - `maximum_tokens` (integer, required): The maximum number of tokens to be generated by the completion. The sum of input tokens and maximum tokens may not exceed 2048.
32
+
33
+ - `minimum_tokens` (integer, optional; default value: 0): Generate at least this number of tokens before an end-of-text token is generated.
34
+
35
+ - `echo` (boolean, optional; default value: false): Whether to echo the prompt in the completion.
36
+
37
+ - `temperature` (number, nullable; default value: 0): Adjusts how creatively the model generates outputs. Use combinations of temperature, top_k, and top_p sensibly.
38
+
39
+ - `top_k` (integer, nullable; default value: 0): Introduces randomness into token generation by considering the top k most likely options.
40
+
41
+ - `top_p` (number, nullable; default value: 0): Adds randomness by considering the smallest set of tokens whose cumulative probability exceeds top_p.
42
+
43
+ - `presence_penalty`, `frequency_penalty`, `sequence_penalty` (number, nullable; default value: 0): Various penalties that can reduce repetition.
44
+
45
+ - `sequence_penalty_min_length` (integer; default value: 2): Minimum number of tokens to be considered as a sequence.
46
+
47
+ - `repetition_penalties_include_prompt`, `repetition_penalties_include_completion`, `use_multiplicative_presence_penalty`,`use_multiplicative_frequency_penalty`,`use_multiplicative_sequence_penalty` (boolean, nullable; default value: false): Various settings that adjust how the repetition penalties are applied.
48
+
49
+ - `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties.
50
+
51
+ - `penalty_exceptions` (string[], nullable): Strings that may be generated without penalty.
52
+
53
+ - `penalty_exceptions_include_stop_sequences` (boolean, nullable; default value: true): Include all stop_sequences in penalty_exceptions.
54
+
55
+ - `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side.
56
+
57
+ - `n` (integer, nullable; default value: 1): The number of completions to return.
58
+
59
+ - `logit_bias` (object, nullable): Adjust the logit scores before sampling.
60
+
61
+ - `log_probs` (integer, nullable): Number of top log probabilities for each token generated.
62
+
63
+ - `stop_sequences` (string[], nullable): List of strings that will stop generation if they're generated.
64
+
65
+ - `tokens` (boolean, nullable; default value: false): Flag indicating whether individual tokens of the completion should be returned or not.
66
+
67
+ - `raw_completion` (boolean; default value: false): if True, the raw completion of the model will be returned.
68
+
69
+ - `disable_optimizations` (boolean, nullable; default value: false): Disables any applied optimizations to both your prompt and completion.
70
+
71
+ - `completion_bias_inclusion`, `completion_bias_exclusion` (string[], default value: []): Set of strings to bias the generation of tokens.
72
+
73
+ - `completion_bias_inclusion_first_token_only`, `completion_bias_exclusion_first_token_only` (boolean; default value: false): Consider only the first token for the completion_bias_inclusion/exclusion.
74
+
75
+ - `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled.
76
+
77
+ - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
78
+ """
79
+
80
+ maximum_tokens: Optional[
81
+ int
82
+ ] = litellm.max_tokens # aleph alpha requires max tokens
83
+ minimum_tokens: Optional[int] = None
84
+ echo: Optional[bool] = None
85
+ temperature: Optional[int] = None
86
+ top_k: Optional[int] = None
87
+ top_p: Optional[int] = None
88
+ presence_penalty: Optional[int] = None
89
+ frequency_penalty: Optional[int] = None
90
+ sequence_penalty: Optional[int] = None
91
+ sequence_penalty_min_length: Optional[int] = None
92
+ repetition_penalties_include_prompt: Optional[bool] = None
93
+ repetition_penalties_include_completion: Optional[bool] = None
94
+ use_multiplicative_presence_penalty: Optional[bool] = None
95
+ use_multiplicative_frequency_penalty: Optional[bool] = None
96
+ use_multiplicative_sequence_penalty: Optional[bool] = None
97
+ penalty_bias: Optional[str] = None
98
+ penalty_exceptions_include_stop_sequences: Optional[bool] = None
99
+ best_of: Optional[int] = None
100
+ n: Optional[int] = None
101
+ logit_bias: Optional[dict] = None
102
+ log_probs: Optional[int] = None
103
+ stop_sequences: Optional[list] = None
104
+ tokens: Optional[bool] = None
105
+ raw_completion: Optional[bool] = None
106
+ disable_optimizations: Optional[bool] = None
107
+ completion_bias_inclusion: Optional[list] = None
108
+ completion_bias_exclusion: Optional[list] = None
109
+ completion_bias_inclusion_first_token_only: Optional[bool] = None
110
+ completion_bias_exclusion_first_token_only: Optional[bool] = None
111
+ contextual_control_threshold: Optional[int] = None
112
+ control_log_additive: Optional[bool] = None
113
+
114
+ def __init__(
115
+ self,
116
+ maximum_tokens: Optional[int] = None,
117
+ minimum_tokens: Optional[int] = None,
118
+ echo: Optional[bool] = None,
119
+ temperature: Optional[int] = None,
120
+ top_k: Optional[int] = None,
121
+ top_p: Optional[int] = None,
122
+ presence_penalty: Optional[int] = None,
123
+ frequency_penalty: Optional[int] = None,
124
+ sequence_penalty: Optional[int] = None,
125
+ sequence_penalty_min_length: Optional[int] = None,
126
+ repetition_penalties_include_prompt: Optional[bool] = None,
127
+ repetition_penalties_include_completion: Optional[bool] = None,
128
+ use_multiplicative_presence_penalty: Optional[bool] = None,
129
+ use_multiplicative_frequency_penalty: Optional[bool] = None,
130
+ use_multiplicative_sequence_penalty: Optional[bool] = None,
131
+ penalty_bias: Optional[str] = None,
132
+ penalty_exceptions_include_stop_sequences: Optional[bool] = None,
133
+ best_of: Optional[int] = None,
134
+ n: Optional[int] = None,
135
+ logit_bias: Optional[dict] = None,
136
+ log_probs: Optional[int] = None,
137
+ stop_sequences: Optional[list] = None,
138
+ tokens: Optional[bool] = None,
139
+ raw_completion: Optional[bool] = None,
140
+ disable_optimizations: Optional[bool] = None,
141
+ completion_bias_inclusion: Optional[list] = None,
142
+ completion_bias_exclusion: Optional[list] = None,
143
+ completion_bias_inclusion_first_token_only: Optional[bool] = None,
144
+ completion_bias_exclusion_first_token_only: Optional[bool] = None,
145
+ contextual_control_threshold: Optional[int] = None,
146
+ control_log_additive: Optional[bool] = None,
147
+ ) -> None:
148
+ locals_ = locals()
149
+ for key, value in locals_.items():
150
+ if key != "self" and value is not None:
151
+ setattr(self.__class__, key, value)
152
+
153
+ @classmethod
154
+ def get_config(cls):
155
+ return {
156
+ k: v
157
+ for k, v in cls.__dict__.items()
158
+ if not k.startswith("__")
159
+ and not isinstance(
160
+ v,
161
+ (
162
+ types.FunctionType,
163
+ types.BuiltinFunctionType,
164
+ classmethod,
165
+ staticmethod,
166
+ ),
167
+ )
168
+ and v is not None
169
+ }
170
+
171
+
172
+ def validate_environment(api_key):
173
+ headers = {
174
+ "accept": "application/json",
175
+ "content-type": "application/json",
176
+ }
177
+ if api_key:
178
+ headers["Authorization"] = f"Bearer {api_key}"
179
+ return headers
180
+
181
+
182
+ def completion(
183
+ model: str,
184
+ messages: list,
185
+ api_base: str,
186
+ model_response: ModelResponse,
187
+ print_verbose: Callable,
188
+ encoding,
189
+ api_key,
190
+ logging_obj,
191
+ optional_params=None,
192
+ litellm_params=None,
193
+ logger_fn=None,
194
+ default_max_tokens_to_sample=None,
195
+ ):
196
+ headers = validate_environment(api_key)
197
+
198
+ ## Load Config
199
+ config = litellm.AlephAlphaConfig.get_config()
200
+ for k, v in config.items():
201
+ if (
202
+ k not in optional_params
203
+ ): # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
204
+ optional_params[k] = v
205
+
206
+ completion_url = api_base
207
+ model = model
208
+ prompt = ""
209
+ if "control" in model: # follow the ###Instruction / ###Response format
210
+ for idx, message in enumerate(messages):
211
+ if "role" in message:
212
+ if (
213
+ idx == 0
214
+ ): # set first message as instruction (required), let later user messages be input
215
+ prompt += f"###Instruction: {message['content']}"
216
+ else:
217
+ if message["role"] == "system":
218
+ prompt += f"###Instruction: {message['content']}"
219
+ elif message["role"] == "user":
220
+ prompt += f"###Input: {message['content']}"
221
+ else:
222
+ prompt += f"###Response: {message['content']}"
223
+ else:
224
+ prompt += f"{message['content']}"
225
+ else:
226
+ prompt = " ".join(message["content"] for message in messages)
227
+ data = {
228
+ "model": model,
229
+ "prompt": prompt,
230
+ **optional_params,
231
+ }
232
+
233
+ ## LOGGING
234
+ logging_obj.pre_call(
235
+ input=prompt,
236
+ api_key=api_key,
237
+ additional_args={"complete_input_dict": data},
238
+ )
239
+ ## COMPLETION CALL
240
+ response = requests.post(
241
+ completion_url,
242
+ headers=headers,
243
+ data=json.dumps(data),
244
+ stream=optional_params["stream"] if "stream" in optional_params else False,
245
+ )
246
+ if "stream" in optional_params and optional_params["stream"] == True:
247
+ return response.iter_lines()
248
+ else:
249
+ ## LOGGING
250
+ logging_obj.post_call(
251
+ input=prompt,
252
+ api_key=api_key,
253
+ original_response=response.text,
254
+ additional_args={"complete_input_dict": data},
255
+ )
256
+ print_verbose(f"raw model_response: {response.text}")
257
+ ## RESPONSE OBJECT
258
+ completion_response = response.json()
259
+ if "error" in completion_response:
260
+ raise AlephAlphaError(
261
+ message=completion_response["error"],
262
+ status_code=response.status_code,
263
+ )
264
+ else:
265
+ try:
266
+ choices_list = []
267
+ for idx, item in enumerate(completion_response["completions"]):
268
+ if len(item["completion"]) > 0:
269
+ message_obj = Message(content=item["completion"])
270
+ else:
271
+ message_obj = Message(content=None)
272
+ choice_obj = Choices(
273
+ finish_reason=item["finish_reason"],
274
+ index=idx + 1,
275
+ message=message_obj,
276
+ )
277
+ choices_list.append(choice_obj)
278
+ model_response["choices"] = choices_list
279
+ except:
280
+ raise AlephAlphaError(
281
+ message=json.dumps(completion_response),
282
+ status_code=response.status_code,
283
+ )
284
+
285
+ ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
286
+ prompt_tokens = len(encoding.encode(prompt))
287
+ completion_tokens = len(
288
+ encoding.encode(model_response["choices"][0]["message"]["content"])
289
+ )
290
+
291
+ model_response["created"] = int(time.time())
292
+ model_response["model"] = model
293
+ usage = Usage(
294
+ prompt_tokens=prompt_tokens,
295
+ completion_tokens=completion_tokens,
296
+ total_tokens=prompt_tokens + completion_tokens,
297
+ )
298
+ model_response.usage = usage
299
+ return model_response
300
+
301
+
302
+ def embedding():
303
+ # logic for parsing in - calling - parsing out model embedding calls
304
+ pass
litellm/llms/anthropic.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, types
2
+ import json
3
+ from enum import Enum
4
+ import requests
5
+ import time
6
+ from typing import Callable, Optional
7
+ from litellm.utils import ModelResponse, Usage
8
+ import litellm
9
+ from .prompt_templates.factory import prompt_factory, custom_prompt
10
+ import httpx
11
+
12
+
13
+ class AnthropicConstants(Enum):
14
+ HUMAN_PROMPT = "\n\nHuman: "
15
+ AI_PROMPT = "\n\nAssistant: "
16
+
17
+
18
+ class AnthropicError(Exception):
19
+ def __init__(self, status_code, message):
20
+ self.status_code = status_code
21
+ self.message = message
22
+ self.request = httpx.Request(
23
+ method="POST", url="https://api.anthropic.com/v1/complete"
24
+ )
25
+ self.response = httpx.Response(status_code=status_code, request=self.request)
26
+ super().__init__(
27
+ self.message
28
+ ) # Call the base class constructor with the parameters it needs
29
+
30
+
31
+ class AnthropicConfig:
32
+ """
33
+ Reference: https://docs.anthropic.com/claude/reference/complete_post
34
+
35
+ to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
36
+ """
37
+
38
+ max_tokens_to_sample: Optional[
39
+ int
40
+ ] = litellm.max_tokens # anthropic requires a default
41
+ stop_sequences: Optional[list] = None
42
+ temperature: Optional[int] = None
43
+ top_p: Optional[int] = None
44
+ top_k: Optional[int] = None
45
+ metadata: Optional[dict] = None
46
+
47
+ def __init__(
48
+ self,
49
+ max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
50
+ stop_sequences: Optional[list] = None,
51
+ temperature: Optional[int] = None,
52
+ top_p: Optional[int] = None,
53
+ top_k: Optional[int] = None,
54
+ metadata: Optional[dict] = None,
55
+ ) -> None:
56
+ locals_ = locals()
57
+ for key, value in locals_.items():
58
+ if key != "self" and value is not None:
59
+ setattr(self.__class__, key, value)
60
+
61
+ @classmethod
62
+ def get_config(cls):
63
+ return {
64
+ k: v
65
+ for k, v in cls.__dict__.items()
66
+ if not k.startswith("__")
67
+ and not isinstance(
68
+ v,
69
+ (
70
+ types.FunctionType,
71
+ types.BuiltinFunctionType,
72
+ classmethod,
73
+ staticmethod,
74
+ ),
75
+ )
76
+ and v is not None
77
+ }
78
+
79
+
80
+ # makes headers for API call
81
+ def validate_environment(api_key):
82
+ if api_key is None:
83
+ raise ValueError(
84
+ "Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
85
+ )
86
+ headers = {
87
+ "accept": "application/json",
88
+ "anthropic-version": "2023-06-01",
89
+ "content-type": "application/json",
90
+ "x-api-key": api_key,
91
+ }
92
+ return headers
93
+
94
+
95
+ def completion(
96
+ model: str,
97
+ messages: list,
98
+ api_base: str,
99
+ custom_prompt_dict: dict,
100
+ model_response: ModelResponse,
101
+ print_verbose: Callable,
102
+ encoding,
103
+ api_key,
104
+ logging_obj,
105
+ optional_params=None,
106
+ litellm_params=None,
107
+ logger_fn=None,
108
+ ):
109
+ headers = validate_environment(api_key)
110
+ if model in custom_prompt_dict:
111
+ # check if the model has a registered custom prompt
112
+ model_prompt_details = custom_prompt_dict[model]
113
+ prompt = custom_prompt(
114
+ role_dict=model_prompt_details["roles"],
115
+ initial_prompt_value=model_prompt_details["initial_prompt_value"],
116
+ final_prompt_value=model_prompt_details["final_prompt_value"],
117
+ messages=messages,
118
+ )
119
+ else:
120
+ prompt = prompt_factory(
121
+ model=model, messages=messages, custom_llm_provider="anthropic"
122
+ )
123
+
124
+ ## Load Config
125
+ config = litellm.AnthropicConfig.get_config()
126
+ for k, v in config.items():
127
+ if (
128
+ k not in optional_params
129
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
130
+ optional_params[k] = v
131
+
132
+ data = {
133
+ "model": model,
134
+ "prompt": prompt,
135
+ **optional_params,
136
+ }
137
+
138
+ ## LOGGING
139
+ logging_obj.pre_call(
140
+ input=prompt,
141
+ api_key=api_key,
142
+ additional_args={"complete_input_dict": data, "api_base": api_base},
143
+ )
144
+
145
+ ## COMPLETION CALL
146
+ if "stream" in optional_params and optional_params["stream"] == True:
147
+ response = requests.post(
148
+ api_base,
149
+ headers=headers,
150
+ data=json.dumps(data),
151
+ stream=optional_params["stream"],
152
+ )
153
+
154
+ if response.status_code != 200:
155
+ raise AnthropicError(
156
+ status_code=response.status_code, message=response.text
157
+ )
158
+
159
+ return response.iter_lines()
160
+ else:
161
+ response = requests.post(api_base, headers=headers, data=json.dumps(data))
162
+ if response.status_code != 200:
163
+ raise AnthropicError(
164
+ status_code=response.status_code, message=response.text
165
+ )
166
+
167
+ ## LOGGING
168
+ logging_obj.post_call(
169
+ input=prompt,
170
+ api_key=api_key,
171
+ original_response=response.text,
172
+ additional_args={"complete_input_dict": data},
173
+ )
174
+ print_verbose(f"raw model_response: {response.text}")
175
+ ## RESPONSE OBJECT
176
+ try:
177
+ completion_response = response.json()
178
+ except:
179
+ raise AnthropicError(
180
+ message=response.text, status_code=response.status_code
181
+ )
182
+ if "error" in completion_response:
183
+ raise AnthropicError(
184
+ message=str(completion_response["error"]),
185
+ status_code=response.status_code,
186
+ )
187
+ else:
188
+ if len(completion_response["completion"]) > 0:
189
+ model_response["choices"][0]["message"][
190
+ "content"
191
+ ] = completion_response["completion"]
192
+ model_response.choices[0].finish_reason = completion_response["stop_reason"]
193
+
194
+ ## CALCULATING USAGE
195
+ prompt_tokens = len(
196
+ encoding.encode(prompt)
197
+ ) ##[TODO] use the anthropic tokenizer here
198
+ completion_tokens = len(
199
+ encoding.encode(model_response["choices"][0]["message"].get("content", ""))
200
+ ) ##[TODO] use the anthropic tokenizer here
201
+
202
+ model_response["created"] = int(time.time())
203
+ model_response["model"] = model
204
+ usage = Usage(
205
+ prompt_tokens=prompt_tokens,
206
+ completion_tokens=completion_tokens,
207
+ total_tokens=prompt_tokens + completion_tokens,
208
+ )
209
+ model_response.usage = usage
210
+ return model_response
211
+
212
+
213
+ def embedding():
214
+ # logic for parsing in - calling - parsing out model embedding calls
215
+ pass
litellm/llms/azure.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Any
2
+ import types, requests
3
+ from .base import BaseLLM
4
+ from litellm.utils import (
5
+ ModelResponse,
6
+ Choices,
7
+ Message,
8
+ CustomStreamWrapper,
9
+ convert_to_model_response_object,
10
+ )
11
+ from typing import Callable, Optional
12
+ from litellm import OpenAIConfig
13
+ import litellm, json
14
+ import httpx
15
+ from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
16
+ from openai import AzureOpenAI, AsyncAzureOpenAI
17
+
18
+
19
+ class AzureOpenAIError(Exception):
20
+ def __init__(
21
+ self,
22
+ status_code,
23
+ message,
24
+ request: Optional[httpx.Request] = None,
25
+ response: Optional[httpx.Response] = None,
26
+ ):
27
+ self.status_code = status_code
28
+ self.message = message
29
+ if request:
30
+ self.request = request
31
+ else:
32
+ self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
33
+ if response:
34
+ self.response = response
35
+ else:
36
+ self.response = httpx.Response(
37
+ status_code=status_code, request=self.request
38
+ )
39
+ super().__init__(
40
+ self.message
41
+ ) # Call the base class constructor with the parameters it needs
42
+
43
+
44
+ class AzureOpenAIConfig(OpenAIConfig):
45
+ """
46
+ Reference: https://platform.openai.com/docs/api-reference/chat/create
47
+
48
+ The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
49
+
50
+ - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
51
+
52
+ - `function_call` (string or object): This optional parameter controls how the model calls functions.
53
+
54
+ - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
55
+
56
+ - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
57
+
58
+ - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
59
+
60
+ - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
61
+
62
+ - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
63
+
64
+ - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
65
+
66
+ - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
67
+
68
+ - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ frequency_penalty: Optional[int] = None,
74
+ function_call: Optional[Union[str, dict]] = None,
75
+ functions: Optional[list] = None,
76
+ logit_bias: Optional[dict] = None,
77
+ max_tokens: Optional[int] = None,
78
+ n: Optional[int] = None,
79
+ presence_penalty: Optional[int] = None,
80
+ stop: Optional[Union[str, list]] = None,
81
+ temperature: Optional[int] = None,
82
+ top_p: Optional[int] = None,
83
+ ) -> None:
84
+ super().__init__(
85
+ frequency_penalty,
86
+ function_call,
87
+ functions,
88
+ logit_bias,
89
+ max_tokens,
90
+ n,
91
+ presence_penalty,
92
+ stop,
93
+ temperature,
94
+ top_p,
95
+ )
96
+
97
+
98
+ class AzureChatCompletion(BaseLLM):
99
+ def __init__(self) -> None:
100
+ super().__init__()
101
+
102
+ def validate_environment(self, api_key, azure_ad_token):
103
+ headers = {
104
+ "content-type": "application/json",
105
+ }
106
+ if api_key is not None:
107
+ headers["api-key"] = api_key
108
+ elif azure_ad_token is not None:
109
+ headers["Authorization"] = f"Bearer {azure_ad_token}"
110
+ return headers
111
+
112
+ def completion(
113
+ self,
114
+ model: str,
115
+ messages: list,
116
+ model_response: ModelResponse,
117
+ api_key: str,
118
+ api_base: str,
119
+ api_version: str,
120
+ api_type: str,
121
+ azure_ad_token: str,
122
+ print_verbose: Callable,
123
+ timeout,
124
+ logging_obj,
125
+ optional_params,
126
+ litellm_params,
127
+ logger_fn,
128
+ acompletion: bool = False,
129
+ headers: Optional[dict] = None,
130
+ client=None,
131
+ ):
132
+ super().completion()
133
+ exception_mapping_worked = False
134
+ try:
135
+ if model is None or messages is None:
136
+ raise AzureOpenAIError(
137
+ status_code=422, message=f"Missing model or messages"
138
+ )
139
+
140
+ max_retries = optional_params.pop("max_retries", 2)
141
+
142
+ ### CHECK IF CLOUDFLARE AI GATEWAY ###
143
+ ### if so - set the model as part of the base url
144
+ if "gateway.ai.cloudflare.com" in api_base:
145
+ ## build base url - assume api base includes resource name
146
+ if client is None:
147
+ if not api_base.endswith("/"):
148
+ api_base += "/"
149
+ api_base += f"{model}"
150
+
151
+ azure_client_params = {
152
+ "api_version": api_version,
153
+ "base_url": f"{api_base}",
154
+ "http_client": litellm.client_session,
155
+ "max_retries": max_retries,
156
+ "timeout": timeout,
157
+ }
158
+ if api_key is not None:
159
+ azure_client_params["api_key"] = api_key
160
+ elif azure_ad_token is not None:
161
+ azure_client_params["azure_ad_token"] = azure_ad_token
162
+
163
+ if acompletion is True:
164
+ client = AsyncAzureOpenAI(**azure_client_params)
165
+ else:
166
+ client = AzureOpenAI(**azure_client_params)
167
+
168
+ data = {"model": None, "messages": messages, **optional_params}
169
+ else:
170
+ data = {
171
+ "model": model, # type: ignore
172
+ "messages": messages,
173
+ **optional_params,
174
+ }
175
+
176
+ if acompletion is True:
177
+ if optional_params.get("stream", False):
178
+ return self.async_streaming(
179
+ logging_obj=logging_obj,
180
+ api_base=api_base,
181
+ data=data,
182
+ model=model,
183
+ api_key=api_key,
184
+ api_version=api_version,
185
+ azure_ad_token=azure_ad_token,
186
+ timeout=timeout,
187
+ client=client,
188
+ )
189
+ else:
190
+ return self.acompletion(
191
+ api_base=api_base,
192
+ data=data,
193
+ model_response=model_response,
194
+ api_key=api_key,
195
+ api_version=api_version,
196
+ model=model,
197
+ azure_ad_token=azure_ad_token,
198
+ timeout=timeout,
199
+ client=client,
200
+ logging_obj=logging_obj,
201
+ )
202
+ elif "stream" in optional_params and optional_params["stream"] == True:
203
+ return self.streaming(
204
+ logging_obj=logging_obj,
205
+ api_base=api_base,
206
+ data=data,
207
+ model=model,
208
+ api_key=api_key,
209
+ api_version=api_version,
210
+ azure_ad_token=azure_ad_token,
211
+ timeout=timeout,
212
+ client=client,
213
+ )
214
+ else:
215
+ ## LOGGING
216
+ logging_obj.pre_call(
217
+ input=messages,
218
+ api_key=api_key,
219
+ additional_args={
220
+ "headers": {
221
+ "api_key": api_key,
222
+ "azure_ad_token": azure_ad_token,
223
+ },
224
+ "api_version": api_version,
225
+ "api_base": api_base,
226
+ "complete_input_dict": data,
227
+ },
228
+ )
229
+ if not isinstance(max_retries, int):
230
+ raise AzureOpenAIError(
231
+ status_code=422, message="max retries must be an int"
232
+ )
233
+ # init AzureOpenAI Client
234
+ azure_client_params = {
235
+ "api_version": api_version,
236
+ "azure_endpoint": api_base,
237
+ "azure_deployment": model,
238
+ "http_client": litellm.client_session,
239
+ "max_retries": max_retries,
240
+ "timeout": timeout,
241
+ }
242
+ if api_key is not None:
243
+ azure_client_params["api_key"] = api_key
244
+ elif azure_ad_token is not None:
245
+ azure_client_params["azure_ad_token"] = azure_ad_token
246
+ if client is None:
247
+ azure_client = AzureOpenAI(**azure_client_params)
248
+ else:
249
+ azure_client = client
250
+ response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore
251
+ stringified_response = response.model_dump()
252
+ ## LOGGING
253
+ logging_obj.post_call(
254
+ input=messages,
255
+ api_key=api_key,
256
+ original_response=stringified_response,
257
+ additional_args={
258
+ "headers": headers,
259
+ "api_version": api_version,
260
+ "api_base": api_base,
261
+ },
262
+ )
263
+ return convert_to_model_response_object(
264
+ response_object=stringified_response,
265
+ model_response_object=model_response,
266
+ )
267
+ except AzureOpenAIError as e:
268
+ exception_mapping_worked = True
269
+ raise e
270
+ except Exception as e:
271
+ if hasattr(e, "status_code"):
272
+ raise AzureOpenAIError(status_code=e.status_code, message=str(e))
273
+ else:
274
+ raise AzureOpenAIError(status_code=500, message=str(e))
275
+
276
+ async def acompletion(
277
+ self,
278
+ api_key: str,
279
+ api_version: str,
280
+ model: str,
281
+ api_base: str,
282
+ data: dict,
283
+ timeout: Any,
284
+ model_response: ModelResponse,
285
+ azure_ad_token: Optional[str] = None,
286
+ client=None, # this is the AsyncAzureOpenAI
287
+ logging_obj=None,
288
+ ):
289
+ response = None
290
+ try:
291
+ max_retries = data.pop("max_retries", 2)
292
+ if not isinstance(max_retries, int):
293
+ raise AzureOpenAIError(
294
+ status_code=422, message="max retries must be an int"
295
+ )
296
+
297
+ # init AzureOpenAI Client
298
+ azure_client_params = {
299
+ "api_version": api_version,
300
+ "azure_endpoint": api_base,
301
+ "azure_deployment": model,
302
+ "http_client": litellm.client_session,
303
+ "max_retries": max_retries,
304
+ "timeout": timeout,
305
+ }
306
+ if api_key is not None:
307
+ azure_client_params["api_key"] = api_key
308
+ elif azure_ad_token is not None:
309
+ azure_client_params["azure_ad_token"] = azure_ad_token
310
+ if client is None:
311
+ azure_client = AsyncAzureOpenAI(**azure_client_params)
312
+ else:
313
+ azure_client = client
314
+ ## LOGGING
315
+ logging_obj.pre_call(
316
+ input=data["messages"],
317
+ api_key=azure_client.api_key,
318
+ additional_args={
319
+ "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
320
+ "api_base": azure_client._base_url._uri_reference,
321
+ "acompletion": True,
322
+ "complete_input_dict": data,
323
+ },
324
+ )
325
+ response = await azure_client.chat.completions.create(
326
+ **data, timeout=timeout
327
+ )
328
+ return convert_to_model_response_object(
329
+ response_object=response.model_dump(),
330
+ model_response_object=model_response,
331
+ )
332
+ except AzureOpenAIError as e:
333
+ exception_mapping_worked = True
334
+ raise e
335
+ except Exception as e:
336
+ if hasattr(e, "status_code"):
337
+ raise e
338
+ else:
339
+ raise AzureOpenAIError(status_code=500, message=str(e))
340
+
341
+ def streaming(
342
+ self,
343
+ logging_obj,
344
+ api_base: str,
345
+ api_key: str,
346
+ api_version: str,
347
+ data: dict,
348
+ model: str,
349
+ timeout: Any,
350
+ azure_ad_token: Optional[str] = None,
351
+ client=None,
352
+ ):
353
+ max_retries = data.pop("max_retries", 2)
354
+ if not isinstance(max_retries, int):
355
+ raise AzureOpenAIError(
356
+ status_code=422, message="max retries must be an int"
357
+ )
358
+ # init AzureOpenAI Client
359
+ azure_client_params = {
360
+ "api_version": api_version,
361
+ "azure_endpoint": api_base,
362
+ "azure_deployment": model,
363
+ "http_client": litellm.client_session,
364
+ "max_retries": max_retries,
365
+ "timeout": timeout,
366
+ }
367
+ if api_key is not None:
368
+ azure_client_params["api_key"] = api_key
369
+ elif azure_ad_token is not None:
370
+ azure_client_params["azure_ad_token"] = azure_ad_token
371
+ if client is None:
372
+ azure_client = AzureOpenAI(**azure_client_params)
373
+ else:
374
+ azure_client = client
375
+ ## LOGGING
376
+ logging_obj.pre_call(
377
+ input=data["messages"],
378
+ api_key=azure_client.api_key,
379
+ additional_args={
380
+ "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
381
+ "api_base": azure_client._base_url._uri_reference,
382
+ "acompletion": True,
383
+ "complete_input_dict": data,
384
+ },
385
+ )
386
+ response = azure_client.chat.completions.create(**data, timeout=timeout)
387
+ streamwrapper = CustomStreamWrapper(
388
+ completion_stream=response,
389
+ model=model,
390
+ custom_llm_provider="azure",
391
+ logging_obj=logging_obj,
392
+ )
393
+ return streamwrapper
394
+
395
+ async def async_streaming(
396
+ self,
397
+ logging_obj,
398
+ api_base: str,
399
+ api_key: str,
400
+ api_version: str,
401
+ data: dict,
402
+ model: str,
403
+ timeout: Any,
404
+ azure_ad_token: Optional[str] = None,
405
+ client=None,
406
+ ):
407
+ try:
408
+ # init AzureOpenAI Client
409
+ azure_client_params = {
410
+ "api_version": api_version,
411
+ "azure_endpoint": api_base,
412
+ "azure_deployment": model,
413
+ "http_client": litellm.client_session,
414
+ "max_retries": data.pop("max_retries", 2),
415
+ "timeout": timeout,
416
+ }
417
+ if api_key is not None:
418
+ azure_client_params["api_key"] = api_key
419
+ elif azure_ad_token is not None:
420
+ azure_client_params["azure_ad_token"] = azure_ad_token
421
+ if client is None:
422
+ azure_client = AsyncAzureOpenAI(**azure_client_params)
423
+ else:
424
+ azure_client = client
425
+ ## LOGGING
426
+ logging_obj.pre_call(
427
+ input=data["messages"],
428
+ api_key=azure_client.api_key,
429
+ additional_args={
430
+ "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
431
+ "api_base": azure_client._base_url._uri_reference,
432
+ "acompletion": True,
433
+ "complete_input_dict": data,
434
+ },
435
+ )
436
+ response = await azure_client.chat.completions.create(
437
+ **data, timeout=timeout
438
+ )
439
+ # return response
440
+ streamwrapper = CustomStreamWrapper(
441
+ completion_stream=response,
442
+ model=model,
443
+ custom_llm_provider="azure",
444
+ logging_obj=logging_obj,
445
+ )
446
+ return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
447
+ except Exception as e:
448
+ if hasattr(e, "status_code"):
449
+ raise AzureOpenAIError(status_code=e.status_code, message=str(e))
450
+ else:
451
+ raise AzureOpenAIError(status_code=500, message=str(e))
452
+
453
+ async def aembedding(
454
+ self,
455
+ data: dict,
456
+ model_response: ModelResponse,
457
+ azure_client_params: dict,
458
+ api_key: str,
459
+ input: list,
460
+ client=None,
461
+ logging_obj=None,
462
+ timeout=None,
463
+ ):
464
+ response = None
465
+ try:
466
+ if client is None:
467
+ openai_aclient = AsyncAzureOpenAI(**azure_client_params)
468
+ else:
469
+ openai_aclient = client
470
+ response = await openai_aclient.embeddings.create(**data, timeout=timeout)
471
+ stringified_response = response.model_dump()
472
+ ## LOGGING
473
+ logging_obj.post_call(
474
+ input=input,
475
+ api_key=api_key,
476
+ additional_args={"complete_input_dict": data},
477
+ original_response=stringified_response,
478
+ )
479
+ return convert_to_model_response_object(
480
+ response_object=stringified_response,
481
+ model_response_object=model_response,
482
+ response_type="embedding",
483
+ )
484
+ except Exception as e:
485
+ ## LOGGING
486
+ logging_obj.post_call(
487
+ input=input,
488
+ api_key=api_key,
489
+ additional_args={"complete_input_dict": data},
490
+ original_response=str(e),
491
+ )
492
+ raise e
493
+
494
+ def embedding(
495
+ self,
496
+ model: str,
497
+ input: list,
498
+ api_key: str,
499
+ api_base: str,
500
+ api_version: str,
501
+ timeout: float,
502
+ logging_obj=None,
503
+ model_response=None,
504
+ optional_params=None,
505
+ azure_ad_token: Optional[str] = None,
506
+ client=None,
507
+ aembedding=None,
508
+ ):
509
+ super().embedding()
510
+ exception_mapping_worked = False
511
+ if self._client_session is None:
512
+ self._client_session = self.create_client_session()
513
+ try:
514
+ data = {"model": model, "input": input, **optional_params}
515
+ max_retries = data.pop("max_retries", 2)
516
+ if not isinstance(max_retries, int):
517
+ raise AzureOpenAIError(
518
+ status_code=422, message="max retries must be an int"
519
+ )
520
+
521
+ # init AzureOpenAI Client
522
+ azure_client_params = {
523
+ "api_version": api_version,
524
+ "azure_endpoint": api_base,
525
+ "azure_deployment": model,
526
+ "http_client": litellm.client_session,
527
+ "max_retries": max_retries,
528
+ "timeout": timeout,
529
+ }
530
+ if api_key is not None:
531
+ azure_client_params["api_key"] = api_key
532
+ elif azure_ad_token is not None:
533
+ azure_client_params["azure_ad_token"] = azure_ad_token
534
+
535
+ ## LOGGING
536
+ logging_obj.pre_call(
537
+ input=input,
538
+ api_key=api_key,
539
+ additional_args={
540
+ "complete_input_dict": data,
541
+ "headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
542
+ },
543
+ )
544
+
545
+ if aembedding == True:
546
+ response = self.aembedding(
547
+ data=data,
548
+ input=input,
549
+ logging_obj=logging_obj,
550
+ api_key=api_key,
551
+ model_response=model_response,
552
+ azure_client_params=azure_client_params,
553
+ timeout=timeout,
554
+ )
555
+ return response
556
+ if client is None:
557
+ azure_client = AzureOpenAI(**azure_client_params) # type: ignore
558
+ else:
559
+ azure_client = client
560
+ ## COMPLETION CALL
561
+ response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore
562
+ ## LOGGING
563
+ logging_obj.post_call(
564
+ input=input,
565
+ api_key=api_key,
566
+ additional_args={"complete_input_dict": data, "api_base": api_base},
567
+ original_response=response,
568
+ )
569
+
570
+ return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
571
+ except AzureOpenAIError as e:
572
+ exception_mapping_worked = True
573
+ raise e
574
+ except Exception as e:
575
+ if hasattr(e, "status_code"):
576
+ raise AzureOpenAIError(status_code=e.status_code, message=str(e))
577
+ else:
578
+ raise AzureOpenAIError(status_code=500, message=str(e))
579
+
580
+ async def aimage_generation(
581
+ self,
582
+ data: dict,
583
+ model_response: ModelResponse,
584
+ azure_client_params: dict,
585
+ api_key: str,
586
+ input: list,
587
+ client=None,
588
+ logging_obj=None,
589
+ timeout=None,
590
+ ):
591
+ response = None
592
+ try:
593
+ if client is None:
594
+ client_session = litellm.aclient_session or httpx.AsyncClient(
595
+ transport=AsyncCustomHTTPTransport(),
596
+ )
597
+ openai_aclient = AsyncAzureOpenAI(
598
+ http_client=client_session, **azure_client_params
599
+ )
600
+ else:
601
+ openai_aclient = client
602
+ response = await openai_aclient.images.generate(**data, timeout=timeout)
603
+ stringified_response = response.model_dump()
604
+ ## LOGGING
605
+ logging_obj.post_call(
606
+ input=input,
607
+ api_key=api_key,
608
+ additional_args={"complete_input_dict": data},
609
+ original_response=stringified_response,
610
+ )
611
+ return convert_to_model_response_object(
612
+ response_object=stringified_response,
613
+ model_response_object=model_response,
614
+ response_type="image_generation",
615
+ )
616
+ except Exception as e:
617
+ ## LOGGING
618
+ logging_obj.post_call(
619
+ input=input,
620
+ api_key=api_key,
621
+ additional_args={"complete_input_dict": data},
622
+ original_response=str(e),
623
+ )
624
+ raise e
625
+
626
+ def image_generation(
627
+ self,
628
+ prompt: str,
629
+ timeout: float,
630
+ model: Optional[str] = None,
631
+ api_key: Optional[str] = None,
632
+ api_base: Optional[str] = None,
633
+ api_version: Optional[str] = None,
634
+ model_response: Optional[litellm.utils.ImageResponse] = None,
635
+ azure_ad_token: Optional[str] = None,
636
+ logging_obj=None,
637
+ optional_params=None,
638
+ client=None,
639
+ aimg_generation=None,
640
+ ):
641
+ exception_mapping_worked = False
642
+ try:
643
+ if model and len(model) > 0:
644
+ model = model
645
+ else:
646
+ model = None
647
+ data = {"model": model, "prompt": prompt, **optional_params}
648
+ max_retries = data.pop("max_retries", 2)
649
+ if not isinstance(max_retries, int):
650
+ raise AzureOpenAIError(
651
+ status_code=422, message="max retries must be an int"
652
+ )
653
+
654
+ # init AzureOpenAI Client
655
+ azure_client_params = {
656
+ "api_version": api_version,
657
+ "azure_endpoint": api_base,
658
+ "azure_deployment": model,
659
+ "max_retries": max_retries,
660
+ "timeout": timeout,
661
+ }
662
+ if api_key is not None:
663
+ azure_client_params["api_key"] = api_key
664
+ elif azure_ad_token is not None:
665
+ azure_client_params["azure_ad_token"] = azure_ad_token
666
+
667
+ if aimg_generation == True:
668
+ response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
669
+ return response
670
+
671
+ if client is None:
672
+ client_session = litellm.client_session or httpx.Client(
673
+ transport=CustomHTTPTransport(),
674
+ )
675
+ azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
676
+ else:
677
+ azure_client = client
678
+
679
+ ## LOGGING
680
+ logging_obj.pre_call(
681
+ input=prompt,
682
+ api_key=azure_client.api_key,
683
+ additional_args={
684
+ "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
685
+ "api_base": azure_client._base_url._uri_reference,
686
+ "acompletion": False,
687
+ "complete_input_dict": data,
688
+ },
689
+ )
690
+
691
+ ## COMPLETION CALL
692
+ response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
693
+ ## LOGGING
694
+ logging_obj.post_call(
695
+ input=input,
696
+ api_key=api_key,
697
+ additional_args={"complete_input_dict": data},
698
+ original_response=response,
699
+ )
700
+ # return response
701
+ return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore
702
+ except AzureOpenAIError as e:
703
+ exception_mapping_worked = True
704
+ raise e
705
+ except Exception as e:
706
+ if hasattr(e, "status_code"):
707
+ raise AzureOpenAIError(status_code=e.status_code, message=str(e))
708
+ else:
709
+ raise AzureOpenAIError(status_code=500, message=str(e))
710
+
711
+ async def ahealth_check(
712
+ self,
713
+ model: Optional[str],
714
+ api_key: str,
715
+ api_base: str,
716
+ api_version: str,
717
+ timeout: float,
718
+ mode: str,
719
+ messages: Optional[list] = None,
720
+ input: Optional[list] = None,
721
+ prompt: Optional[str] = None,
722
+ ):
723
+ client_session = litellm.aclient_session or httpx.AsyncClient(
724
+ transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
725
+ )
726
+ if "gateway.ai.cloudflare.com" in api_base:
727
+ ## build base url - assume api base includes resource name
728
+ if not api_base.endswith("/"):
729
+ api_base += "/"
730
+ api_base += f"{model}"
731
+ client = AsyncAzureOpenAI(
732
+ base_url=api_base,
733
+ api_version=api_version,
734
+ api_key=api_key,
735
+ timeout=timeout,
736
+ http_client=client_session,
737
+ )
738
+ model = None
739
+ # cloudflare ai gateway, needs model=None
740
+ else:
741
+ client = AsyncAzureOpenAI(
742
+ api_version=api_version,
743
+ azure_endpoint=api_base,
744
+ api_key=api_key,
745
+ timeout=timeout,
746
+ http_client=client_session,
747
+ )
748
+
749
+ # only run this check if it's not cloudflare ai gateway
750
+ if model is None and mode != "image_generation":
751
+ raise Exception("model is not set")
752
+
753
+ completion = None
754
+
755
+ if mode == "completion":
756
+ completion = await client.completions.with_raw_response.create(
757
+ model=model, # type: ignore
758
+ prompt=prompt, # type: ignore
759
+ )
760
+ elif mode == "chat":
761
+ if messages is None:
762
+ raise Exception("messages is not set")
763
+ completion = await client.chat.completions.with_raw_response.create(
764
+ model=model, # type: ignore
765
+ messages=messages, # type: ignore
766
+ )
767
+ elif mode == "embedding":
768
+ if input is None:
769
+ raise Exception("input is not set")
770
+ completion = await client.embeddings.with_raw_response.create(
771
+ model=model, # type: ignore
772
+ input=input, # type: ignore
773
+ )
774
+ elif mode == "image_generation":
775
+ if prompt is None:
776
+ raise Exception("prompt is not set")
777
+ completion = await client.images.with_raw_response.generate(
778
+ model=model, # type: ignore
779
+ prompt=prompt, # type: ignore
780
+ )
781
+ else:
782
+ raise Exception("mode not set")
783
+ response = {}
784
+
785
+ if completion is None or not hasattr(completion, "headers"):
786
+ raise Exception("invalid completion response")
787
+
788
+ if (
789
+ completion.headers.get("x-ratelimit-remaining-requests", None) is not None
790
+ ): # not provided for dall-e requests
791
+ response["x-ratelimit-remaining-requests"] = completion.headers[
792
+ "x-ratelimit-remaining-requests"
793
+ ]
794
+
795
+ if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
796
+ response["x-ratelimit-remaining-tokens"] = completion.headers[
797
+ "x-ratelimit-remaining-tokens"
798
+ ]
799
+ return response
litellm/llms/base.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## This is a template base class to be used for adding new LLM providers via API calls
2
+ import litellm
3
+ import httpx
4
+ from typing import Optional
5
+
6
+
7
+ class BaseLLM:
8
+ _client_session: Optional[httpx.Client] = None
9
+
10
+ def create_client_session(self):
11
+ if litellm.client_session:
12
+ _client_session = litellm.client_session
13
+ else:
14
+ _client_session = httpx.Client()
15
+
16
+ return _client_session
17
+
18
+ def create_aclient_session(self):
19
+ if litellm.aclient_session:
20
+ _aclient_session = litellm.aclient_session
21
+ else:
22
+ _aclient_session = httpx.AsyncClient()
23
+
24
+ return _aclient_session
25
+
26
+ def __exit__(self):
27
+ if hasattr(self, "_client_session"):
28
+ self._client_session.close()
29
+
30
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
31
+ if hasattr(self, "_aclient_session"):
32
+ await self._aclient_session.aclose()
33
+
34
+ def validate_environment(self): # set up the environment required to run the model
35
+ pass
36
+
37
+ def completion(
38
+ self, *args, **kwargs
39
+ ): # logic for parsing in - calling - parsing out model completion calls
40
+ pass
41
+
42
+ def embedding(
43
+ self, *args, **kwargs
44
+ ): # logic for parsing in - calling - parsing out model embedding calls
45
+ pass
litellm/llms/baseten.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from enum import Enum
4
+ import requests
5
+ import time
6
+ from typing import Callable
7
+ from litellm.utils import ModelResponse, Usage
8
+
9
+
10
+ class BasetenError(Exception):
11
+ def __init__(self, status_code, message):
12
+ self.status_code = status_code
13
+ self.message = message
14
+ super().__init__(
15
+ self.message
16
+ ) # Call the base class constructor with the parameters it needs
17
+
18
+
19
+ def validate_environment(api_key):
20
+ headers = {
21
+ "accept": "application/json",
22
+ "content-type": "application/json",
23
+ }
24
+ if api_key:
25
+ headers["Authorization"] = f"Api-Key {api_key}"
26
+ return headers
27
+
28
+
29
+ def completion(
30
+ model: str,
31
+ messages: list,
32
+ model_response: ModelResponse,
33
+ print_verbose: Callable,
34
+ encoding,
35
+ api_key,
36
+ logging_obj,
37
+ optional_params=None,
38
+ litellm_params=None,
39
+ logger_fn=None,
40
+ ):
41
+ headers = validate_environment(api_key)
42
+ completion_url_fragment_1 = "https://app.baseten.co/models/"
43
+ completion_url_fragment_2 = "/predict"
44
+ model = model
45
+ prompt = ""
46
+ for message in messages:
47
+ if "role" in message:
48
+ if message["role"] == "user":
49
+ prompt += f"{message['content']}"
50
+ else:
51
+ prompt += f"{message['content']}"
52
+ else:
53
+ prompt += f"{message['content']}"
54
+ data = {
55
+ "inputs": prompt,
56
+ "prompt": prompt,
57
+ "parameters": optional_params,
58
+ "stream": True
59
+ if "stream" in optional_params and optional_params["stream"] == True
60
+ else False,
61
+ }
62
+
63
+ ## LOGGING
64
+ logging_obj.pre_call(
65
+ input=prompt,
66
+ api_key=api_key,
67
+ additional_args={"complete_input_dict": data},
68
+ )
69
+ ## COMPLETION CALL
70
+ response = requests.post(
71
+ completion_url_fragment_1 + model + completion_url_fragment_2,
72
+ headers=headers,
73
+ data=json.dumps(data),
74
+ stream=True
75
+ if "stream" in optional_params and optional_params["stream"] == True
76
+ else False,
77
+ )
78
+ if "text/event-stream" in response.headers["Content-Type"] or (
79
+ "stream" in optional_params and optional_params["stream"] == True
80
+ ):
81
+ return response.iter_lines()
82
+ else:
83
+ ## LOGGING
84
+ logging_obj.post_call(
85
+ input=prompt,
86
+ api_key=api_key,
87
+ original_response=response.text,
88
+ additional_args={"complete_input_dict": data},
89
+ )
90
+ print_verbose(f"raw model_response: {response.text}")
91
+ ## RESPONSE OBJECT
92
+ completion_response = response.json()
93
+ if "error" in completion_response:
94
+ raise BasetenError(
95
+ message=completion_response["error"],
96
+ status_code=response.status_code,
97
+ )
98
+ else:
99
+ if "model_output" in completion_response:
100
+ if (
101
+ isinstance(completion_response["model_output"], dict)
102
+ and "data" in completion_response["model_output"]
103
+ and isinstance(completion_response["model_output"]["data"], list)
104
+ ):
105
+ model_response["choices"][0]["message"][
106
+ "content"
107
+ ] = completion_response["model_output"]["data"][0]
108
+ elif isinstance(completion_response["model_output"], str):
109
+ model_response["choices"][0]["message"][
110
+ "content"
111
+ ] = completion_response["model_output"]
112
+ elif "completion" in completion_response and isinstance(
113
+ completion_response["completion"], str
114
+ ):
115
+ model_response["choices"][0]["message"][
116
+ "content"
117
+ ] = completion_response["completion"]
118
+ elif isinstance(completion_response, list) and len(completion_response) > 0:
119
+ if "generated_text" not in completion_response:
120
+ raise BasetenError(
121
+ message=f"Unable to parse response. Original response: {response.text}",
122
+ status_code=response.status_code,
123
+ )
124
+ model_response["choices"][0]["message"][
125
+ "content"
126
+ ] = completion_response[0]["generated_text"]
127
+ ## GETTING LOGPROBS
128
+ if (
129
+ "details" in completion_response[0]
130
+ and "tokens" in completion_response[0]["details"]
131
+ ):
132
+ model_response.choices[0].finish_reason = completion_response[0][
133
+ "details"
134
+ ]["finish_reason"]
135
+ sum_logprob = 0
136
+ for token in completion_response[0]["details"]["tokens"]:
137
+ sum_logprob += token["logprob"]
138
+ model_response["choices"][0]["message"]._logprobs = sum_logprob
139
+ else:
140
+ raise BasetenError(
141
+ message=f"Unable to parse response. Original response: {response.text}",
142
+ status_code=response.status_code,
143
+ )
144
+
145
+ ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
146
+ prompt_tokens = len(encoding.encode(prompt))
147
+ completion_tokens = len(
148
+ encoding.encode(model_response["choices"][0]["message"]["content"])
149
+ )
150
+
151
+ model_response["created"] = int(time.time())
152
+ model_response["model"] = model
153
+ usage = Usage(
154
+ prompt_tokens=prompt_tokens,
155
+ completion_tokens=completion_tokens,
156
+ total_tokens=prompt_tokens + completion_tokens,
157
+ )
158
+ model_response.usage = usage
159
+ return model_response
160
+
161
+
162
+ def embedding():
163
+ # logic for parsing in - calling - parsing out model embedding calls
164
+ pass
litellm/llms/bedrock.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, copy, types
2
+ import os
3
+ from enum import Enum
4
+ import time
5
+ from typing import Callable, Optional, Any, Union
6
+ import litellm
7
+ from litellm.utils import ModelResponse, get_secret, Usage
8
+ from .prompt_templates.factory import prompt_factory, custom_prompt
9
+ import httpx
10
+
11
+
12
+ class BedrockError(Exception):
13
+ def __init__(self, status_code, message):
14
+ self.status_code = status_code
15
+ self.message = message
16
+ self.request = httpx.Request(
17
+ method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
18
+ )
19
+ self.response = httpx.Response(status_code=status_code, request=self.request)
20
+ super().__init__(
21
+ self.message
22
+ ) # Call the base class constructor with the parameters it needs
23
+
24
+
25
+ class AmazonTitanConfig:
26
+ """
27
+ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
28
+
29
+ Supported Params for the Amazon Titan models:
30
+
31
+ - `maxTokenCount` (integer) max tokens,
32
+ - `stopSequences` (string[]) list of stop sequence strings
33
+ - `temperature` (float) temperature for model,
34
+ - `topP` (int) top p for model
35
+ """
36
+
37
+ maxTokenCount: Optional[int] = None
38
+ stopSequences: Optional[list] = None
39
+ temperature: Optional[float] = None
40
+ topP: Optional[int] = None
41
+
42
+ def __init__(
43
+ self,
44
+ maxTokenCount: Optional[int] = None,
45
+ stopSequences: Optional[list] = None,
46
+ temperature: Optional[float] = None,
47
+ topP: Optional[int] = None,
48
+ ) -> None:
49
+ locals_ = locals()
50
+ for key, value in locals_.items():
51
+ if key != "self" and value is not None:
52
+ setattr(self.__class__, key, value)
53
+
54
+ @classmethod
55
+ def get_config(cls):
56
+ return {
57
+ k: v
58
+ for k, v in cls.__dict__.items()
59
+ if not k.startswith("__")
60
+ and not isinstance(
61
+ v,
62
+ (
63
+ types.FunctionType,
64
+ types.BuiltinFunctionType,
65
+ classmethod,
66
+ staticmethod,
67
+ ),
68
+ )
69
+ and v is not None
70
+ }
71
+
72
+
73
+ class AmazonAnthropicConfig:
74
+ """
75
+ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
76
+
77
+ Supported Params for the Amazon / Anthropic models:
78
+
79
+ - `max_tokens_to_sample` (integer) max tokens,
80
+ - `temperature` (float) model temperature,
81
+ - `top_k` (integer) top k,
82
+ - `top_p` (integer) top p,
83
+ - `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
84
+ - `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
85
+ """
86
+
87
+ max_tokens_to_sample: Optional[int] = litellm.max_tokens
88
+ stop_sequences: Optional[list] = None
89
+ temperature: Optional[float] = None
90
+ top_k: Optional[int] = None
91
+ top_p: Optional[int] = None
92
+ anthropic_version: Optional[str] = None
93
+
94
+ def __init__(
95
+ self,
96
+ max_tokens_to_sample: Optional[int] = None,
97
+ stop_sequences: Optional[list] = None,
98
+ temperature: Optional[float] = None,
99
+ top_k: Optional[int] = None,
100
+ top_p: Optional[int] = None,
101
+ anthropic_version: Optional[str] = None,
102
+ ) -> None:
103
+ locals_ = locals()
104
+ for key, value in locals_.items():
105
+ if key != "self" and value is not None:
106
+ setattr(self.__class__, key, value)
107
+
108
+ @classmethod
109
+ def get_config(cls):
110
+ return {
111
+ k: v
112
+ for k, v in cls.__dict__.items()
113
+ if not k.startswith("__")
114
+ and not isinstance(
115
+ v,
116
+ (
117
+ types.FunctionType,
118
+ types.BuiltinFunctionType,
119
+ classmethod,
120
+ staticmethod,
121
+ ),
122
+ )
123
+ and v is not None
124
+ }
125
+
126
+
127
+ class AmazonCohereConfig:
128
+ """
129
+ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
130
+
131
+ Supported Params for the Amazon / Cohere models:
132
+
133
+ - `max_tokens` (integer) max tokens,
134
+ - `temperature` (float) model temperature,
135
+ - `return_likelihood` (string) n/a
136
+ """
137
+
138
+ max_tokens: Optional[int] = None
139
+ temperature: Optional[float] = None
140
+ return_likelihood: Optional[str] = None
141
+
142
+ def __init__(
143
+ self,
144
+ max_tokens: Optional[int] = None,
145
+ temperature: Optional[float] = None,
146
+ return_likelihood: Optional[str] = None,
147
+ ) -> None:
148
+ locals_ = locals()
149
+ for key, value in locals_.items():
150
+ if key != "self" and value is not None:
151
+ setattr(self.__class__, key, value)
152
+
153
+ @classmethod
154
+ def get_config(cls):
155
+ return {
156
+ k: v
157
+ for k, v in cls.__dict__.items()
158
+ if not k.startswith("__")
159
+ and not isinstance(
160
+ v,
161
+ (
162
+ types.FunctionType,
163
+ types.BuiltinFunctionType,
164
+ classmethod,
165
+ staticmethod,
166
+ ),
167
+ )
168
+ and v is not None
169
+ }
170
+
171
+
172
+ class AmazonAI21Config:
173
+ """
174
+ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
175
+
176
+ Supported Params for the Amazon / AI21 models:
177
+
178
+ - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
179
+
180
+ - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
181
+
182
+ - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
183
+
184
+ - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
185
+
186
+ - `frequencyPenalty` (object): Placeholder for frequency penalty object.
187
+
188
+ - `presencePenalty` (object): Placeholder for presence penalty object.
189
+
190
+ - `countPenalty` (object): Placeholder for count penalty object.
191
+ """
192
+
193
+ maxTokens: Optional[int] = None
194
+ temperature: Optional[float] = None
195
+ topP: Optional[float] = None
196
+ stopSequences: Optional[list] = None
197
+ frequencePenalty: Optional[dict] = None
198
+ presencePenalty: Optional[dict] = None
199
+ countPenalty: Optional[dict] = None
200
+
201
+ def __init__(
202
+ self,
203
+ maxTokens: Optional[int] = None,
204
+ temperature: Optional[float] = None,
205
+ topP: Optional[float] = None,
206
+ stopSequences: Optional[list] = None,
207
+ frequencePenalty: Optional[dict] = None,
208
+ presencePenalty: Optional[dict] = None,
209
+ countPenalty: Optional[dict] = None,
210
+ ) -> None:
211
+ locals_ = locals()
212
+ for key, value in locals_.items():
213
+ if key != "self" and value is not None:
214
+ setattr(self.__class__, key, value)
215
+
216
+ @classmethod
217
+ def get_config(cls):
218
+ return {
219
+ k: v
220
+ for k, v in cls.__dict__.items()
221
+ if not k.startswith("__")
222
+ and not isinstance(
223
+ v,
224
+ (
225
+ types.FunctionType,
226
+ types.BuiltinFunctionType,
227
+ classmethod,
228
+ staticmethod,
229
+ ),
230
+ )
231
+ and v is not None
232
+ }
233
+
234
+
235
+ class AnthropicConstants(Enum):
236
+ HUMAN_PROMPT = "\n\nHuman: "
237
+ AI_PROMPT = "\n\nAssistant: "
238
+
239
+
240
+ class AmazonLlamaConfig:
241
+ """
242
+ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
243
+
244
+ Supported Params for the Amazon / Meta Llama models:
245
+
246
+ - `max_gen_len` (integer) max tokens,
247
+ - `temperature` (float) temperature for model,
248
+ - `top_p` (float) top p for model
249
+ """
250
+
251
+ max_gen_len: Optional[int] = None
252
+ temperature: Optional[float] = None
253
+ topP: Optional[float] = None
254
+
255
+ def __init__(
256
+ self,
257
+ maxTokenCount: Optional[int] = None,
258
+ temperature: Optional[float] = None,
259
+ topP: Optional[int] = None,
260
+ ) -> None:
261
+ locals_ = locals()
262
+ for key, value in locals_.items():
263
+ if key != "self" and value is not None:
264
+ setattr(self.__class__, key, value)
265
+
266
+ @classmethod
267
+ def get_config(cls):
268
+ return {
269
+ k: v
270
+ for k, v in cls.__dict__.items()
271
+ if not k.startswith("__")
272
+ and not isinstance(
273
+ v,
274
+ (
275
+ types.FunctionType,
276
+ types.BuiltinFunctionType,
277
+ classmethod,
278
+ staticmethod,
279
+ ),
280
+ )
281
+ and v is not None
282
+ }
283
+
284
+
285
+ def init_bedrock_client(
286
+ region_name=None,
287
+ aws_access_key_id: Optional[str] = None,
288
+ aws_secret_access_key: Optional[str] = None,
289
+ aws_region_name: Optional[str] = None,
290
+ aws_bedrock_runtime_endpoint: Optional[str] = None,
291
+ ):
292
+ # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
293
+ litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
294
+ standard_aws_region_name = get_secret("AWS_REGION", None)
295
+
296
+ ## CHECK IS 'os.environ/' passed in
297
+ # Define the list of parameters to check
298
+ params_to_check = [
299
+ aws_access_key_id,
300
+ aws_secret_access_key,
301
+ aws_region_name,
302
+ aws_bedrock_runtime_endpoint,
303
+ ]
304
+
305
+ # Iterate over parameters and update if needed
306
+ for i, param in enumerate(params_to_check):
307
+ if param and param.startswith("os.environ/"):
308
+ params_to_check[i] = get_secret(param)
309
+ # Assign updated values back to parameters
310
+ (
311
+ aws_access_key_id,
312
+ aws_secret_access_key,
313
+ aws_region_name,
314
+ aws_bedrock_runtime_endpoint,
315
+ ) = params_to_check
316
+ if region_name:
317
+ pass
318
+ elif aws_region_name:
319
+ region_name = aws_region_name
320
+ elif litellm_aws_region_name:
321
+ region_name = litellm_aws_region_name
322
+ elif standard_aws_region_name:
323
+ region_name = standard_aws_region_name
324
+ else:
325
+ raise BedrockError(
326
+ message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
327
+ status_code=401,
328
+ )
329
+
330
+ # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
331
+ env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
332
+ if aws_bedrock_runtime_endpoint:
333
+ endpoint_url = aws_bedrock_runtime_endpoint
334
+ elif env_aws_bedrock_runtime_endpoint:
335
+ endpoint_url = env_aws_bedrock_runtime_endpoint
336
+ else:
337
+ endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
338
+
339
+ import boto3
340
+
341
+ if aws_access_key_id != None:
342
+ # uses auth params passed to completion
343
+ # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
344
+
345
+ client = boto3.client(
346
+ service_name="bedrock-runtime",
347
+ aws_access_key_id=aws_access_key_id,
348
+ aws_secret_access_key=aws_secret_access_key,
349
+ region_name=region_name,
350
+ endpoint_url=endpoint_url,
351
+ )
352
+ else:
353
+ # aws_access_key_id is None, assume user is trying to auth using env variables
354
+ # boto3 automatically reads env variables
355
+
356
+ client = boto3.client(
357
+ service_name="bedrock-runtime",
358
+ region_name=region_name,
359
+ endpoint_url=endpoint_url,
360
+ )
361
+
362
+ return client
363
+
364
+
365
+ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
366
+ # handle anthropic prompts using anthropic constants
367
+ if provider == "anthropic":
368
+ if model in custom_prompt_dict:
369
+ # check if the model has a registered custom prompt
370
+ model_prompt_details = custom_prompt_dict[model]
371
+ prompt = custom_prompt(
372
+ role_dict=model_prompt_details["roles"],
373
+ initial_prompt_value=model_prompt_details["initial_prompt_value"],
374
+ final_prompt_value=model_prompt_details["final_prompt_value"],
375
+ messages=messages,
376
+ )
377
+ else:
378
+ prompt = prompt_factory(
379
+ model=model, messages=messages, custom_llm_provider="anthropic"
380
+ )
381
+ else:
382
+ prompt = ""
383
+ for message in messages:
384
+ if "role" in message:
385
+ if message["role"] == "user":
386
+ prompt += f"{message['content']}"
387
+ else:
388
+ prompt += f"{message['content']}"
389
+ else:
390
+ prompt += f"{message['content']}"
391
+ return prompt
392
+
393
+
394
+ """
395
+ BEDROCK AUTH Keys/Vars
396
+ os.environ['AWS_ACCESS_KEY_ID'] = ""
397
+ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
398
+ """
399
+
400
+
401
+ # set os.environ['AWS_REGION_NAME'] = <your-region_name>
402
+
403
+
404
+ def completion(
405
+ model: str,
406
+ messages: list,
407
+ custom_prompt_dict: dict,
408
+ model_response: ModelResponse,
409
+ print_verbose: Callable,
410
+ encoding,
411
+ logging_obj,
412
+ optional_params=None,
413
+ litellm_params=None,
414
+ logger_fn=None,
415
+ ):
416
+ exception_mapping_worked = False
417
+ try:
418
+ # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
419
+ aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
420
+ aws_access_key_id = optional_params.pop("aws_access_key_id", None)
421
+ aws_region_name = optional_params.pop("aws_region_name", None)
422
+ aws_bedrock_runtime_endpoint = optional_params.pop(
423
+ "aws_bedrock_runtime_endpoint", None
424
+ )
425
+
426
+ # use passed in BedrockRuntime.Client if provided, otherwise create a new one
427
+ client = optional_params.pop("aws_bedrock_client", None)
428
+
429
+ # only init client, if user did not pass one
430
+ if client is None:
431
+ client = init_bedrock_client(
432
+ aws_access_key_id=aws_access_key_id,
433
+ aws_secret_access_key=aws_secret_access_key,
434
+ aws_region_name=aws_region_name,
435
+ aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
436
+ )
437
+
438
+ model = model
439
+ modelId = (
440
+ optional_params.pop("model_id", None) or model
441
+ ) # default to model if not passed
442
+ provider = model.split(".")[0]
443
+ prompt = convert_messages_to_prompt(
444
+ model, messages, provider, custom_prompt_dict
445
+ )
446
+ inference_params = copy.deepcopy(optional_params)
447
+ stream = inference_params.pop("stream", False)
448
+ if provider == "anthropic":
449
+ ## LOAD CONFIG
450
+ config = litellm.AmazonAnthropicConfig.get_config()
451
+ for k, v in config.items():
452
+ if (
453
+ k not in inference_params
454
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
455
+ inference_params[k] = v
456
+ data = json.dumps({"prompt": prompt, **inference_params})
457
+ elif provider == "ai21":
458
+ ## LOAD CONFIG
459
+ config = litellm.AmazonAI21Config.get_config()
460
+ for k, v in config.items():
461
+ if (
462
+ k not in inference_params
463
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
464
+ inference_params[k] = v
465
+
466
+ data = json.dumps({"prompt": prompt, **inference_params})
467
+ elif provider == "cohere":
468
+ ## LOAD CONFIG
469
+ config = litellm.AmazonCohereConfig.get_config()
470
+ for k, v in config.items():
471
+ if (
472
+ k not in inference_params
473
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
474
+ inference_params[k] = v
475
+ if optional_params.get("stream", False) == True:
476
+ inference_params[
477
+ "stream"
478
+ ] = True # cohere requires stream = True in inference params
479
+ data = json.dumps({"prompt": prompt, **inference_params})
480
+ elif provider == "meta":
481
+ ## LOAD CONFIG
482
+ config = litellm.AmazonLlamaConfig.get_config()
483
+ for k, v in config.items():
484
+ if (
485
+ k not in inference_params
486
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
487
+ inference_params[k] = v
488
+ data = json.dumps({"prompt": prompt, **inference_params})
489
+ elif provider == "amazon": # amazon titan
490
+ ## LOAD CONFIG
491
+ config = litellm.AmazonTitanConfig.get_config()
492
+ for k, v in config.items():
493
+ if (
494
+ k not in inference_params
495
+ ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
496
+ inference_params[k] = v
497
+
498
+ data = json.dumps(
499
+ {
500
+ "inputText": prompt,
501
+ "textGenerationConfig": inference_params,
502
+ }
503
+ )
504
+ else:
505
+ data = json.dumps({})
506
+
507
+ ## COMPLETION CALL
508
+ accept = "application/json"
509
+ contentType = "application/json"
510
+ if stream == True:
511
+ if provider == "ai21":
512
+ ## LOGGING
513
+ request_str = f"""
514
+ response = client.invoke_model(
515
+ body={data},
516
+ modelId={modelId},
517
+ accept=accept,
518
+ contentType=contentType
519
+ )
520
+ """
521
+ logging_obj.pre_call(
522
+ input=prompt,
523
+ api_key="",
524
+ additional_args={
525
+ "complete_input_dict": data,
526
+ "request_str": request_str,
527
+ },
528
+ )
529
+
530
+ response = client.invoke_model(
531
+ body=data, modelId=modelId, accept=accept, contentType=contentType
532
+ )
533
+
534
+ response = response.get("body").read()
535
+ return response
536
+ else:
537
+ ## LOGGING
538
+ request_str = f"""
539
+ response = client.invoke_model_with_response_stream(
540
+ body={data},
541
+ modelId={modelId},
542
+ accept=accept,
543
+ contentType=contentType
544
+ )
545
+ """
546
+ logging_obj.pre_call(
547
+ input=prompt,
548
+ api_key="",
549
+ additional_args={
550
+ "complete_input_dict": data,
551
+ "request_str": request_str,
552
+ },
553
+ )
554
+
555
+ response = client.invoke_model_with_response_stream(
556
+ body=data, modelId=modelId, accept=accept, contentType=contentType
557
+ )
558
+ response = response.get("body")
559
+ return response
560
+ try:
561
+ ## LOGGING
562
+ request_str = f"""
563
+ response = client.invoke_model(
564
+ body={data},
565
+ modelId={modelId},
566
+ accept=accept,
567
+ contentType=contentType
568
+ )
569
+ """
570
+ logging_obj.pre_call(
571
+ input=prompt,
572
+ api_key="",
573
+ additional_args={
574
+ "complete_input_dict": data,
575
+ "request_str": request_str,
576
+ },
577
+ )
578
+ response = client.invoke_model(
579
+ body=data, modelId=modelId, accept=accept, contentType=contentType
580
+ )
581
+ except client.exceptions.ValidationException as e:
582
+ if "The provided model identifier is invalid" in str(e):
583
+ raise BedrockError(status_code=404, message=str(e))
584
+ raise BedrockError(status_code=400, message=str(e))
585
+ except Exception as e:
586
+ raise BedrockError(status_code=500, message=str(e))
587
+
588
+ response_body = json.loads(response.get("body").read())
589
+
590
+ ## LOGGING
591
+ logging_obj.post_call(
592
+ input=prompt,
593
+ api_key="",
594
+ original_response=json.dumps(response_body),
595
+ additional_args={"complete_input_dict": data},
596
+ )
597
+ print_verbose(f"raw model_response: {response}")
598
+ ## RESPONSE OBJECT
599
+ outputText = "default"
600
+ if provider == "ai21":
601
+ outputText = response_body.get("completions")[0].get("data").get("text")
602
+ elif provider == "anthropic":
603
+ outputText = response_body["completion"]
604
+ model_response["finish_reason"] = response_body["stop_reason"]
605
+ elif provider == "cohere":
606
+ outputText = response_body["generations"][0]["text"]
607
+ elif provider == "meta":
608
+ outputText = response_body["generation"]
609
+ else: # amazon titan
610
+ outputText = response_body.get("results")[0].get("outputText")
611
+
612
+ response_metadata = response.get("ResponseMetadata", {})
613
+ if response_metadata.get("HTTPStatusCode", 500) >= 400:
614
+ raise BedrockError(
615
+ message=outputText,
616
+ status_code=response_metadata.get("HTTPStatusCode", 500),
617
+ )
618
+ else:
619
+ try:
620
+ if len(outputText) > 0:
621
+ model_response["choices"][0]["message"]["content"] = outputText
622
+ except:
623
+ raise BedrockError(
624
+ message=json.dumps(outputText),
625
+ status_code=response_metadata.get("HTTPStatusCode", 500),
626
+ )
627
+
628
+ ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
629
+ prompt_tokens = len(encoding.encode(prompt))
630
+ completion_tokens = len(
631
+ encoding.encode(model_response["choices"][0]["message"].get("content", ""))
632
+ )
633
+
634
+ model_response["created"] = int(time.time())
635
+ model_response["model"] = model
636
+ usage = Usage(
637
+ prompt_tokens=prompt_tokens,
638
+ completion_tokens=completion_tokens,
639
+ total_tokens=prompt_tokens + completion_tokens,
640
+ )
641
+ model_response.usage = usage
642
+ return model_response
643
+ except BedrockError as e:
644
+ exception_mapping_worked = True
645
+ raise e
646
+ except Exception as e:
647
+ if exception_mapping_worked:
648
+ raise e
649
+ else:
650
+ import traceback
651
+
652
+ raise BedrockError(status_code=500, message=traceback.format_exc())
653
+
654
+
655
+ def _embedding_func_single(
656
+ model: str,
657
+ input: str,
658
+ client: Any,
659
+ optional_params=None,
660
+ encoding=None,
661
+ logging_obj=None,
662
+ ):
663
+ # logic for parsing in - calling - parsing out model embedding calls
664
+ ## FORMAT EMBEDDING INPUT ##
665
+ provider = model.split(".")[0]
666
+ inference_params = copy.deepcopy(optional_params)
667
+ inference_params.pop(
668
+ "user", None
669
+ ) # make sure user is not passed in for bedrock call
670
+ modelId = (
671
+ optional_params.pop("model_id", None) or model
672
+ ) # default to model if not passed
673
+ if provider == "amazon":
674
+ input = input.replace(os.linesep, " ")
675
+ data = {"inputText": input, **inference_params}
676
+ # data = json.dumps(data)
677
+ elif provider == "cohere":
678
+ inference_params["input_type"] = inference_params.get(
679
+ "input_type", "search_document"
680
+ ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
681
+ data = {"texts": [input], **inference_params} # type: ignore
682
+ body = json.dumps(data).encode("utf-8")
683
+ ## LOGGING
684
+ request_str = f"""
685
+ response = client.invoke_model(
686
+ body={body},
687
+ modelId={modelId},
688
+ accept="*/*",
689
+ contentType="application/json",
690
+ )""" # type: ignore
691
+ logging_obj.pre_call(
692
+ input=input,
693
+ api_key="", # boto3 is used for init.
694
+ additional_args={
695
+ "complete_input_dict": {"model": modelId, "texts": input},
696
+ "request_str": request_str,
697
+ },
698
+ )
699
+ try:
700
+ response = client.invoke_model(
701
+ body=body,
702
+ modelId=modelId,
703
+ accept="*/*",
704
+ contentType="application/json",
705
+ )
706
+ response_body = json.loads(response.get("body").read())
707
+ ## LOGGING
708
+ logging_obj.post_call(
709
+ input=input,
710
+ api_key="",
711
+ additional_args={"complete_input_dict": data},
712
+ original_response=json.dumps(response_body),
713
+ )
714
+ if provider == "cohere":
715
+ response = response_body.get("embeddings")
716
+ # flatten list
717
+ response = [item for sublist in response for item in sublist]
718
+ return response
719
+ elif provider == "amazon":
720
+ return response_body.get("embedding")
721
+ except Exception as e:
722
+ raise BedrockError(
723
+ message=f"Embedding Error with model {model}: {e}", status_code=500
724
+ )
725
+
726
+
727
+ def embedding(
728
+ model: str,
729
+ input: Union[list, str],
730
+ api_key: Optional[str] = None,
731
+ logging_obj=None,
732
+ model_response=None,
733
+ optional_params=None,
734
+ encoding=None,
735
+ ):
736
+ ### BOTO3 INIT ###
737
+ # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
738
+ aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
739
+ aws_access_key_id = optional_params.pop("aws_access_key_id", None)
740
+ aws_region_name = optional_params.pop("aws_region_name", None)
741
+ aws_bedrock_runtime_endpoint = optional_params.pop(
742
+ "aws_bedrock_runtime_endpoint", None
743
+ )
744
+
745
+ # use passed in BedrockRuntime.Client if provided, otherwise create a new one
746
+ client = init_bedrock_client(
747
+ aws_access_key_id=aws_access_key_id,
748
+ aws_secret_access_key=aws_secret_access_key,
749
+ aws_region_name=aws_region_name,
750
+ aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
751
+ )
752
+ if type(input) == str:
753
+ embeddings = [
754
+ _embedding_func_single(
755
+ model,
756
+ input,
757
+ optional_params=optional_params,
758
+ client=client,
759
+ logging_obj=logging_obj,
760
+ )
761
+ ]
762
+ else:
763
+ ## Embedding Call
764
+ embeddings = [
765
+ _embedding_func_single(
766
+ model,
767
+ i,
768
+ optional_params=optional_params,
769
+ client=client,
770
+ logging_obj=logging_obj,
771
+ )
772
+ for i in input
773
+ ] # [TODO]: make these parallel calls
774
+
775
+ ## Populate OpenAI compliant dictionary
776
+ embedding_response = []
777
+ for idx, embedding in enumerate(embeddings):
778
+ embedding_response.append(
779
+ {
780
+ "object": "embedding",
781
+ "index": idx,
782
+ "embedding": embedding,
783
+ }
784
+ )
785
+ model_response["object"] = "list"
786
+ model_response["data"] = embedding_response
787
+ model_response["model"] = model
788
+ input_tokens = 0
789
+
790
+ input_str = "".join(input)
791
+
792
+ input_tokens += len(encoding.encode(input_str))
793
+
794
+ usage = Usage(
795
+ prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0
796
+ )
797
+ model_response.usage = usage
798
+
799
+ return model_response
litellm/llms/cloudflare.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, types
2
+ import json
3
+ from enum import Enum
4
+ import requests
5
+ import time
6
+ from typing import Callable, Optional
7
+ import litellm
8
+ import httpx
9
+ from litellm.utils import ModelResponse, Usage
10
+ from .prompt_templates.factory import prompt_factory, custom_prompt
11
+
12
+
13
+ class CloudflareError(Exception):
14
+ def __init__(self, status_code, message):
15
+ self.status_code = status_code
16
+ self.message = message
17
+ self.request = httpx.Request(method="POST", url="https://api.cloudflare.com")
18
+ self.response = httpx.Response(status_code=status_code, request=self.request)
19
+ super().__init__(
20
+ self.message
21
+ ) # Call the base class constructor with the parameters it needs
22
+
23
+
24
+ class CloudflareConfig:
25
+ max_tokens: Optional[int] = None
26
+ stream: Optional[bool] = None
27
+
28
+ def __init__(
29
+ self,
30
+ max_tokens: Optional[int] = None,
31
+ stream: Optional[bool] = None,
32
+ ) -> None:
33
+ locals_ = locals()
34
+ for key, value in locals_.items():
35
+ if key != "self" and value is not None:
36
+ setattr(self.__class__, key, value)
37
+
38
+ @classmethod
39
+ def get_config(cls):
40
+ return {
41
+ k: v
42
+ for k, v in cls.__dict__.items()
43
+ if not k.startswith("__")
44
+ and not isinstance(
45
+ v,
46
+ (
47
+ types.FunctionType,
48
+ types.BuiltinFunctionType,
49
+ classmethod,
50
+ staticmethod,
51
+ ),
52
+ )
53
+ and v is not None
54
+ }
55
+
56
+
57
+ def validate_environment(api_key):
58
+ if api_key is None:
59
+ raise ValueError(
60
+ "Missing CloudflareError API Key - A call is being made to cloudflare but no key is set either in the environment variables or via params"
61
+ )
62
+ headers = {
63
+ "accept": "application/json",
64
+ "content-type": "application/json",
65
+ "Authorization": "Bearer " + api_key,
66
+ }
67
+ return headers
68
+
69
+
70
+ def completion(
71
+ model: str,
72
+ messages: list,
73
+ api_base: str,
74
+ model_response: ModelResponse,
75
+ print_verbose: Callable,
76
+ encoding,
77
+ api_key,
78
+ logging_obj,
79
+ custom_prompt_dict={},
80
+ optional_params=None,
81
+ litellm_params=None,
82
+ logger_fn=None,
83
+ ):
84
+ headers = validate_environment(api_key)
85
+
86
+ ## Load Config
87
+ config = litellm.CloudflareConfig.get_config()
88
+ for k, v in config.items():
89
+ if k not in optional_params:
90
+ optional_params[k] = v
91
+
92
+ print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
93
+ if model in custom_prompt_dict:
94
+ # check if the model has a registered custom prompt
95
+ model_prompt_details = custom_prompt_dict[model]
96
+ prompt = custom_prompt(
97
+ role_dict=model_prompt_details.get("roles", {}),
98
+ initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
99
+ final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
100
+ bos_token=model_prompt_details.get("bos_token", ""),
101
+ eos_token=model_prompt_details.get("eos_token", ""),
102
+ messages=messages,
103
+ )
104
+
105
+ # cloudflare adds the model to the api base
106
+ api_base = api_base + model
107
+
108
+ data = {
109
+ "messages": messages,
110
+ **optional_params,
111
+ }
112
+
113
+ ## LOGGING
114
+ logging_obj.pre_call(
115
+ input=messages,
116
+ api_key=api_key,
117
+ additional_args={
118
+ "headers": headers,
119
+ "api_base": api_base,
120
+ "complete_input_dict": data,
121
+ },
122
+ )
123
+
124
+ ## COMPLETION CALL
125
+ if "stream" in optional_params and optional_params["stream"] == True:
126
+ response = requests.post(
127
+ api_base,
128
+ headers=headers,
129
+ data=json.dumps(data),
130
+ stream=optional_params["stream"],
131
+ )
132
+ return response.iter_lines()
133
+ else:
134
+ response = requests.post(api_base, headers=headers, data=json.dumps(data))
135
+ ## LOGGING
136
+ logging_obj.post_call(
137
+ input=messages,
138
+ api_key=api_key,
139
+ original_response=response.text,
140
+ additional_args={"complete_input_dict": data},
141
+ )
142
+ print_verbose(f"raw model_response: {response.text}")
143
+ ## RESPONSE OBJECT
144
+ if response.status_code != 200:
145
+ raise CloudflareError(
146
+ status_code=response.status_code, message=response.text
147
+ )
148
+ completion_response = response.json()
149
+
150
+ model_response["choices"][0]["message"]["content"] = completion_response[
151
+ "result"
152
+ ]["response"]
153
+
154
+ ## CALCULATING USAGE
155
+ print_verbose(
156
+ f"CALCULATING CLOUDFLARE TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
157
+ )
158
+ prompt_tokens = litellm.utils.get_token_count(messages=messages, model=model)
159
+ completion_tokens = len(
160
+ encoding.encode(model_response["choices"][0]["message"].get("content", ""))
161
+ )
162
+
163
+ model_response["created"] = int(time.time())
164
+ model_response["model"] = "cloudflare/" + model
165
+ usage = Usage(
166
+ prompt_tokens=prompt_tokens,
167
+ completion_tokens=completion_tokens,
168
+ total_tokens=prompt_tokens + completion_tokens,
169
+ )
170
+ model_response.usage = usage
171
+ return model_response
172
+
173
+
174
+ def embedding():
175
+ # logic for parsing in - calling - parsing out model embedding calls
176
+ pass
litellm/llms/cohere.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, types
2
+ import json
3
+ from enum import Enum
4
+ import requests
5
+ import time, traceback
6
+ from typing import Callable, Optional
7
+ from litellm.utils import ModelResponse, Choices, Message, Usage
8
+ import litellm
9
+ import httpx
10
+
11
+
12
+ class CohereError(Exception):
13
+ def __init__(self, status_code, message):
14
+ self.status_code = status_code
15
+ self.message = message
16
+ self.request = httpx.Request(
17
+ method="POST", url="https://api.cohere.ai/v1/generate"
18
+ )
19
+ self.response = httpx.Response(status_code=status_code, request=self.request)
20
+ super().__init__(
21
+ self.message
22
+ ) # Call the base class constructor with the parameters it needs
23
+
24
+
25
+ class CohereConfig:
26
+ """
27
+ Reference: https://docs.cohere.com/reference/generate
28
+
29
+ The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters:
30
+
31
+ - `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5.
32
+
33
+ - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20.
34
+
35
+ - `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END.
36
+
37
+ - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75.
38
+
39
+ - `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc.
40
+
41
+ - `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text.
42
+
43
+ - `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text.
44
+
45
+ - `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0.
46
+
47
+ - `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0.
48
+
49
+ - `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens.
50
+
51
+ - `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared.
52
+
53
+ - `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE.
54
+
55
+ - `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
56
+ """
57
+
58
+ num_generations: Optional[int] = None
59
+ max_tokens: Optional[int] = None
60
+ truncate: Optional[str] = None
61
+ temperature: Optional[int] = None
62
+ preset: Optional[str] = None
63
+ end_sequences: Optional[list] = None
64
+ stop_sequences: Optional[list] = None
65
+ k: Optional[int] = None
66
+ p: Optional[int] = None
67
+ frequency_penalty: Optional[int] = None
68
+ presence_penalty: Optional[int] = None
69
+ return_likelihoods: Optional[str] = None
70
+ logit_bias: Optional[dict] = None
71
+
72
+ def __init__(
73
+ self,
74
+ num_generations: Optional[int] = None,
75
+ max_tokens: Optional[int] = None,
76
+ truncate: Optional[str] = None,
77
+ temperature: Optional[int] = None,
78
+ preset: Optional[str] = None,
79
+ end_sequences: Optional[list] = None,
80
+ stop_sequences: Optional[list] = None,
81
+ k: Optional[int] = None,
82
+ p: Optional[int] = None,
83
+ frequency_penalty: Optional[int] = None,
84
+ presence_penalty: Optional[int] = None,
85
+ return_likelihoods: Optional[str] = None,
86
+ logit_bias: Optional[dict] = None,
87
+ ) -> None:
88
+ locals_ = locals()
89
+ for key, value in locals_.items():
90
+ if key != "self" and value is not None:
91
+ setattr(self.__class__, key, value)
92
+
93
+ @classmethod
94
+ def get_config(cls):
95
+ return {
96
+ k: v
97
+ for k, v in cls.__dict__.items()
98
+ if not k.startswith("__")
99
+ and not isinstance(
100
+ v,
101
+ (
102
+ types.FunctionType,
103
+ types.BuiltinFunctionType,
104
+ classmethod,
105
+ staticmethod,
106
+ ),
107
+ )
108
+ and v is not None
109
+ }
110
+
111
+
112
+ def validate_environment(api_key):
113
+ headers = {
114
+ "accept": "application/json",
115
+ "content-type": "application/json",
116
+ }
117
+ if api_key:
118
+ headers["Authorization"] = f"Bearer {api_key}"
119
+ return headers
120
+
121
+
122
+ def completion(
123
+ model: str,
124
+ messages: list,
125
+ api_base: str,
126
+ model_response: ModelResponse,
127
+ print_verbose: Callable,
128
+ encoding,
129
+ api_key,
130
+ logging_obj,
131
+ optional_params=None,
132
+ litellm_params=None,
133
+ logger_fn=None,
134
+ ):
135
+ headers = validate_environment(api_key)
136
+ completion_url = api_base
137
+ model = model
138
+ prompt = " ".join(message["content"] for message in messages)
139
+
140
+ ## Load Config
141
+ config = litellm.CohereConfig.get_config()
142
+ for k, v in config.items():
143
+ if (
144
+ k not in optional_params
145
+ ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
146
+ optional_params[k] = v
147
+
148
+ data = {
149
+ "model": model,
150
+ "prompt": prompt,
151
+ **optional_params,
152
+ }
153
+
154
+ ## LOGGING
155
+ logging_obj.pre_call(
156
+ input=prompt,
157
+ api_key=api_key,
158
+ additional_args={
159
+ "complete_input_dict": data,
160
+ "headers": headers,
161
+ "api_base": completion_url,
162
+ },
163
+ )
164
+ ## COMPLETION CALL
165
+ response = requests.post(
166
+ completion_url,
167
+ headers=headers,
168
+ data=json.dumps(data),
169
+ stream=optional_params["stream"] if "stream" in optional_params else False,
170
+ )
171
+ ## error handling for cohere calls
172
+ if response.status_code != 200:
173
+ raise CohereError(message=response.text, status_code=response.status_code)
174
+
175
+ if "stream" in optional_params and optional_params["stream"] == True:
176
+ return response.iter_lines()
177
+ else:
178
+ ## LOGGING
179
+ logging_obj.post_call(
180
+ input=prompt,
181
+ api_key=api_key,
182
+ original_response=response.text,
183
+ additional_args={"complete_input_dict": data},
184
+ )
185
+ print_verbose(f"raw model_response: {response.text}")
186
+ ## RESPONSE OBJECT
187
+ completion_response = response.json()
188
+ if "error" in completion_response:
189
+ raise CohereError(
190
+ message=completion_response["error"],
191
+ status_code=response.status_code,
192
+ )
193
+ else:
194
+ try:
195
+ choices_list = []
196
+ for idx, item in enumerate(completion_response["generations"]):
197
+ if len(item["text"]) > 0:
198
+ message_obj = Message(content=item["text"])
199
+ else:
200
+ message_obj = Message(content=None)
201
+ choice_obj = Choices(
202
+ finish_reason=item["finish_reason"],
203
+ index=idx + 1,
204
+ message=message_obj,
205
+ )
206
+ choices_list.append(choice_obj)
207
+ model_response["choices"] = choices_list
208
+ except Exception as e:
209
+ raise CohereError(
210
+ message=response.text, status_code=response.status_code
211
+ )
212
+
213
+ ## CALCULATING USAGE
214
+ prompt_tokens = len(encoding.encode(prompt))
215
+ completion_tokens = len(
216
+ encoding.encode(model_response["choices"][0]["message"].get("content", ""))
217
+ )
218
+
219
+ model_response["created"] = int(time.time())
220
+ model_response["model"] = model
221
+ usage = Usage(
222
+ prompt_tokens=prompt_tokens,
223
+ completion_tokens=completion_tokens,
224
+ total_tokens=prompt_tokens + completion_tokens,
225
+ )
226
+ model_response.usage = usage
227
+ return model_response
228
+
229
+
230
+ def embedding(
231
+ model: str,
232
+ input: list,
233
+ api_key: Optional[str] = None,
234
+ logging_obj=None,
235
+ model_response=None,
236
+ encoding=None,
237
+ optional_params=None,
238
+ ):
239
+ headers = validate_environment(api_key)
240
+ embed_url = "https://api.cohere.ai/v1/embed"
241
+ model = model
242
+ data = {"model": model, "texts": input, **optional_params}
243
+
244
+ if "3" in model and "input_type" not in data:
245
+ # cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
246
+ data["input_type"] = "search_document"
247
+
248
+ ## LOGGING
249
+ logging_obj.pre_call(
250
+ input=input,
251
+ api_key=api_key,
252
+ additional_args={"complete_input_dict": data},
253
+ )
254
+ ## COMPLETION CALL
255
+ response = requests.post(embed_url, headers=headers, data=json.dumps(data))
256
+ ## LOGGING
257
+ logging_obj.post_call(
258
+ input=input,
259
+ api_key=api_key,
260
+ additional_args={"complete_input_dict": data},
261
+ original_response=response,
262
+ )
263
+ """
264
+ response
265
+ {
266
+ 'object': "list",
267
+ 'data': [
268
+
269
+ ]
270
+ 'model',
271
+ 'usage'
272
+ }
273
+ """
274
+ if response.status_code != 200:
275
+ raise CohereError(message=response.text, status_code=response.status_code)
276
+ embeddings = response.json()["embeddings"]
277
+ output_data = []
278
+ for idx, embedding in enumerate(embeddings):
279
+ output_data.append(
280
+ {"object": "embedding", "index": idx, "embedding": embedding}
281
+ )
282
+ model_response["object"] = "list"
283
+ model_response["data"] = output_data
284
+ model_response["model"] = model
285
+ input_tokens = 0
286
+ for text in input:
287
+ input_tokens += len(encoding.encode(text))
288
+
289
+ model_response["usage"] = {
290
+ "prompt_tokens": input_tokens,
291
+ "total_tokens": input_tokens,
292
+ }
293
+ return model_response
litellm/llms/custom_httpx/azure_dall_e_2.py ADDED
<
@@ -0,0 +1,136 @@