drahnreb
commited on
Commit
·
4b7275e
1
Parent(s):
0228302
fix truncation with global_config tokenizer
Browse files- lightrag/operate.py +35 -7
- lightrag/utils.py +1 -1
lightrag/operate.py
CHANGED
@@ -842,6 +842,7 @@ async def kg_query(
|
|
842 |
relationships_vdb,
|
843 |
text_chunks_db,
|
844 |
query_param,
|
|
|
845 |
)
|
846 |
|
847 |
if query_param.only_need_context:
|
@@ -1057,6 +1058,8 @@ async def mix_kg_vector_query(
|
|
1057 |
2. Retrieving relevant text chunks through vector similarity
|
1058 |
3. Combining both results for comprehensive answer generation
|
1059 |
"""
|
|
|
|
|
1060 |
# 1. Cache handling
|
1061 |
use_model_func = (
|
1062 |
query_param.model_func
|
@@ -1111,6 +1114,7 @@ async def mix_kg_vector_query(
|
|
1111 |
relationships_vdb,
|
1112 |
text_chunks_db,
|
1113 |
query_param,
|
|
|
1114 |
)
|
1115 |
|
1116 |
return context
|
@@ -1156,6 +1160,7 @@ async def mix_kg_vector_query(
|
|
1156 |
valid_chunks,
|
1157 |
key=lambda x: x["content"],
|
1158 |
max_token_size=query_param.max_token_for_text_unit,
|
|
|
1159 |
)
|
1160 |
|
1161 |
if not maybe_trun_chunks:
|
@@ -1213,7 +1218,6 @@ async def mix_kg_vector_query(
|
|
1213 |
if query_param.only_need_prompt:
|
1214 |
return sys_prompt
|
1215 |
|
1216 |
-
tokenizer: Tokenizer = global_config["tokenizer"]
|
1217 |
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
1218 |
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
1219 |
|
@@ -1263,6 +1267,7 @@ async def _build_query_context(
|
|
1263 |
relationships_vdb: BaseVectorStorage,
|
1264 |
text_chunks_db: BaseKVStorage,
|
1265 |
query_param: QueryParam,
|
|
|
1266 |
):
|
1267 |
logger.info(f"Process {os.getpid()} buidling query context...")
|
1268 |
if query_param.mode == "local":
|
@@ -1272,6 +1277,7 @@ async def _build_query_context(
|
|
1272 |
entities_vdb,
|
1273 |
text_chunks_db,
|
1274 |
query_param,
|
|
|
1275 |
)
|
1276 |
elif query_param.mode == "global":
|
1277 |
entities_context, relations_context, text_units_context = await _get_edge_data(
|
@@ -1280,6 +1286,7 @@ async def _build_query_context(
|
|
1280 |
relationships_vdb,
|
1281 |
text_chunks_db,
|
1282 |
query_param,
|
|
|
1283 |
)
|
1284 |
else: # hybrid mode
|
1285 |
ll_data = await _get_node_data(
|
@@ -1288,6 +1295,7 @@ async def _build_query_context(
|
|
1288 |
entities_vdb,
|
1289 |
text_chunks_db,
|
1290 |
query_param,
|
|
|
1291 |
)
|
1292 |
hl_data = await _get_edge_data(
|
1293 |
hl_keywords,
|
@@ -1295,6 +1303,7 @@ async def _build_query_context(
|
|
1295 |
relationships_vdb,
|
1296 |
text_chunks_db,
|
1297 |
query_param,
|
|
|
1298 |
)
|
1299 |
|
1300 |
(
|
@@ -1341,6 +1350,7 @@ async def _get_node_data(
|
|
1341 |
entities_vdb: BaseVectorStorage,
|
1342 |
text_chunks_db: BaseKVStorage,
|
1343 |
query_param: QueryParam,
|
|
|
1344 |
):
|
1345 |
# get similar entities
|
1346 |
logger.info(
|
@@ -1377,17 +1387,19 @@ async def _get_node_data(
|
|
1377 |
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
1378 |
# get entitytext chunk
|
1379 |
use_text_units = await _find_most_related_text_unit_from_entities(
|
1380 |
-
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
1381 |
)
|
1382 |
use_relations = await _find_most_related_edges_from_entities(
|
1383 |
-
node_datas, query_param, knowledge_graph_inst
|
1384 |
)
|
1385 |
|
|
|
1386 |
len_node_datas = len(node_datas)
|
1387 |
node_datas = truncate_list_by_token_size(
|
1388 |
node_datas,
|
1389 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1390 |
max_token_size=query_param.max_token_for_local_context,
|
|
|
1391 |
)
|
1392 |
logger.debug(
|
1393 |
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
@@ -1481,6 +1493,7 @@ async def _find_most_related_text_unit_from_entities(
|
|
1481 |
query_param: QueryParam,
|
1482 |
text_chunks_db: BaseKVStorage,
|
1483 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
1484 |
):
|
1485 |
text_units = [
|
1486 |
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
@@ -1562,14 +1575,15 @@ async def _find_most_related_text_unit_from_entities(
|
|
1562 |
logger.warning("No valid text units found")
|
1563 |
return []
|
1564 |
|
|
|
1565 |
all_text_units = sorted(
|
1566 |
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
1567 |
)
|
1568 |
-
|
1569 |
all_text_units = truncate_list_by_token_size(
|
1570 |
all_text_units,
|
1571 |
key=lambda x: x["data"]["content"],
|
1572 |
max_token_size=query_param.max_token_for_text_unit,
|
|
|
1573 |
)
|
1574 |
|
1575 |
logger.debug(
|
@@ -1584,6 +1598,7 @@ async def _find_most_related_edges_from_entities(
|
|
1584 |
node_datas: list[dict],
|
1585 |
query_param: QueryParam,
|
1586 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
1587 |
):
|
1588 |
node_names = [dp["entity_name"] for dp in node_datas]
|
1589 |
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
@@ -1623,6 +1638,7 @@ async def _find_most_related_edges_from_entities(
|
|
1623 |
}
|
1624 |
all_edges_data.append(combined)
|
1625 |
|
|
|
1626 |
all_edges_data = sorted(
|
1627 |
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1628 |
)
|
@@ -1630,6 +1646,7 @@ async def _find_most_related_edges_from_entities(
|
|
1630 |
all_edges_data,
|
1631 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1632 |
max_token_size=query_param.max_token_for_global_context,
|
|
|
1633 |
)
|
1634 |
|
1635 |
logger.debug(
|
@@ -1645,6 +1662,7 @@ async def _get_edge_data(
|
|
1645 |
relationships_vdb: BaseVectorStorage,
|
1646 |
text_chunks_db: BaseKVStorage,
|
1647 |
query_param: QueryParam,
|
|
|
1648 |
):
|
1649 |
logger.info(
|
1650 |
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
@@ -1685,6 +1703,7 @@ async def _get_edge_data(
|
|
1685 |
}
|
1686 |
edge_datas.append(combined)
|
1687 |
|
|
|
1688 |
edge_datas = sorted(
|
1689 |
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1690 |
)
|
@@ -1692,13 +1711,14 @@ async def _get_edge_data(
|
|
1692 |
edge_datas,
|
1693 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1694 |
max_token_size=query_param.max_token_for_global_context,
|
|
|
1695 |
)
|
1696 |
use_entities, use_text_units = await asyncio.gather(
|
1697 |
_find_most_related_entities_from_relationships(
|
1698 |
-
edge_datas, query_param, knowledge_graph_inst
|
1699 |
),
|
1700 |
_find_related_text_unit_from_relationships(
|
1701 |
-
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
1702 |
),
|
1703 |
)
|
1704 |
logger.info(
|
@@ -1778,6 +1798,7 @@ async def _find_most_related_entities_from_relationships(
|
|
1778 |
edge_datas: list[dict],
|
1779 |
query_param: QueryParam,
|
1780 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
1781 |
):
|
1782 |
entity_names = []
|
1783 |
seen = set()
|
@@ -1808,11 +1829,13 @@ async def _find_most_related_entities_from_relationships(
|
|
1808 |
combined = {**node, "entity_name": entity_name, "rank": degree}
|
1809 |
node_datas.append(combined)
|
1810 |
|
|
|
1811 |
len_node_datas = len(node_datas)
|
1812 |
node_datas = truncate_list_by_token_size(
|
1813 |
node_datas,
|
1814 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1815 |
max_token_size=query_param.max_token_for_local_context,
|
|
|
1816 |
)
|
1817 |
logger.debug(
|
1818 |
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
@@ -1826,6 +1849,7 @@ async def _find_related_text_unit_from_relationships(
|
|
1826 |
query_param: QueryParam,
|
1827 |
text_chunks_db: BaseKVStorage,
|
1828 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
1829 |
):
|
1830 |
text_units = [
|
1831 |
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
@@ -1867,10 +1891,12 @@ async def _find_related_text_unit_from_relationships(
|
|
1867 |
logger.warning("No valid text chunks after filtering")
|
1868 |
return []
|
1869 |
|
|
|
1870 |
truncated_text_units = truncate_list_by_token_size(
|
1871 |
valid_text_units,
|
1872 |
key=lambda x: x["data"]["content"],
|
1873 |
max_token_size=query_param.max_token_for_text_unit,
|
|
|
1874 |
)
|
1875 |
|
1876 |
logger.debug(
|
@@ -1941,10 +1967,12 @@ async def naive_query(
|
|
1941 |
logger.warning("No valid chunks found after filtering")
|
1942 |
return PROMPTS["fail_response"]
|
1943 |
|
|
|
1944 |
maybe_trun_chunks = truncate_list_by_token_size(
|
1945 |
valid_chunks,
|
1946 |
key=lambda x: x["content"],
|
1947 |
max_token_size=query_param.max_token_for_text_unit,
|
|
|
1948 |
)
|
1949 |
|
1950 |
if not maybe_trun_chunks:
|
@@ -1982,7 +2010,6 @@ async def naive_query(
|
|
1982 |
if query_param.only_need_prompt:
|
1983 |
return sys_prompt
|
1984 |
|
1985 |
-
tokenizer: Tokenizer = global_config["tokenizer"]
|
1986 |
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
1987 |
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
1988 |
|
@@ -2101,6 +2128,7 @@ async def kg_query_with_keywords(
|
|
2101 |
relationships_vdb,
|
2102 |
text_chunks_db,
|
2103 |
query_param,
|
|
|
2104 |
)
|
2105 |
if not context:
|
2106 |
return PROMPTS["fail_response"]
|
|
|
842 |
relationships_vdb,
|
843 |
text_chunks_db,
|
844 |
query_param,
|
845 |
+
global_config,
|
846 |
)
|
847 |
|
848 |
if query_param.only_need_context:
|
|
|
1058 |
2. Retrieving relevant text chunks through vector similarity
|
1059 |
3. Combining both results for comprehensive answer generation
|
1060 |
"""
|
1061 |
+
# get tokenizer
|
1062 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1063 |
# 1. Cache handling
|
1064 |
use_model_func = (
|
1065 |
query_param.model_func
|
|
|
1114 |
relationships_vdb,
|
1115 |
text_chunks_db,
|
1116 |
query_param,
|
1117 |
+
global_config,
|
1118 |
)
|
1119 |
|
1120 |
return context
|
|
|
1160 |
valid_chunks,
|
1161 |
key=lambda x: x["content"],
|
1162 |
max_token_size=query_param.max_token_for_text_unit,
|
1163 |
+
tokenizer=tokenizer,
|
1164 |
)
|
1165 |
|
1166 |
if not maybe_trun_chunks:
|
|
|
1218 |
if query_param.only_need_prompt:
|
1219 |
return sys_prompt
|
1220 |
|
|
|
1221 |
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
1222 |
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
1223 |
|
|
|
1267 |
relationships_vdb: BaseVectorStorage,
|
1268 |
text_chunks_db: BaseKVStorage,
|
1269 |
query_param: QueryParam,
|
1270 |
+
global_config: dict[str, str],
|
1271 |
):
|
1272 |
logger.info(f"Process {os.getpid()} buidling query context...")
|
1273 |
if query_param.mode == "local":
|
|
|
1277 |
entities_vdb,
|
1278 |
text_chunks_db,
|
1279 |
query_param,
|
1280 |
+
global_config,
|
1281 |
)
|
1282 |
elif query_param.mode == "global":
|
1283 |
entities_context, relations_context, text_units_context = await _get_edge_data(
|
|
|
1286 |
relationships_vdb,
|
1287 |
text_chunks_db,
|
1288 |
query_param,
|
1289 |
+
global_config,
|
1290 |
)
|
1291 |
else: # hybrid mode
|
1292 |
ll_data = await _get_node_data(
|
|
|
1295 |
entities_vdb,
|
1296 |
text_chunks_db,
|
1297 |
query_param,
|
1298 |
+
global_config,
|
1299 |
)
|
1300 |
hl_data = await _get_edge_data(
|
1301 |
hl_keywords,
|
|
|
1303 |
relationships_vdb,
|
1304 |
text_chunks_db,
|
1305 |
query_param,
|
1306 |
+
global_config,
|
1307 |
)
|
1308 |
|
1309 |
(
|
|
|
1350 |
entities_vdb: BaseVectorStorage,
|
1351 |
text_chunks_db: BaseKVStorage,
|
1352 |
query_param: QueryParam,
|
1353 |
+
global_config: dict[str, str],
|
1354 |
):
|
1355 |
# get similar entities
|
1356 |
logger.info(
|
|
|
1387 |
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
1388 |
# get entitytext chunk
|
1389 |
use_text_units = await _find_most_related_text_unit_from_entities(
|
1390 |
+
node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
|
1391 |
)
|
1392 |
use_relations = await _find_most_related_edges_from_entities(
|
1393 |
+
node_datas, query_param, knowledge_graph_inst, global_config
|
1394 |
)
|
1395 |
|
1396 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1397 |
len_node_datas = len(node_datas)
|
1398 |
node_datas = truncate_list_by_token_size(
|
1399 |
node_datas,
|
1400 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1401 |
max_token_size=query_param.max_token_for_local_context,
|
1402 |
+
tokenizer=tokenizer,
|
1403 |
)
|
1404 |
logger.debug(
|
1405 |
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
|
|
1493 |
query_param: QueryParam,
|
1494 |
text_chunks_db: BaseKVStorage,
|
1495 |
knowledge_graph_inst: BaseGraphStorage,
|
1496 |
+
global_config: dict[str, str],
|
1497 |
):
|
1498 |
text_units = [
|
1499 |
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
|
|
1575 |
logger.warning("No valid text units found")
|
1576 |
return []
|
1577 |
|
1578 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1579 |
all_text_units = sorted(
|
1580 |
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
1581 |
)
|
|
|
1582 |
all_text_units = truncate_list_by_token_size(
|
1583 |
all_text_units,
|
1584 |
key=lambda x: x["data"]["content"],
|
1585 |
max_token_size=query_param.max_token_for_text_unit,
|
1586 |
+
tokenizer=tokenizer,
|
1587 |
)
|
1588 |
|
1589 |
logger.debug(
|
|
|
1598 |
node_datas: list[dict],
|
1599 |
query_param: QueryParam,
|
1600 |
knowledge_graph_inst: BaseGraphStorage,
|
1601 |
+
global_config: dict[str, str],
|
1602 |
):
|
1603 |
node_names = [dp["entity_name"] for dp in node_datas]
|
1604 |
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
|
|
1638 |
}
|
1639 |
all_edges_data.append(combined)
|
1640 |
|
1641 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1642 |
all_edges_data = sorted(
|
1643 |
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1644 |
)
|
|
|
1646 |
all_edges_data,
|
1647 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1648 |
max_token_size=query_param.max_token_for_global_context,
|
1649 |
+
tokenizer=tokenizer,
|
1650 |
)
|
1651 |
|
1652 |
logger.debug(
|
|
|
1662 |
relationships_vdb: BaseVectorStorage,
|
1663 |
text_chunks_db: BaseKVStorage,
|
1664 |
query_param: QueryParam,
|
1665 |
+
global_config: dict[str, str],
|
1666 |
):
|
1667 |
logger.info(
|
1668 |
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
|
|
1703 |
}
|
1704 |
edge_datas.append(combined)
|
1705 |
|
1706 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1707 |
edge_datas = sorted(
|
1708 |
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1709 |
)
|
|
|
1711 |
edge_datas,
|
1712 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1713 |
max_token_size=query_param.max_token_for_global_context,
|
1714 |
+
tokenizer=tokenizer,
|
1715 |
)
|
1716 |
use_entities, use_text_units = await asyncio.gather(
|
1717 |
_find_most_related_entities_from_relationships(
|
1718 |
+
edge_datas, query_param, knowledge_graph_inst, global_config
|
1719 |
),
|
1720 |
_find_related_text_unit_from_relationships(
|
1721 |
+
edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
|
1722 |
),
|
1723 |
)
|
1724 |
logger.info(
|
|
|
1798 |
edge_datas: list[dict],
|
1799 |
query_param: QueryParam,
|
1800 |
knowledge_graph_inst: BaseGraphStorage,
|
1801 |
+
global_config: dict[str, str],
|
1802 |
):
|
1803 |
entity_names = []
|
1804 |
seen = set()
|
|
|
1829 |
combined = {**node, "entity_name": entity_name, "rank": degree}
|
1830 |
node_datas.append(combined)
|
1831 |
|
1832 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1833 |
len_node_datas = len(node_datas)
|
1834 |
node_datas = truncate_list_by_token_size(
|
1835 |
node_datas,
|
1836 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1837 |
max_token_size=query_param.max_token_for_local_context,
|
1838 |
+
tokenizer=tokenizer,
|
1839 |
)
|
1840 |
logger.debug(
|
1841 |
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
|
|
1849 |
query_param: QueryParam,
|
1850 |
text_chunks_db: BaseKVStorage,
|
1851 |
knowledge_graph_inst: BaseGraphStorage,
|
1852 |
+
global_config: dict[str, str],
|
1853 |
):
|
1854 |
text_units = [
|
1855 |
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
|
|
1891 |
logger.warning("No valid text chunks after filtering")
|
1892 |
return []
|
1893 |
|
1894 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1895 |
truncated_text_units = truncate_list_by_token_size(
|
1896 |
valid_text_units,
|
1897 |
key=lambda x: x["data"]["content"],
|
1898 |
max_token_size=query_param.max_token_for_text_unit,
|
1899 |
+
tokenizer=tokenizer,
|
1900 |
)
|
1901 |
|
1902 |
logger.debug(
|
|
|
1967 |
logger.warning("No valid chunks found after filtering")
|
1968 |
return PROMPTS["fail_response"]
|
1969 |
|
1970 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1971 |
maybe_trun_chunks = truncate_list_by_token_size(
|
1972 |
valid_chunks,
|
1973 |
key=lambda x: x["content"],
|
1974 |
max_token_size=query_param.max_token_for_text_unit,
|
1975 |
+
tokenizer=tokenizer,
|
1976 |
)
|
1977 |
|
1978 |
if not maybe_trun_chunks:
|
|
|
2010 |
if query_param.only_need_prompt:
|
2011 |
return sys_prompt
|
2012 |
|
|
|
2013 |
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
2014 |
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
2015 |
|
|
|
2128 |
relationships_vdb,
|
2129 |
text_chunks_db,
|
2130 |
query_param,
|
2131 |
+
global_config,
|
2132 |
)
|
2133 |
if not context:
|
2134 |
return PROMPTS["fail_response"]
|
lightrag/utils.py
CHANGED
@@ -424,7 +424,7 @@ def is_float_regex(value: str) -> bool:
|
|
424 |
|
425 |
|
426 |
def truncate_list_by_token_size(
|
427 |
-
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
428 |
) -> list[int]:
|
429 |
"""Truncate a list of data by token size"""
|
430 |
if max_token_size <= 0:
|
|
|
424 |
|
425 |
|
426 |
def truncate_list_by_token_size(
|
427 |
+
list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
|
428 |
) -> list[int]:
|
429 |
"""Truncate a list of data by token size"""
|
430 |
if max_token_size <= 0:
|