Shane Walker commited on
Commit
11e4c6f
·
unverified ·
1 Parent(s): 0943277

feat(openai): add client configuration support to OpenAI integration

Browse files

Add support for custom client configurations in the OpenAI integration,
allowing for more flexible configuration of the AsyncOpenAI client.
This includes:

- Create a reusable helper function `create_openai_async_client`
- Add proper documentation for client configuration options
- Ensure consistent parameter precedence across the codebase
- Update the embedding function to support client configurations
- Add example script demonstrating custom client configuration usage

The changes maintain backward compatibility while providing a cleaner
and more maintainable approach to configuring OpenAI clients.

Files changed (1) hide show
  1. lightrag/llm/openai.py +101 -26
lightrag/llm/openai.py CHANGED
@@ -44,6 +44,43 @@ class InvalidResponseError(Exception):
44
  pass
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @retry(
48
  stop=stop_after_attempt(3),
49
  wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -60,29 +97,54 @@ async def openai_complete_if_cache(
60
  api_key: str | None = None,
61
  **kwargs: Any,
62
  ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if history_messages is None:
64
  history_messages = []
65
- if not api_key:
66
- api_key = os.environ["OPENAI_API_KEY"]
67
-
68
- default_headers = {
69
- "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
70
- "Content-Type": "application/json",
71
- }
72
 
73
  # Set openai logger level to INFO when VERBOSE_DEBUG is off
74
  if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
75
  logging.getLogger("openai").setLevel(logging.INFO)
76
 
77
- openai_async_client = (
78
- AsyncOpenAI(default_headers=default_headers, api_key=api_key)
79
- if base_url is None
80
- else AsyncOpenAI(
81
- base_url=base_url, default_headers=default_headers, api_key=api_key
82
- )
 
 
83
  )
 
 
84
  kwargs.pop("hashing_kv", None)
85
  kwargs.pop("keyword_extraction", None)
 
 
86
  messages: list[dict[str, Any]] = []
87
  if system_prompt:
88
  messages.append({"role": "system", "content": system_prompt})
@@ -257,21 +319,34 @@ async def openai_embed(
257
  model: str = "text-embedding-3-small",
258
  base_url: str = None,
259
  api_key: str = None,
 
260
  ) -> np.ndarray:
261
- if not api_key:
262
- api_key = os.environ["OPENAI_API_KEY"]
263
-
264
- default_headers = {
265
- "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
266
- "Content-Type": "application/json",
267
- }
268
- openai_async_client = (
269
- AsyncOpenAI(default_headers=default_headers, api_key=api_key)
270
- if base_url is None
271
- else AsyncOpenAI(
272
- base_url=base_url, default_headers=default_headers, api_key=api_key
273
- )
 
 
 
 
 
 
 
 
 
 
 
274
  )
 
275
  response = await openai_async_client.embeddings.create(
276
  model=model, input=texts, encoding_format="float"
277
  )
 
44
  pass
45
 
46
 
47
+ def create_openai_async_client(
48
+ api_key: str | None = None,
49
+ base_url: str | None = None,
50
+ client_configs: dict[str, Any] = None,
51
+ ) -> AsyncOpenAI:
52
+ """Create an AsyncOpenAI client with the given configuration.
53
+
54
+ Args:
55
+ api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
56
+ base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
57
+ client_configs: Additional configuration options for the AsyncOpenAI client.
58
+ These will override any default configurations but will be overridden by
59
+ explicit parameters (api_key, base_url).
60
+
61
+ Returns:
62
+ An AsyncOpenAI client instance.
63
+ """
64
+ if not api_key:
65
+ api_key = os.environ["OPENAI_API_KEY"]
66
+
67
+ default_headers = {
68
+ "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
69
+ "Content-Type": "application/json",
70
+ }
71
+
72
+ if client_configs is None:
73
+ client_configs = {}
74
+
75
+ # Create a merged config dict with precedence: explicit params > client_configs > defaults
76
+ merged_configs = {**client_configs, "default_headers": default_headers, "api_key": api_key}
77
+
78
+ if base_url is not None:
79
+ merged_configs["base_url"] = base_url
80
+
81
+ return AsyncOpenAI(**merged_configs)
82
+
83
+
84
  @retry(
85
  stop=stop_after_attempt(3),
86
  wait=wait_exponential(multiplier=1, min=4, max=10),
 
97
  api_key: str | None = None,
98
  **kwargs: Any,
99
  ) -> str:
100
+ """Complete a prompt using OpenAI's API with caching support.
101
+
102
+ Args:
103
+ model: The OpenAI model to use.
104
+ prompt: The prompt to complete.
105
+ system_prompt: Optional system prompt to include.
106
+ history_messages: Optional list of previous messages in the conversation.
107
+ base_url: Optional base URL for the OpenAI API.
108
+ api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
109
+ **kwargs: Additional keyword arguments to pass to the OpenAI API.
110
+ Special kwargs:
111
+ - openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
112
+ These will be passed to the client constructor but will be overridden by
113
+ explicit parameters (api_key, base_url).
114
+ - hashing_kv: Will be removed from kwargs before passing to OpenAI.
115
+ - keyword_extraction: Will be removed from kwargs before passing to OpenAI.
116
+
117
+ Returns:
118
+ The completed text or an async iterator of text chunks if streaming.
119
+
120
+ Raises:
121
+ InvalidResponseError: If the response from OpenAI is invalid or empty.
122
+ APIConnectionError: If there is a connection error with the OpenAI API.
123
+ RateLimitError: If the OpenAI API rate limit is exceeded.
124
+ APITimeoutError: If the OpenAI API request times out.
125
+ """
126
  if history_messages is None:
127
  history_messages = []
 
 
 
 
 
 
 
128
 
129
  # Set openai logger level to INFO when VERBOSE_DEBUG is off
130
  if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
131
  logging.getLogger("openai").setLevel(logging.INFO)
132
 
133
+ # Extract client configuration options
134
+ client_configs = kwargs.pop("openai_client_configs", {})
135
+
136
+ # Create the OpenAI client
137
+ openai_async_client = create_openai_async_client(
138
+ api_key=api_key,
139
+ base_url=base_url,
140
+ client_configs=client_configs
141
  )
142
+
143
+ # Remove special kwargs that shouldn't be passed to OpenAI
144
  kwargs.pop("hashing_kv", None)
145
  kwargs.pop("keyword_extraction", None)
146
+
147
+ # Prepare messages
148
  messages: list[dict[str, Any]] = []
149
  if system_prompt:
150
  messages.append({"role": "system", "content": system_prompt})
 
319
  model: str = "text-embedding-3-small",
320
  base_url: str = None,
321
  api_key: str = None,
322
+ client_configs: dict[str, Any] = None,
323
  ) -> np.ndarray:
324
+ """Generate embeddings for a list of texts using OpenAI's API.
325
+
326
+ Args:
327
+ texts: List of texts to embed.
328
+ model: The OpenAI embedding model to use.
329
+ base_url: Optional base URL for the OpenAI API.
330
+ api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
331
+ client_configs: Additional configuration options for the AsyncOpenAI client.
332
+ These will override any default configurations but will be overridden by
333
+ explicit parameters (api_key, base_url).
334
+
335
+ Returns:
336
+ A numpy array of embeddings, one per input text.
337
+
338
+ Raises:
339
+ APIConnectionError: If there is a connection error with the OpenAI API.
340
+ RateLimitError: If the OpenAI API rate limit is exceeded.
341
+ APITimeoutError: If the OpenAI API request times out.
342
+ """
343
+ # Create the OpenAI client
344
+ openai_async_client = create_openai_async_client(
345
+ api_key=api_key,
346
+ base_url=base_url,
347
+ client_configs=client_configs
348
  )
349
+
350
  response = await openai_async_client.embeddings.create(
351
  model=model, input=texts, encoding_format="float"
352
  )