drahnreb commited on
Commit
4b7275e
·
1 Parent(s): 0228302

fix truncation with global_config tokenizer

Browse files
Files changed (2) hide show
  1. lightrag/operate.py +35 -7
  2. lightrag/utils.py +1 -1
lightrag/operate.py CHANGED
@@ -842,6 +842,7 @@ async def kg_query(
842
  relationships_vdb,
843
  text_chunks_db,
844
  query_param,
 
845
  )
846
 
847
  if query_param.only_need_context:
@@ -1057,6 +1058,8 @@ async def mix_kg_vector_query(
1057
  2. Retrieving relevant text chunks through vector similarity
1058
  3. Combining both results for comprehensive answer generation
1059
  """
 
 
1060
  # 1. Cache handling
1061
  use_model_func = (
1062
  query_param.model_func
@@ -1111,6 +1114,7 @@ async def mix_kg_vector_query(
1111
  relationships_vdb,
1112
  text_chunks_db,
1113
  query_param,
 
1114
  )
1115
 
1116
  return context
@@ -1156,6 +1160,7 @@ async def mix_kg_vector_query(
1156
  valid_chunks,
1157
  key=lambda x: x["content"],
1158
  max_token_size=query_param.max_token_for_text_unit,
 
1159
  )
1160
 
1161
  if not maybe_trun_chunks:
@@ -1213,7 +1218,6 @@ async def mix_kg_vector_query(
1213
  if query_param.only_need_prompt:
1214
  return sys_prompt
1215
 
1216
- tokenizer: Tokenizer = global_config["tokenizer"]
1217
  len_of_prompts = len(tokenizer.encode(query + sys_prompt))
1218
  logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
1219
 
@@ -1263,6 +1267,7 @@ async def _build_query_context(
1263
  relationships_vdb: BaseVectorStorage,
1264
  text_chunks_db: BaseKVStorage,
1265
  query_param: QueryParam,
 
1266
  ):
1267
  logger.info(f"Process {os.getpid()} buidling query context...")
1268
  if query_param.mode == "local":
@@ -1272,6 +1277,7 @@ async def _build_query_context(
1272
  entities_vdb,
1273
  text_chunks_db,
1274
  query_param,
 
1275
  )
1276
  elif query_param.mode == "global":
1277
  entities_context, relations_context, text_units_context = await _get_edge_data(
@@ -1280,6 +1286,7 @@ async def _build_query_context(
1280
  relationships_vdb,
1281
  text_chunks_db,
1282
  query_param,
 
1283
  )
1284
  else: # hybrid mode
1285
  ll_data = await _get_node_data(
@@ -1288,6 +1295,7 @@ async def _build_query_context(
1288
  entities_vdb,
1289
  text_chunks_db,
1290
  query_param,
 
1291
  )
1292
  hl_data = await _get_edge_data(
1293
  hl_keywords,
@@ -1295,6 +1303,7 @@ async def _build_query_context(
1295
  relationships_vdb,
1296
  text_chunks_db,
1297
  query_param,
 
1298
  )
1299
 
1300
  (
@@ -1341,6 +1350,7 @@ async def _get_node_data(
1341
  entities_vdb: BaseVectorStorage,
1342
  text_chunks_db: BaseKVStorage,
1343
  query_param: QueryParam,
 
1344
  ):
1345
  # get similar entities
1346
  logger.info(
@@ -1377,17 +1387,19 @@ async def _get_node_data(
1377
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1378
  # get entitytext chunk
1379
  use_text_units = await _find_most_related_text_unit_from_entities(
1380
- node_datas, query_param, text_chunks_db, knowledge_graph_inst
1381
  )
1382
  use_relations = await _find_most_related_edges_from_entities(
1383
- node_datas, query_param, knowledge_graph_inst
1384
  )
1385
 
 
1386
  len_node_datas = len(node_datas)
1387
  node_datas = truncate_list_by_token_size(
1388
  node_datas,
1389
  key=lambda x: x["description"] if x["description"] is not None else "",
1390
  max_token_size=query_param.max_token_for_local_context,
 
1391
  )
1392
  logger.debug(
1393
  f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1481,6 +1493,7 @@ async def _find_most_related_text_unit_from_entities(
1481
  query_param: QueryParam,
1482
  text_chunks_db: BaseKVStorage,
1483
  knowledge_graph_inst: BaseGraphStorage,
 
1484
  ):
1485
  text_units = [
1486
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1562,14 +1575,15 @@ async def _find_most_related_text_unit_from_entities(
1562
  logger.warning("No valid text units found")
1563
  return []
1564
 
 
1565
  all_text_units = sorted(
1566
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
1567
  )
1568
-
1569
  all_text_units = truncate_list_by_token_size(
1570
  all_text_units,
1571
  key=lambda x: x["data"]["content"],
1572
  max_token_size=query_param.max_token_for_text_unit,
 
1573
  )
1574
 
1575
  logger.debug(
@@ -1584,6 +1598,7 @@ async def _find_most_related_edges_from_entities(
1584
  node_datas: list[dict],
1585
  query_param: QueryParam,
1586
  knowledge_graph_inst: BaseGraphStorage,
 
1587
  ):
1588
  node_names = [dp["entity_name"] for dp in node_datas]
1589
  batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
@@ -1623,6 +1638,7 @@ async def _find_most_related_edges_from_entities(
1623
  }
1624
  all_edges_data.append(combined)
1625
 
 
1626
  all_edges_data = sorted(
1627
  all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
1628
  )
@@ -1630,6 +1646,7 @@ async def _find_most_related_edges_from_entities(
1630
  all_edges_data,
1631
  key=lambda x: x["description"] if x["description"] is not None else "",
1632
  max_token_size=query_param.max_token_for_global_context,
 
1633
  )
1634
 
1635
  logger.debug(
@@ -1645,6 +1662,7 @@ async def _get_edge_data(
1645
  relationships_vdb: BaseVectorStorage,
1646
  text_chunks_db: BaseKVStorage,
1647
  query_param: QueryParam,
 
1648
  ):
1649
  logger.info(
1650
  f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
@@ -1685,6 +1703,7 @@ async def _get_edge_data(
1685
  }
1686
  edge_datas.append(combined)
1687
 
 
1688
  edge_datas = sorted(
1689
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1690
  )
@@ -1692,13 +1711,14 @@ async def _get_edge_data(
1692
  edge_datas,
1693
  key=lambda x: x["description"] if x["description"] is not None else "",
1694
  max_token_size=query_param.max_token_for_global_context,
 
1695
  )
1696
  use_entities, use_text_units = await asyncio.gather(
1697
  _find_most_related_entities_from_relationships(
1698
- edge_datas, query_param, knowledge_graph_inst
1699
  ),
1700
  _find_related_text_unit_from_relationships(
1701
- edge_datas, query_param, text_chunks_db, knowledge_graph_inst
1702
  ),
1703
  )
1704
  logger.info(
@@ -1778,6 +1798,7 @@ async def _find_most_related_entities_from_relationships(
1778
  edge_datas: list[dict],
1779
  query_param: QueryParam,
1780
  knowledge_graph_inst: BaseGraphStorage,
 
1781
  ):
1782
  entity_names = []
1783
  seen = set()
@@ -1808,11 +1829,13 @@ async def _find_most_related_entities_from_relationships(
1808
  combined = {**node, "entity_name": entity_name, "rank": degree}
1809
  node_datas.append(combined)
1810
 
 
1811
  len_node_datas = len(node_datas)
1812
  node_datas = truncate_list_by_token_size(
1813
  node_datas,
1814
  key=lambda x: x["description"] if x["description"] is not None else "",
1815
  max_token_size=query_param.max_token_for_local_context,
 
1816
  )
1817
  logger.debug(
1818
  f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1826,6 +1849,7 @@ async def _find_related_text_unit_from_relationships(
1826
  query_param: QueryParam,
1827
  text_chunks_db: BaseKVStorage,
1828
  knowledge_graph_inst: BaseGraphStorage,
 
1829
  ):
1830
  text_units = [
1831
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1867,10 +1891,12 @@ async def _find_related_text_unit_from_relationships(
1867
  logger.warning("No valid text chunks after filtering")
1868
  return []
1869
 
 
1870
  truncated_text_units = truncate_list_by_token_size(
1871
  valid_text_units,
1872
  key=lambda x: x["data"]["content"],
1873
  max_token_size=query_param.max_token_for_text_unit,
 
1874
  )
1875
 
1876
  logger.debug(
@@ -1941,10 +1967,12 @@ async def naive_query(
1941
  logger.warning("No valid chunks found after filtering")
1942
  return PROMPTS["fail_response"]
1943
 
 
1944
  maybe_trun_chunks = truncate_list_by_token_size(
1945
  valid_chunks,
1946
  key=lambda x: x["content"],
1947
  max_token_size=query_param.max_token_for_text_unit,
 
1948
  )
1949
 
1950
  if not maybe_trun_chunks:
@@ -1982,7 +2010,6 @@ async def naive_query(
1982
  if query_param.only_need_prompt:
1983
  return sys_prompt
1984
 
1985
- tokenizer: Tokenizer = global_config["tokenizer"]
1986
  len_of_prompts = len(tokenizer.encode(query + sys_prompt))
1987
  logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
1988
 
@@ -2101,6 +2128,7 @@ async def kg_query_with_keywords(
2101
  relationships_vdb,
2102
  text_chunks_db,
2103
  query_param,
 
2104
  )
2105
  if not context:
2106
  return PROMPTS["fail_response"]
 
842
  relationships_vdb,
843
  text_chunks_db,
844
  query_param,
845
+ global_config,
846
  )
847
 
848
  if query_param.only_need_context:
 
1058
  2. Retrieving relevant text chunks through vector similarity
1059
  3. Combining both results for comprehensive answer generation
1060
  """
1061
+ # get tokenizer
1062
+ tokenizer: Tokenizer = global_config["tokenizer"]
1063
  # 1. Cache handling
1064
  use_model_func = (
1065
  query_param.model_func
 
1114
  relationships_vdb,
1115
  text_chunks_db,
1116
  query_param,
1117
+ global_config,
1118
  )
1119
 
1120
  return context
 
1160
  valid_chunks,
1161
  key=lambda x: x["content"],
1162
  max_token_size=query_param.max_token_for_text_unit,
1163
+ tokenizer=tokenizer,
1164
  )
1165
 
1166
  if not maybe_trun_chunks:
 
1218
  if query_param.only_need_prompt:
1219
  return sys_prompt
1220
 
 
1221
  len_of_prompts = len(tokenizer.encode(query + sys_prompt))
1222
  logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
1223
 
 
1267
  relationships_vdb: BaseVectorStorage,
1268
  text_chunks_db: BaseKVStorage,
1269
  query_param: QueryParam,
1270
+ global_config: dict[str, str],
1271
  ):
1272
  logger.info(f"Process {os.getpid()} buidling query context...")
1273
  if query_param.mode == "local":
 
1277
  entities_vdb,
1278
  text_chunks_db,
1279
  query_param,
1280
+ global_config,
1281
  )
1282
  elif query_param.mode == "global":
1283
  entities_context, relations_context, text_units_context = await _get_edge_data(
 
1286
  relationships_vdb,
1287
  text_chunks_db,
1288
  query_param,
1289
+ global_config,
1290
  )
1291
  else: # hybrid mode
1292
  ll_data = await _get_node_data(
 
1295
  entities_vdb,
1296
  text_chunks_db,
1297
  query_param,
1298
+ global_config,
1299
  )
1300
  hl_data = await _get_edge_data(
1301
  hl_keywords,
 
1303
  relationships_vdb,
1304
  text_chunks_db,
1305
  query_param,
1306
+ global_config,
1307
  )
1308
 
1309
  (
 
1350
  entities_vdb: BaseVectorStorage,
1351
  text_chunks_db: BaseKVStorage,
1352
  query_param: QueryParam,
1353
+ global_config: dict[str, str],
1354
  ):
1355
  # get similar entities
1356
  logger.info(
 
1387
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1388
  # get entitytext chunk
1389
  use_text_units = await _find_most_related_text_unit_from_entities(
1390
+ node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
1391
  )
1392
  use_relations = await _find_most_related_edges_from_entities(
1393
+ node_datas, query_param, knowledge_graph_inst, global_config
1394
  )
1395
 
1396
+ tokenizer: Tokenizer = global_config["tokenizer"]
1397
  len_node_datas = len(node_datas)
1398
  node_datas = truncate_list_by_token_size(
1399
  node_datas,
1400
  key=lambda x: x["description"] if x["description"] is not None else "",
1401
  max_token_size=query_param.max_token_for_local_context,
1402
+ tokenizer=tokenizer,
1403
  )
1404
  logger.debug(
1405
  f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
 
1493
  query_param: QueryParam,
1494
  text_chunks_db: BaseKVStorage,
1495
  knowledge_graph_inst: BaseGraphStorage,
1496
+ global_config: dict[str, str],
1497
  ):
1498
  text_units = [
1499
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
 
1575
  logger.warning("No valid text units found")
1576
  return []
1577
 
1578
+ tokenizer: Tokenizer = global_config["tokenizer"]
1579
  all_text_units = sorted(
1580
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
1581
  )
 
1582
  all_text_units = truncate_list_by_token_size(
1583
  all_text_units,
1584
  key=lambda x: x["data"]["content"],
1585
  max_token_size=query_param.max_token_for_text_unit,
1586
+ tokenizer=tokenizer,
1587
  )
1588
 
1589
  logger.debug(
 
1598
  node_datas: list[dict],
1599
  query_param: QueryParam,
1600
  knowledge_graph_inst: BaseGraphStorage,
1601
+ global_config: dict[str, str],
1602
  ):
1603
  node_names = [dp["entity_name"] for dp in node_datas]
1604
  batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
 
1638
  }
1639
  all_edges_data.append(combined)
1640
 
1641
+ tokenizer: Tokenizer = global_config["tokenizer"]
1642
  all_edges_data = sorted(
1643
  all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
1644
  )
 
1646
  all_edges_data,
1647
  key=lambda x: x["description"] if x["description"] is not None else "",
1648
  max_token_size=query_param.max_token_for_global_context,
1649
+ tokenizer=tokenizer,
1650
  )
1651
 
1652
  logger.debug(
 
1662
  relationships_vdb: BaseVectorStorage,
1663
  text_chunks_db: BaseKVStorage,
1664
  query_param: QueryParam,
1665
+ global_config: dict[str, str],
1666
  ):
1667
  logger.info(
1668
  f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
 
1703
  }
1704
  edge_datas.append(combined)
1705
 
1706
+ tokenizer: Tokenizer = global_config["tokenizer"]
1707
  edge_datas = sorted(
1708
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1709
  )
 
1711
  edge_datas,
1712
  key=lambda x: x["description"] if x["description"] is not None else "",
1713
  max_token_size=query_param.max_token_for_global_context,
1714
+ tokenizer=tokenizer,
1715
  )
1716
  use_entities, use_text_units = await asyncio.gather(
1717
  _find_most_related_entities_from_relationships(
1718
+ edge_datas, query_param, knowledge_graph_inst, global_config
1719
  ),
1720
  _find_related_text_unit_from_relationships(
1721
+ edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
1722
  ),
1723
  )
1724
  logger.info(
 
1798
  edge_datas: list[dict],
1799
  query_param: QueryParam,
1800
  knowledge_graph_inst: BaseGraphStorage,
1801
+ global_config: dict[str, str],
1802
  ):
1803
  entity_names = []
1804
  seen = set()
 
1829
  combined = {**node, "entity_name": entity_name, "rank": degree}
1830
  node_datas.append(combined)
1831
 
1832
+ tokenizer: Tokenizer = global_config["tokenizer"]
1833
  len_node_datas = len(node_datas)
1834
  node_datas = truncate_list_by_token_size(
1835
  node_datas,
1836
  key=lambda x: x["description"] if x["description"] is not None else "",
1837
  max_token_size=query_param.max_token_for_local_context,
1838
+ tokenizer=tokenizer,
1839
  )
1840
  logger.debug(
1841
  f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
 
1849
  query_param: QueryParam,
1850
  text_chunks_db: BaseKVStorage,
1851
  knowledge_graph_inst: BaseGraphStorage,
1852
+ global_config: dict[str, str],
1853
  ):
1854
  text_units = [
1855
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
 
1891
  logger.warning("No valid text chunks after filtering")
1892
  return []
1893
 
1894
+ tokenizer: Tokenizer = global_config["tokenizer"]
1895
  truncated_text_units = truncate_list_by_token_size(
1896
  valid_text_units,
1897
  key=lambda x: x["data"]["content"],
1898
  max_token_size=query_param.max_token_for_text_unit,
1899
+ tokenizer=tokenizer,
1900
  )
1901
 
1902
  logger.debug(
 
1967
  logger.warning("No valid chunks found after filtering")
1968
  return PROMPTS["fail_response"]
1969
 
1970
+ tokenizer: Tokenizer = global_config["tokenizer"]
1971
  maybe_trun_chunks = truncate_list_by_token_size(
1972
  valid_chunks,
1973
  key=lambda x: x["content"],
1974
  max_token_size=query_param.max_token_for_text_unit,
1975
+ tokenizer=tokenizer,
1976
  )
1977
 
1978
  if not maybe_trun_chunks:
 
2010
  if query_param.only_need_prompt:
2011
  return sys_prompt
2012
 
 
2013
  len_of_prompts = len(tokenizer.encode(query + sys_prompt))
2014
  logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
2015
 
 
2128
  relationships_vdb,
2129
  text_chunks_db,
2130
  query_param,
2131
+ global_config,
2132
  )
2133
  if not context:
2134
  return PROMPTS["fail_response"]
lightrag/utils.py CHANGED
@@ -424,7 +424,7 @@ def is_float_regex(value: str) -> bool:
424
 
425
 
426
  def truncate_list_by_token_size(
427
- list_data: list[Any], key: Callable[[Any], str], max_token_size: int
428
  ) -> list[int]:
429
  """Truncate a list of data by token size"""
430
  if max_token_size <= 0:
 
424
 
425
 
426
  def truncate_list_by_token_size(
427
+ list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
428
  ) -> list[int]:
429
  """Truncate a list of data by token size"""
430
  if max_token_size <= 0: