yangdx
commited on
Commit
·
3deb059
1
Parent(s):
e0fe0ea
Fix linting
Browse files- lightrag/lightrag.py +6 -6
- lightrag/operate.py +13 -9
- lightrag/utils.py +6 -2
lightrag/lightrag.py
CHANGED
@@ -197,12 +197,12 @@ class LightRAG:
|
|
197 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
198 |
|
199 |
# Init embedding functions with separate instances for insert and query
|
200 |
-
self.insert_embedding_func = limit_async_func_call(
|
201 |
-
self.
|
202 |
-
)
|
203 |
-
self.query_embedding_func = limit_async_func_call(
|
204 |
-
self.
|
205 |
-
)
|
206 |
|
207 |
# Initialize all storages
|
208 |
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
|
|
197 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
198 |
|
199 |
# Init embedding functions with separate instances for insert and query
|
200 |
+
self.insert_embedding_func = limit_async_func_call(
|
201 |
+
self.embedding_func_max_async
|
202 |
+
)(self.embedding_func)
|
203 |
+
self.query_embedding_func = limit_async_func_call(
|
204 |
+
self.embedding_func_max_async_query
|
205 |
+
)(self.embedding_func)
|
206 |
|
207 |
# Initialize all storages
|
208 |
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
lightrag/operate.py
CHANGED
@@ -352,7 +352,7 @@ async def extract_entities(
|
|
352 |
input_text: str, history_messages: list[dict[str, str]] = None
|
353 |
) -> str:
|
354 |
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
355 |
-
custom_llm = None
|
356 |
if (
|
357 |
global_config["embedding_cache_config"]
|
358 |
and global_config["embedding_cache_config"]["enabled"]
|
@@ -360,10 +360,14 @@ async def extract_entities(
|
|
360 |
new_config = global_config.copy()
|
361 |
new_config["embedding_cache_config"] = None
|
362 |
new_config["enable_llm_cache"] = True
|
363 |
-
|
364 |
# create a llm function with new_config for handle_cache
|
365 |
async def custom_llm(
|
366 |
-
prompt,
|
|
|
|
|
|
|
|
|
367 |
) -> str:
|
368 |
# 合并 new_config 和其他 kwargs,保证其他参数不被覆盖
|
369 |
merged_config = {**kwargs, **new_config}
|
@@ -374,7 +378,7 @@ async def extract_entities(
|
|
374 |
keyword_extraction=keyword_extraction,
|
375 |
**merged_config,
|
376 |
)
|
377 |
-
|
378 |
if history_messages:
|
379 |
history = json.dumps(history_messages, ensure_ascii=False)
|
380 |
_prompt = history + "\n" + input_text
|
@@ -383,12 +387,12 @@ async def extract_entities(
|
|
383 |
|
384 |
arg_hash = compute_args_hash(_prompt)
|
385 |
cached_return, _1, _2, _3 = await handle_cache(
|
386 |
-
llm_response_cache,
|
387 |
-
arg_hash,
|
388 |
-
_prompt,
|
389 |
-
"default",
|
390 |
cache_type="default",
|
391 |
-
llm=custom_llm
|
392 |
)
|
393 |
if cached_return:
|
394 |
logger.debug(f"Found cache for {arg_hash}")
|
|
|
352 |
input_text: str, history_messages: list[dict[str, str]] = None
|
353 |
) -> str:
|
354 |
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
355 |
+
custom_llm = None
|
356 |
if (
|
357 |
global_config["embedding_cache_config"]
|
358 |
and global_config["embedding_cache_config"]["enabled"]
|
|
|
360 |
new_config = global_config.copy()
|
361 |
new_config["embedding_cache_config"] = None
|
362 |
new_config["enable_llm_cache"] = True
|
363 |
+
|
364 |
# create a llm function with new_config for handle_cache
|
365 |
async def custom_llm(
|
366 |
+
prompt,
|
367 |
+
system_prompt=None,
|
368 |
+
history_messages=[],
|
369 |
+
keyword_extraction=False,
|
370 |
+
**kwargs,
|
371 |
) -> str:
|
372 |
# 合并 new_config 和其他 kwargs,保证其他参数不被覆盖
|
373 |
merged_config = {**kwargs, **new_config}
|
|
|
378 |
keyword_extraction=keyword_extraction,
|
379 |
**merged_config,
|
380 |
)
|
381 |
+
|
382 |
if history_messages:
|
383 |
history = json.dumps(history_messages, ensure_ascii=False)
|
384 |
_prompt = history + "\n" + input_text
|
|
|
387 |
|
388 |
arg_hash = compute_args_hash(_prompt)
|
389 |
cached_return, _1, _2, _3 = await handle_cache(
|
390 |
+
llm_response_cache,
|
391 |
+
arg_hash,
|
392 |
+
_prompt,
|
393 |
+
"default",
|
394 |
cache_type="default",
|
395 |
+
llm=custom_llm,
|
396 |
)
|
397 |
if cached_return:
|
398 |
logger.debug(f"Found cache for {arg_hash}")
|
lightrag/utils.py
CHANGED
@@ -491,7 +491,9 @@ def dequantize_embedding(
|
|
491 |
return (quantized * scale + min_val).astype(np.float32)
|
492 |
|
493 |
|
494 |
-
async def handle_cache(
|
|
|
|
|
495 |
"""Generic cache handling function"""
|
496 |
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
|
497 |
return None, None, None, None
|
@@ -528,7 +530,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
|
|
528 |
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
529 |
mode=mode,
|
530 |
use_llm_check=use_llm_check,
|
531 |
-
llm_func=llm
|
|
|
|
|
532 |
original_prompt=prompt if use_llm_check else None,
|
533 |
cache_type=cache_type,
|
534 |
)
|
|
|
491 |
return (quantized * scale + min_val).astype(np.float32)
|
492 |
|
493 |
|
494 |
+
async def handle_cache(
|
495 |
+
hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None
|
496 |
+
):
|
497 |
"""Generic cache handling function"""
|
498 |
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
|
499 |
return None, None, None, None
|
|
|
530 |
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
531 |
mode=mode,
|
532 |
use_llm_check=use_llm_check,
|
533 |
+
llm_func=llm
|
534 |
+
if (use_llm_check and llm is not None)
|
535 |
+
else (llm_model_func if use_llm_check else None),
|
536 |
original_prompt=prompt if use_llm_check else None,
|
537 |
cache_type=cache_type,
|
538 |
)
|