|
import os |
|
import pipmaster as pm |
|
|
|
|
|
if not pm.is_installed("openai"): |
|
pm.install("openai") |
|
if not pm.is_installed("tenacity"): |
|
pm.install("tenacity") |
|
|
|
from openai import ( |
|
AsyncAzureOpenAI, |
|
APIConnectionError, |
|
RateLimitError, |
|
APITimeoutError, |
|
) |
|
from tenacity import ( |
|
retry, |
|
stop_after_attempt, |
|
wait_exponential, |
|
retry_if_exception_type, |
|
) |
|
|
|
from lightrag.utils import ( |
|
wrap_embedding_func_with_attrs, |
|
locate_json_string_body_from_string, |
|
safe_unicode_decode, |
|
) |
|
|
|
import numpy as np |
|
|
|
|
|
@retry( |
|
stop=stop_after_attempt(3), |
|
wait=wait_exponential(multiplier=1, min=4, max=10), |
|
retry=retry_if_exception_type( |
|
(RateLimitError, APIConnectionError, APIConnectionError) |
|
), |
|
) |
|
async def azure_openai_complete_if_cache( |
|
model, |
|
prompt, |
|
system_prompt=None, |
|
history_messages=[], |
|
base_url=None, |
|
api_key=None, |
|
api_version=None, |
|
**kwargs, |
|
): |
|
if api_key: |
|
os.environ["AZURE_OPENAI_API_KEY"] = api_key |
|
if base_url: |
|
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url |
|
if api_version: |
|
os.environ["AZURE_OPENAI_API_VERSION"] = api_version |
|
|
|
openai_async_client = AsyncAzureOpenAI( |
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), |
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"), |
|
api_version=os.getenv("AZURE_OPENAI_API_VERSION"), |
|
) |
|
kwargs.pop("hashing_kv", None) |
|
messages = [] |
|
if system_prompt: |
|
messages.append({"role": "system", "content": system_prompt}) |
|
messages.extend(history_messages) |
|
if prompt is not None: |
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
if "response_format" in kwargs: |
|
response = await openai_async_client.beta.chat.completions.parse( |
|
model=model, messages=messages, **kwargs |
|
) |
|
else: |
|
response = await openai_async_client.chat.completions.create( |
|
model=model, messages=messages, **kwargs |
|
) |
|
|
|
if hasattr(response, "__aiter__"): |
|
|
|
async def inner(): |
|
async for chunk in response: |
|
if len(chunk.choices) == 0: |
|
continue |
|
content = chunk.choices[0].delta.content |
|
if content is None: |
|
continue |
|
if r"\u" in content: |
|
content = safe_unicode_decode(content.encode("utf-8")) |
|
yield content |
|
|
|
return inner() |
|
else: |
|
content = response.choices[0].message.content |
|
if r"\u" in content: |
|
content = safe_unicode_decode(content.encode("utf-8")) |
|
return content |
|
|
|
|
|
async def azure_openai_complete( |
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs |
|
) -> str: |
|
keyword_extraction = kwargs.pop("keyword_extraction", None) |
|
result = await azure_openai_complete_if_cache( |
|
os.getenv("LLM_MODEL", "gpt-4o-mini"), |
|
prompt, |
|
system_prompt=system_prompt, |
|
history_messages=history_messages, |
|
**kwargs, |
|
) |
|
if keyword_extraction: |
|
return locate_json_string_body_from_string(result) |
|
return result |
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191) |
|
@retry( |
|
stop=stop_after_attempt(3), |
|
wait=wait_exponential(multiplier=1, min=4, max=10), |
|
retry=retry_if_exception_type( |
|
(RateLimitError, APIConnectionError, APITimeoutError) |
|
), |
|
) |
|
async def azure_openai_embed( |
|
texts: list[str], |
|
model: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"), |
|
base_url: str = None, |
|
api_key: str = None, |
|
api_version: str = None, |
|
) -> np.ndarray: |
|
if api_key: |
|
os.environ["AZURE_OPENAI_API_KEY"] = api_key |
|
if base_url: |
|
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url |
|
if api_version: |
|
os.environ["AZURE_OPENAI_API_VERSION"] = api_version |
|
|
|
openai_async_client = AsyncAzureOpenAI( |
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), |
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"), |
|
api_version=os.getenv("AZURE_OPENAI_API_VERSION"), |
|
) |
|
|
|
response = await openai_async_client.embeddings.create( |
|
model=model, input=texts, encoding_format="float" |
|
) |
|
return np.array([dp.embedding for dp in response.data]) |
|
|