Shane Walker
commited on
feat(openai): add client configuration support to OpenAI integration
Browse filesAdd support for custom client configurations in the OpenAI integration,
allowing for more flexible configuration of the AsyncOpenAI client.
This includes:
- Create a reusable helper function `create_openai_async_client`
- Add proper documentation for client configuration options
- Ensure consistent parameter precedence across the codebase
- Update the embedding function to support client configurations
- Add example script demonstrating custom client configuration usage
The changes maintain backward compatibility while providing a cleaner
and more maintainable approach to configuring OpenAI clients.
- lightrag/llm/openai.py +101 -26
lightrag/llm/openai.py
CHANGED
@@ -44,6 +44,43 @@ class InvalidResponseError(Exception):
|
|
44 |
pass
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
@retry(
|
48 |
stop=stop_after_attempt(3),
|
49 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
@@ -60,29 +97,54 @@ async def openai_complete_if_cache(
|
|
60 |
api_key: str | None = None,
|
61 |
**kwargs: Any,
|
62 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
if history_messages is None:
|
64 |
history_messages = []
|
65 |
-
if not api_key:
|
66 |
-
api_key = os.environ["OPENAI_API_KEY"]
|
67 |
-
|
68 |
-
default_headers = {
|
69 |
-
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
70 |
-
"Content-Type": "application/json",
|
71 |
-
}
|
72 |
|
73 |
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
74 |
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
75 |
logging.getLogger("openai").setLevel(logging.INFO)
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
83 |
)
|
|
|
|
|
84 |
kwargs.pop("hashing_kv", None)
|
85 |
kwargs.pop("keyword_extraction", None)
|
|
|
|
|
86 |
messages: list[dict[str, Any]] = []
|
87 |
if system_prompt:
|
88 |
messages.append({"role": "system", "content": system_prompt})
|
@@ -257,21 +319,34 @@ async def openai_embed(
|
|
257 |
model: str = "text-embedding-3-small",
|
258 |
base_url: str = None,
|
259 |
api_key: str = None,
|
|
|
260 |
) -> np.ndarray:
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
)
|
|
|
275 |
response = await openai_async_client.embeddings.create(
|
276 |
model=model, input=texts, encoding_format="float"
|
277 |
)
|
|
|
44 |
pass
|
45 |
|
46 |
|
47 |
+
def create_openai_async_client(
|
48 |
+
api_key: str | None = None,
|
49 |
+
base_url: str | None = None,
|
50 |
+
client_configs: dict[str, Any] = None,
|
51 |
+
) -> AsyncOpenAI:
|
52 |
+
"""Create an AsyncOpenAI client with the given configuration.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
56 |
+
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
|
57 |
+
client_configs: Additional configuration options for the AsyncOpenAI client.
|
58 |
+
These will override any default configurations but will be overridden by
|
59 |
+
explicit parameters (api_key, base_url).
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
An AsyncOpenAI client instance.
|
63 |
+
"""
|
64 |
+
if not api_key:
|
65 |
+
api_key = os.environ["OPENAI_API_KEY"]
|
66 |
+
|
67 |
+
default_headers = {
|
68 |
+
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
69 |
+
"Content-Type": "application/json",
|
70 |
+
}
|
71 |
+
|
72 |
+
if client_configs is None:
|
73 |
+
client_configs = {}
|
74 |
+
|
75 |
+
# Create a merged config dict with precedence: explicit params > client_configs > defaults
|
76 |
+
merged_configs = {**client_configs, "default_headers": default_headers, "api_key": api_key}
|
77 |
+
|
78 |
+
if base_url is not None:
|
79 |
+
merged_configs["base_url"] = base_url
|
80 |
+
|
81 |
+
return AsyncOpenAI(**merged_configs)
|
82 |
+
|
83 |
+
|
84 |
@retry(
|
85 |
stop=stop_after_attempt(3),
|
86 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
|
97 |
api_key: str | None = None,
|
98 |
**kwargs: Any,
|
99 |
) -> str:
|
100 |
+
"""Complete a prompt using OpenAI's API with caching support.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
model: The OpenAI model to use.
|
104 |
+
prompt: The prompt to complete.
|
105 |
+
system_prompt: Optional system prompt to include.
|
106 |
+
history_messages: Optional list of previous messages in the conversation.
|
107 |
+
base_url: Optional base URL for the OpenAI API.
|
108 |
+
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
109 |
+
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
110 |
+
Special kwargs:
|
111 |
+
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
112 |
+
These will be passed to the client constructor but will be overridden by
|
113 |
+
explicit parameters (api_key, base_url).
|
114 |
+
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
115 |
+
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
The completed text or an async iterator of text chunks if streaming.
|
119 |
+
|
120 |
+
Raises:
|
121 |
+
InvalidResponseError: If the response from OpenAI is invalid or empty.
|
122 |
+
APIConnectionError: If there is a connection error with the OpenAI API.
|
123 |
+
RateLimitError: If the OpenAI API rate limit is exceeded.
|
124 |
+
APITimeoutError: If the OpenAI API request times out.
|
125 |
+
"""
|
126 |
if history_messages is None:
|
127 |
history_messages = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
130 |
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
131 |
logging.getLogger("openai").setLevel(logging.INFO)
|
132 |
|
133 |
+
# Extract client configuration options
|
134 |
+
client_configs = kwargs.pop("openai_client_configs", {})
|
135 |
+
|
136 |
+
# Create the OpenAI client
|
137 |
+
openai_async_client = create_openai_async_client(
|
138 |
+
api_key=api_key,
|
139 |
+
base_url=base_url,
|
140 |
+
client_configs=client_configs
|
141 |
)
|
142 |
+
|
143 |
+
# Remove special kwargs that shouldn't be passed to OpenAI
|
144 |
kwargs.pop("hashing_kv", None)
|
145 |
kwargs.pop("keyword_extraction", None)
|
146 |
+
|
147 |
+
# Prepare messages
|
148 |
messages: list[dict[str, Any]] = []
|
149 |
if system_prompt:
|
150 |
messages.append({"role": "system", "content": system_prompt})
|
|
|
319 |
model: str = "text-embedding-3-small",
|
320 |
base_url: str = None,
|
321 |
api_key: str = None,
|
322 |
+
client_configs: dict[str, Any] = None,
|
323 |
) -> np.ndarray:
|
324 |
+
"""Generate embeddings for a list of texts using OpenAI's API.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
texts: List of texts to embed.
|
328 |
+
model: The OpenAI embedding model to use.
|
329 |
+
base_url: Optional base URL for the OpenAI API.
|
330 |
+
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
331 |
+
client_configs: Additional configuration options for the AsyncOpenAI client.
|
332 |
+
These will override any default configurations but will be overridden by
|
333 |
+
explicit parameters (api_key, base_url).
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
A numpy array of embeddings, one per input text.
|
337 |
+
|
338 |
+
Raises:
|
339 |
+
APIConnectionError: If there is a connection error with the OpenAI API.
|
340 |
+
RateLimitError: If the OpenAI API rate limit is exceeded.
|
341 |
+
APITimeoutError: If the OpenAI API request times out.
|
342 |
+
"""
|
343 |
+
# Create the OpenAI client
|
344 |
+
openai_async_client = create_openai_async_client(
|
345 |
+
api_key=api_key,
|
346 |
+
base_url=base_url,
|
347 |
+
client_configs=client_configs
|
348 |
)
|
349 |
+
|
350 |
response = await openai_async_client.embeddings.create(
|
351 |
model=model, input=texts, encoding_format="float"
|
352 |
)
|