yangdx commited on
Commit
e95fe5c
·
1 Parent(s): 036d85a

Unify llm_response_cache and hashing_kv, prevent creating an independent hashing_kv.

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -323,7 +323,7 @@ def create_app(args):
323
  vector_db_storage_cls_kwargs={
324
  "cosine_better_than_threshold": args.cosine_threshold
325
  },
326
- enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args
327
  embedding_cache_config={
328
  "enabled": True,
329
  "similarity_threshold": 0.95,
@@ -352,7 +352,7 @@ def create_app(args):
352
  vector_db_storage_cls_kwargs={
353
  "cosine_better_than_threshold": args.cosine_threshold
354
  },
355
- enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args
356
  embedding_cache_config={
357
  "enabled": True,
358
  "similarity_threshold": 0.95,
@@ -416,7 +416,7 @@ def create_app(args):
416
  "doc_status_storage": args.doc_status_storage,
417
  "graph_storage": args.graph_storage,
418
  "vector_storage": args.vector_storage,
419
- "enable_llm_cache": args.enable_llm_cache,
420
  },
421
  "update_status": update_status,
422
  }
 
323
  vector_db_storage_cls_kwargs={
324
  "cosine_better_than_threshold": args.cosine_threshold
325
  },
326
+ enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
327
  embedding_cache_config={
328
  "enabled": True,
329
  "similarity_threshold": 0.95,
 
352
  vector_db_storage_cls_kwargs={
353
  "cosine_better_than_threshold": args.cosine_threshold
354
  },
355
+ enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
356
  embedding_cache_config={
357
  "enabled": True,
358
  "similarity_threshold": 0.95,
 
416
  "doc_status_storage": args.doc_status_storage,
417
  "graph_storage": args.graph_storage,
418
  "vector_storage": args.vector_storage,
419
+ "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
420
  },
421
  "update_status": update_status,
422
  }
lightrag/api/utils_api.py CHANGED
@@ -361,7 +361,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
361
  args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
362
 
363
  # Inject LLM cache configuration
364
- args.enable_llm_cache = get_env_value(
365
  "ENABLE_LLM_CACHE_FOR_EXTRACT",
366
  False,
367
  bool
@@ -460,8 +460,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
460
  ASCIIColors.yellow(f"{args.cosine_threshold}")
461
  ASCIIColors.white(" ├─ Top-K: ", end="")
462
  ASCIIColors.yellow(f"{args.top_k}")
463
- ASCIIColors.white(" └─ LLM Cache Enabled: ", end="")
464
- ASCIIColors.yellow(f"{args.enable_llm_cache}")
465
 
466
  # System Configuration
467
  ASCIIColors.magenta("\n💾 Storage Configuration:")
 
361
  args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
362
 
363
  # Inject LLM cache configuration
364
+ args.enable_llm_cache_for_extract = get_env_value(
365
  "ENABLE_LLM_CACHE_FOR_EXTRACT",
366
  False,
367
  bool
 
460
  ASCIIColors.yellow(f"{args.cosine_threshold}")
461
  ASCIIColors.white(" ├─ Top-K: ", end="")
462
  ASCIIColors.yellow(f"{args.top_k}")
463
+ ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
464
+ ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
465
 
466
  # System Configuration
467
  ASCIIColors.magenta("\n💾 Storage Configuration:")
lightrag/lightrag.py CHANGED
@@ -354,6 +354,7 @@ class LightRAG:
354
  namespace=make_namespace(
355
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
356
  ),
 
357
  embedding_func=self.embedding_func,
358
  )
359
 
@@ -404,18 +405,8 @@ class LightRAG:
404
  embedding_func=None,
405
  )
406
 
407
- if self.llm_response_cache and hasattr(
408
- self.llm_response_cache, "global_config"
409
- ):
410
- hashing_kv = self.llm_response_cache
411
- else:
412
- hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
413
- namespace=make_namespace(
414
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
415
- ),
416
- global_config=asdict(self),
417
- embedding_func=self.embedding_func,
418
- )
419
 
420
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
421
  partial(
@@ -1260,16 +1251,7 @@ class LightRAG:
1260
  self.text_chunks,
1261
  param,
1262
  asdict(self),
1263
- hashing_kv=self.llm_response_cache
1264
- if self.llm_response_cache
1265
- and hasattr(self.llm_response_cache, "global_config")
1266
- else self.key_string_value_json_storage_cls(
1267
- namespace=make_namespace(
1268
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1269
- ),
1270
- global_config=asdict(self),
1271
- embedding_func=self.embedding_func,
1272
- ),
1273
  system_prompt=system_prompt,
1274
  )
1275
  elif param.mode == "naive":
@@ -1279,16 +1261,7 @@ class LightRAG:
1279
  self.text_chunks,
1280
  param,
1281
  asdict(self),
1282
- hashing_kv=self.llm_response_cache
1283
- if self.llm_response_cache
1284
- and hasattr(self.llm_response_cache, "global_config")
1285
- else self.key_string_value_json_storage_cls(
1286
- namespace=make_namespace(
1287
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1288
- ),
1289
- global_config=asdict(self),
1290
- embedding_func=self.embedding_func,
1291
- ),
1292
  system_prompt=system_prompt,
1293
  )
1294
  elif param.mode == "mix":
@@ -1301,16 +1274,7 @@ class LightRAG:
1301
  self.text_chunks,
1302
  param,
1303
  asdict(self),
1304
- hashing_kv=self.llm_response_cache
1305
- if self.llm_response_cache
1306
- and hasattr(self.llm_response_cache, "global_config")
1307
- else self.key_string_value_json_storage_cls(
1308
- namespace=make_namespace(
1309
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1310
- ),
1311
- global_config=asdict(self),
1312
- embedding_func=self.embedding_func,
1313
- ),
1314
  system_prompt=system_prompt,
1315
  )
1316
  else:
@@ -1344,14 +1308,7 @@ class LightRAG:
1344
  text=query,
1345
  param=param,
1346
  global_config=asdict(self),
1347
- hashing_kv=self.llm_response_cache
1348
- or self.key_string_value_json_storage_cls(
1349
- namespace=make_namespace(
1350
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1351
- ),
1352
- global_config=asdict(self),
1353
- embedding_func=self.embedding_func,
1354
- ),
1355
  )
1356
 
1357
  param.hl_keywords = hl_keywords
@@ -1375,16 +1332,7 @@ class LightRAG:
1375
  self.text_chunks,
1376
  param,
1377
  asdict(self),
1378
- hashing_kv=self.llm_response_cache
1379
- if self.llm_response_cache
1380
- and hasattr(self.llm_response_cache, "global_config")
1381
- else self.key_string_value_json_storage_cls(
1382
- namespace=make_namespace(
1383
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1384
- ),
1385
- global_config=asdict(self),
1386
- embedding_func=self.embedding_func,
1387
- ),
1388
  )
1389
  elif param.mode == "naive":
1390
  response = await naive_query(
@@ -1393,16 +1341,7 @@ class LightRAG:
1393
  self.text_chunks,
1394
  param,
1395
  asdict(self),
1396
- hashing_kv=self.llm_response_cache
1397
- if self.llm_response_cache
1398
- and hasattr(self.llm_response_cache, "global_config")
1399
- else self.key_string_value_json_storage_cls(
1400
- namespace=make_namespace(
1401
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1402
- ),
1403
- global_config=asdict(self),
1404
- embedding_func=self.embedding_func,
1405
- ),
1406
  )
1407
  elif param.mode == "mix":
1408
  response = await mix_kg_vector_query(
@@ -1414,16 +1353,7 @@ class LightRAG:
1414
  self.text_chunks,
1415
  param,
1416
  asdict(self),
1417
- hashing_kv=self.llm_response_cache
1418
- if self.llm_response_cache
1419
- and hasattr(self.llm_response_cache, "global_config")
1420
- else self.key_string_value_json_storage_cls(
1421
- namespace=make_namespace(
1422
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1423
- ),
1424
- global_config=asdict(self),
1425
- embedding_func=self.embedding_func,
1426
- ),
1427
  )
1428
  else:
1429
  raise ValueError(f"Unknown mode {param.mode}")
 
354
  namespace=make_namespace(
355
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
356
  ),
357
+ global_config=asdict(self), # Add global_config to ensure cache works properly
358
  embedding_func=self.embedding_func,
359
  )
360
 
 
405
  embedding_func=None,
406
  )
407
 
408
+ # Directly use llm_response_cache, don't create a new object
409
+ hashing_kv = self.llm_response_cache
 
 
 
 
 
 
 
 
 
 
410
 
411
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
412
  partial(
 
1251
  self.text_chunks,
1252
  param,
1253
  asdict(self),
1254
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1255
  system_prompt=system_prompt,
1256
  )
1257
  elif param.mode == "naive":
 
1261
  self.text_chunks,
1262
  param,
1263
  asdict(self),
1264
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1265
  system_prompt=system_prompt,
1266
  )
1267
  elif param.mode == "mix":
 
1274
  self.text_chunks,
1275
  param,
1276
  asdict(self),
1277
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1278
  system_prompt=system_prompt,
1279
  )
1280
  else:
 
1308
  text=query,
1309
  param=param,
1310
  global_config=asdict(self),
1311
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
1312
  )
1313
 
1314
  param.hl_keywords = hl_keywords
 
1332
  self.text_chunks,
1333
  param,
1334
  asdict(self),
1335
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1336
  )
1337
  elif param.mode == "naive":
1338
  response = await naive_query(
 
1341
  self.text_chunks,
1342
  param,
1343
  asdict(self),
1344
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1345
  )
1346
  elif param.mode == "mix":
1347
  response = await mix_kg_vector_query(
 
1353
  self.text_chunks,
1354
  param,
1355
  asdict(self),
1356
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1357
  )
1358
  else:
1359
  raise ValueError(f"Unknown mode {param.mode}")
lightrag/operate.py CHANGED
@@ -410,7 +410,6 @@ async def extract_entities(
410
  _prompt,
411
  "default",
412
  cache_type="extract",
413
- force_llm_cache=True,
414
  )
415
  if cached_return:
416
  logger.debug(f"Found cache for {arg_hash}")
@@ -432,6 +431,7 @@ async def extract_entities(
432
  cache_type="extract",
433
  ),
434
  )
 
435
  return res
436
 
437
  if history_messages:
 
410
  _prompt,
411
  "default",
412
  cache_type="extract",
 
413
  )
414
  if cached_return:
415
  logger.debug(f"Found cache for {arg_hash}")
 
431
  cache_type="extract",
432
  ),
433
  )
434
+ logger.info(f"Extract: saved cache for {arg_hash}")
435
  return res
436
 
437
  if history_messages:
lightrag/utils.py CHANGED
@@ -633,15 +633,15 @@ async def handle_cache(
633
  prompt,
634
  mode="default",
635
  cache_type=None,
636
- force_llm_cache=False,
637
  ):
638
  """Generic cache handling function"""
639
- if hashing_kv is None or not (
640
- force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
641
- ):
642
  return None, None, None, None
643
 
644
- if mode != "default":
 
 
 
645
  # Get embedding cache configuration
646
  embedding_cache_config = hashing_kv.global_config.get(
647
  "embedding_cache_config",
@@ -651,8 +651,7 @@ async def handle_cache(
651
  use_llm_check = embedding_cache_config.get("use_llm_check", False)
652
 
653
  quantized = min_val = max_val = None
654
- if is_embedding_cache_enabled:
655
- # Use embedding cache
656
  current_embedding = await hashing_kv.embedding_func([prompt])
657
  llm_model_func = hashing_kv.global_config.get("llm_model_func")
658
  quantized, min_val, max_val = quantize_embedding(current_embedding[0])
@@ -674,8 +673,13 @@ async def handle_cache(
674
  logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
675
  return None, quantized, min_val, max_val
676
 
677
- # For default mode or is_embedding_cache_enabled is False, use regular cache
678
- # default mode is for extract_entities or naive query
 
 
 
 
 
679
  if exists_func(hashing_kv, "get_by_mode_and_id"):
680
  mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
681
  else:
 
633
  prompt,
634
  mode="default",
635
  cache_type=None,
 
636
  ):
637
  """Generic cache handling function"""
638
+ if hashing_kv is None:
 
 
639
  return None, None, None, None
640
 
641
+ if mode != "default": # handle cache for all type of query
642
+ if not hashing_kv.global_config.get("enable_llm_cache"):
643
+ return None, None, None, None
644
+
645
  # Get embedding cache configuration
646
  embedding_cache_config = hashing_kv.global_config.get(
647
  "embedding_cache_config",
 
651
  use_llm_check = embedding_cache_config.get("use_llm_check", False)
652
 
653
  quantized = min_val = max_val = None
654
+ if is_embedding_cache_enabled: # Use embedding simularity to match cache
 
655
  current_embedding = await hashing_kv.embedding_func([prompt])
656
  llm_model_func = hashing_kv.global_config.get("llm_model_func")
657
  quantized, min_val, max_val = quantize_embedding(current_embedding[0])
 
673
  logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
674
  return None, quantized, min_val, max_val
675
 
676
+ else: # handle cache for entity extraction
677
+ if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
678
+ return None, None, None, None
679
+
680
+ # Here is the conditions of code reaching this point:
681
+ # 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
682
+ # 2. Entity extract: enable_llm_cache_for_entity_extract is True
683
  if exists_func(hashing_kv, "get_by_mode_and_id"):
684
  mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
685
  else: