yangdx commited on
Commit
3deb059
·
1 Parent(s): e0fe0ea

Fix linting

Browse files
Files changed (3) hide show
  1. lightrag/lightrag.py +6 -6
  2. lightrag/operate.py +13 -9
  3. lightrag/utils.py +6 -2
lightrag/lightrag.py CHANGED
@@ -197,12 +197,12 @@ class LightRAG:
197
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
198
 
199
  # Init embedding functions with separate instances for insert and query
200
- self.insert_embedding_func = limit_async_func_call(self.embedding_func_max_async)(
201
- self.embedding_func
202
- )
203
- self.query_embedding_func = limit_async_func_call(self.embedding_func_max_async_query)(
204
- self.embedding_func
205
- )
206
 
207
  # Initialize all storages
208
  self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
 
197
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
198
 
199
  # Init embedding functions with separate instances for insert and query
200
+ self.insert_embedding_func = limit_async_func_call(
201
+ self.embedding_func_max_async
202
+ )(self.embedding_func)
203
+ self.query_embedding_func = limit_async_func_call(
204
+ self.embedding_func_max_async_query
205
+ )(self.embedding_func)
206
 
207
  # Initialize all storages
208
  self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
lightrag/operate.py CHANGED
@@ -352,7 +352,7 @@ async def extract_entities(
352
  input_text: str, history_messages: list[dict[str, str]] = None
353
  ) -> str:
354
  if enable_llm_cache_for_entity_extract and llm_response_cache:
355
- custom_llm = None
356
  if (
357
  global_config["embedding_cache_config"]
358
  and global_config["embedding_cache_config"]["enabled"]
@@ -360,10 +360,14 @@ async def extract_entities(
360
  new_config = global_config.copy()
361
  new_config["embedding_cache_config"] = None
362
  new_config["enable_llm_cache"] = True
363
-
364
  # create a llm function with new_config for handle_cache
365
  async def custom_llm(
366
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
 
 
 
 
367
  ) -> str:
368
  # 合并 new_config 和其他 kwargs,保证其他参数不被覆盖
369
  merged_config = {**kwargs, **new_config}
@@ -374,7 +378,7 @@ async def extract_entities(
374
  keyword_extraction=keyword_extraction,
375
  **merged_config,
376
  )
377
-
378
  if history_messages:
379
  history = json.dumps(history_messages, ensure_ascii=False)
380
  _prompt = history + "\n" + input_text
@@ -383,12 +387,12 @@ async def extract_entities(
383
 
384
  arg_hash = compute_args_hash(_prompt)
385
  cached_return, _1, _2, _3 = await handle_cache(
386
- llm_response_cache,
387
- arg_hash,
388
- _prompt,
389
- "default",
390
  cache_type="default",
391
- llm=custom_llm
392
  )
393
  if cached_return:
394
  logger.debug(f"Found cache for {arg_hash}")
 
352
  input_text: str, history_messages: list[dict[str, str]] = None
353
  ) -> str:
354
  if enable_llm_cache_for_entity_extract and llm_response_cache:
355
+ custom_llm = None
356
  if (
357
  global_config["embedding_cache_config"]
358
  and global_config["embedding_cache_config"]["enabled"]
 
360
  new_config = global_config.copy()
361
  new_config["embedding_cache_config"] = None
362
  new_config["enable_llm_cache"] = True
363
+
364
  # create a llm function with new_config for handle_cache
365
  async def custom_llm(
366
+ prompt,
367
+ system_prompt=None,
368
+ history_messages=[],
369
+ keyword_extraction=False,
370
+ **kwargs,
371
  ) -> str:
372
  # 合并 new_config 和其他 kwargs,保证其他参数不被覆盖
373
  merged_config = {**kwargs, **new_config}
 
378
  keyword_extraction=keyword_extraction,
379
  **merged_config,
380
  )
381
+
382
  if history_messages:
383
  history = json.dumps(history_messages, ensure_ascii=False)
384
  _prompt = history + "\n" + input_text
 
387
 
388
  arg_hash = compute_args_hash(_prompt)
389
  cached_return, _1, _2, _3 = await handle_cache(
390
+ llm_response_cache,
391
+ arg_hash,
392
+ _prompt,
393
+ "default",
394
  cache_type="default",
395
+ llm=custom_llm,
396
  )
397
  if cached_return:
398
  logger.debug(f"Found cache for {arg_hash}")
lightrag/utils.py CHANGED
@@ -491,7 +491,9 @@ def dequantize_embedding(
491
  return (quantized * scale + min_val).astype(np.float32)
492
 
493
 
494
- async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None):
 
 
495
  """Generic cache handling function"""
496
  if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
497
  return None, None, None, None
@@ -528,7 +530,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
528
  similarity_threshold=embedding_cache_config["similarity_threshold"],
529
  mode=mode,
530
  use_llm_check=use_llm_check,
531
- llm_func=llm if (use_llm_check and llm is not None) else (llm_model_func if use_llm_check else None),
 
 
532
  original_prompt=prompt if use_llm_check else None,
533
  cache_type=cache_type,
534
  )
 
491
  return (quantized * scale + min_val).astype(np.float32)
492
 
493
 
494
+ async def handle_cache(
495
+ hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None
496
+ ):
497
  """Generic cache handling function"""
498
  if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
499
  return None, None, None, None
 
530
  similarity_threshold=embedding_cache_config["similarity_threshold"],
531
  mode=mode,
532
  use_llm_check=use_llm_check,
533
+ llm_func=llm
534
+ if (use_llm_check and llm is not None)
535
+ else (llm_model_func if use_llm_check else None),
536
  original_prompt=prompt if use_llm_check else None,
537
  cache_type=cache_type,
538
  )