Abyl Ikhsanov commited on
Commit
74e6c1c
·
1 Parent(s): a7f6abf

Update llm.py

Browse files
Files changed (1) hide show
  1. lightrag/llm.py +80 -1
lightrag/llm.py CHANGED
@@ -4,7 +4,7 @@ import json
4
  import aioboto3
5
  import numpy as np
6
  import ollama
7
- from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
8
  from tenacity import (
9
  retry,
10
  stop_after_attempt,
@@ -61,6 +61,49 @@ async def openai_complete_if_cache(
61
  )
62
  return response.choices[0].message.content
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  class BedrockError(Exception):
66
  """Generic error for issues related to Amazon Bedrock"""
@@ -270,6 +313,16 @@ async def gpt_4o_mini_complete(
270
  **kwargs,
271
  )
272
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  async def bedrock_complete(
275
  prompt, system_prompt=None, history_messages=[], **kwargs
@@ -332,6 +385,32 @@ async def openai_embedding(
332
  )
333
  return np.array([dp.embedding for dp in response.data])
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
337
  # @retry(
 
4
  import aioboto3
5
  import numpy as np
6
  import ollama
7
+ from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
8
  from tenacity import (
9
  retry,
10
  stop_after_attempt,
 
61
  )
62
  return response.choices[0].message.content
63
 
64
+ @retry(
65
+ stop=stop_after_attempt(3),
66
+ wait=wait_exponential(multiplier=1, min=4, max=10),
67
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
68
+ )
69
+ async def azure_openai_complete_if_cache(model,
70
+ prompt,
71
+ system_prompt=None,
72
+ history_messages=[],
73
+ base_url=None,
74
+ api_key=None,
75
+ **kwargs):
76
+ if api_key:
77
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
78
+ if base_url:
79
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
80
+
81
+ openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
82
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
83
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
84
+
85
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
86
+ messages = []
87
+ if system_prompt:
88
+ messages.append({"role": "system", "content": system_prompt})
89
+ messages.extend(history_messages)
90
+ if prompt is not None:
91
+ messages.append({"role": "user", "content": prompt})
92
+ if hashing_kv is not None:
93
+ args_hash = compute_args_hash(model, messages)
94
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
95
+ if if_cache_return is not None:
96
+ return if_cache_return["return"]
97
+
98
+ response = await openai_async_client.chat.completions.create(
99
+ model=model, messages=messages, **kwargs
100
+ )
101
+
102
+ if hashing_kv is not None:
103
+ await hashing_kv.upsert(
104
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
105
+ )
106
+ return response.choices[0].message.content
107
 
108
  class BedrockError(Exception):
109
  """Generic error for issues related to Amazon Bedrock"""
 
313
  **kwargs,
314
  )
315
 
316
+ async def azure_openai_complete(
317
+ prompt, system_prompt=None, history_messages=[], **kwargs
318
+ ) -> str:
319
+ return await azure_openai_complete_if_cache(
320
+ "conversation-4o-mini",
321
+ prompt,
322
+ system_prompt=system_prompt,
323
+ history_messages=history_messages,
324
+ **kwargs,
325
+ )
326
 
327
  async def bedrock_complete(
328
  prompt, system_prompt=None, history_messages=[], **kwargs
 
385
  )
386
  return np.array([dp.embedding for dp in response.data])
387
 
388
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
389
+ @retry(
390
+ stop=stop_after_attempt(3),
391
+ wait=wait_exponential(multiplier=1, min=4, max=10),
392
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
393
+ )
394
+ async def azure_openai_embedding(
395
+ texts: list[str],
396
+ model: str = "text-embedding-3-small",
397
+ base_url: str = None,
398
+ api_key: str = None,
399
+ ) -> np.ndarray:
400
+ if api_key:
401
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
402
+ if base_url:
403
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
404
+
405
+ openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
406
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
407
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
408
+
409
+ response = await openai_async_client.embeddings.create(
410
+ model=model, input=texts, encoding_format="float"
411
+ )
412
+ return np.array([dp.embedding for dp in response.data])
413
+
414
 
415
  # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
416
  # @retry(