drahnreb commited on
Commit
f47a464
·
1 Parent(s): 4b7275e

fix: take global_config from storage class

Browse files
Files changed (2) hide show
  1. lightrag/lightrag.py +1 -1
  2. lightrag/operate.py +10 -25
lightrag/lightrag.py CHANGED
@@ -7,7 +7,7 @@ import warnings
7
  from dataclasses import asdict, dataclass, field
8
  from datetime import datetime
9
  from functools import partial
10
- from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal
11
 
12
  from lightrag.kg import (
13
  STORAGES,
 
7
  from dataclasses import asdict, dataclass, field
8
  from datetime import datetime
9
  from functools import partial
10
+ from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
11
 
12
  from lightrag.kg import (
13
  STORAGES,
lightrag/operate.py CHANGED
@@ -116,7 +116,6 @@ async def _handle_entity_relation_summary(
116
  use_llm_func: callable = global_config["llm_model_func"]
117
  tokenizer: Tokenizer = global_config["tokenizer"]
118
  llm_max_tokens = global_config["llm_model_max_token_size"]
119
- tiktoken_model_name = global_config["tiktoken_model_name"]
120
  summary_max_tokens = global_config["summary_to_max_tokens"]
121
 
122
  language = global_config["addon_params"].get(
@@ -842,7 +841,6 @@ async def kg_query(
842
  relationships_vdb,
843
  text_chunks_db,
844
  query_param,
845
- global_config,
846
  )
847
 
848
  if query_param.only_need_context:
@@ -1114,7 +1112,6 @@ async def mix_kg_vector_query(
1114
  relationships_vdb,
1115
  text_chunks_db,
1116
  query_param,
1117
- global_config,
1118
  )
1119
 
1120
  return context
@@ -1267,7 +1264,6 @@ async def _build_query_context(
1267
  relationships_vdb: BaseVectorStorage,
1268
  text_chunks_db: BaseKVStorage,
1269
  query_param: QueryParam,
1270
- global_config: dict[str, str],
1271
  ):
1272
  logger.info(f"Process {os.getpid()} buidling query context...")
1273
  if query_param.mode == "local":
@@ -1277,7 +1273,6 @@ async def _build_query_context(
1277
  entities_vdb,
1278
  text_chunks_db,
1279
  query_param,
1280
- global_config,
1281
  )
1282
  elif query_param.mode == "global":
1283
  entities_context, relations_context, text_units_context = await _get_edge_data(
@@ -1286,7 +1281,6 @@ async def _build_query_context(
1286
  relationships_vdb,
1287
  text_chunks_db,
1288
  query_param,
1289
- global_config,
1290
  )
1291
  else: # hybrid mode
1292
  ll_data = await _get_node_data(
@@ -1295,7 +1289,6 @@ async def _build_query_context(
1295
  entities_vdb,
1296
  text_chunks_db,
1297
  query_param,
1298
- global_config,
1299
  )
1300
  hl_data = await _get_edge_data(
1301
  hl_keywords,
@@ -1303,7 +1296,6 @@ async def _build_query_context(
1303
  relationships_vdb,
1304
  text_chunks_db,
1305
  query_param,
1306
- global_config,
1307
  )
1308
 
1309
  (
@@ -1350,7 +1342,6 @@ async def _get_node_data(
1350
  entities_vdb: BaseVectorStorage,
1351
  text_chunks_db: BaseKVStorage,
1352
  query_param: QueryParam,
1353
- global_config: dict[str, str],
1354
  ):
1355
  # get similar entities
1356
  logger.info(
@@ -1387,13 +1378,13 @@ async def _get_node_data(
1387
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1388
  # get entitytext chunk
1389
  use_text_units = await _find_most_related_text_unit_from_entities(
1390
- node_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
1391
  )
1392
  use_relations = await _find_most_related_edges_from_entities(
1393
- node_datas, query_param, knowledge_graph_inst, global_config
1394
  )
1395
 
1396
- tokenizer: Tokenizer = global_config["tokenizer"]
1397
  len_node_datas = len(node_datas)
1398
  node_datas = truncate_list_by_token_size(
1399
  node_datas,
@@ -1493,7 +1484,6 @@ async def _find_most_related_text_unit_from_entities(
1493
  query_param: QueryParam,
1494
  text_chunks_db: BaseKVStorage,
1495
  knowledge_graph_inst: BaseGraphStorage,
1496
- global_config: dict[str, str],
1497
  ):
1498
  text_units = [
1499
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1575,7 +1565,7 @@ async def _find_most_related_text_unit_from_entities(
1575
  logger.warning("No valid text units found")
1576
  return []
1577
 
1578
- tokenizer: Tokenizer = global_config["tokenizer"]
1579
  all_text_units = sorted(
1580
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
1581
  )
@@ -1598,7 +1588,6 @@ async def _find_most_related_edges_from_entities(
1598
  node_datas: list[dict],
1599
  query_param: QueryParam,
1600
  knowledge_graph_inst: BaseGraphStorage,
1601
- global_config: dict[str, str],
1602
  ):
1603
  node_names = [dp["entity_name"] for dp in node_datas]
1604
  batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
@@ -1638,7 +1627,7 @@ async def _find_most_related_edges_from_entities(
1638
  }
1639
  all_edges_data.append(combined)
1640
 
1641
- tokenizer: Tokenizer = global_config["tokenizer"]
1642
  all_edges_data = sorted(
1643
  all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
1644
  )
@@ -1662,7 +1651,6 @@ async def _get_edge_data(
1662
  relationships_vdb: BaseVectorStorage,
1663
  text_chunks_db: BaseKVStorage,
1664
  query_param: QueryParam,
1665
- global_config: dict[str, str],
1666
  ):
1667
  logger.info(
1668
  f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
@@ -1703,7 +1691,7 @@ async def _get_edge_data(
1703
  }
1704
  edge_datas.append(combined)
1705
 
1706
- tokenizer: Tokenizer = global_config["tokenizer"]
1707
  edge_datas = sorted(
1708
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1709
  )
@@ -1715,10 +1703,10 @@ async def _get_edge_data(
1715
  )
1716
  use_entities, use_text_units = await asyncio.gather(
1717
  _find_most_related_entities_from_relationships(
1718
- edge_datas, query_param, knowledge_graph_inst, global_config
1719
  ),
1720
  _find_related_text_unit_from_relationships(
1721
- edge_datas, query_param, text_chunks_db, knowledge_graph_inst, global_config
1722
  ),
1723
  )
1724
  logger.info(
@@ -1798,7 +1786,6 @@ async def _find_most_related_entities_from_relationships(
1798
  edge_datas: list[dict],
1799
  query_param: QueryParam,
1800
  knowledge_graph_inst: BaseGraphStorage,
1801
- global_config: dict[str, str],
1802
  ):
1803
  entity_names = []
1804
  seen = set()
@@ -1829,7 +1816,7 @@ async def _find_most_related_entities_from_relationships(
1829
  combined = {**node, "entity_name": entity_name, "rank": degree}
1830
  node_datas.append(combined)
1831
 
1832
- tokenizer: Tokenizer = global_config["tokenizer"]
1833
  len_node_datas = len(node_datas)
1834
  node_datas = truncate_list_by_token_size(
1835
  node_datas,
@@ -1849,7 +1836,6 @@ async def _find_related_text_unit_from_relationships(
1849
  query_param: QueryParam,
1850
  text_chunks_db: BaseKVStorage,
1851
  knowledge_graph_inst: BaseGraphStorage,
1852
- global_config: dict[str, str],
1853
  ):
1854
  text_units = [
1855
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
@@ -1891,7 +1877,7 @@ async def _find_related_text_unit_from_relationships(
1891
  logger.warning("No valid text chunks after filtering")
1892
  return []
1893
 
1894
- tokenizer: Tokenizer = global_config["tokenizer"]
1895
  truncated_text_units = truncate_list_by_token_size(
1896
  valid_text_units,
1897
  key=lambda x: x["data"]["content"],
@@ -2128,7 +2114,6 @@ async def kg_query_with_keywords(
2128
  relationships_vdb,
2129
  text_chunks_db,
2130
  query_param,
2131
- global_config,
2132
  )
2133
  if not context:
2134
  return PROMPTS["fail_response"]
 
116
  use_llm_func: callable = global_config["llm_model_func"]
117
  tokenizer: Tokenizer = global_config["tokenizer"]
118
  llm_max_tokens = global_config["llm_model_max_token_size"]
 
119
  summary_max_tokens = global_config["summary_to_max_tokens"]
120
 
121
  language = global_config["addon_params"].get(
 
841
  relationships_vdb,
842
  text_chunks_db,
843
  query_param,
 
844
  )
845
 
846
  if query_param.only_need_context:
 
1112
  relationships_vdb,
1113
  text_chunks_db,
1114
  query_param,
 
1115
  )
1116
 
1117
  return context
 
1264
  relationships_vdb: BaseVectorStorage,
1265
  text_chunks_db: BaseKVStorage,
1266
  query_param: QueryParam,
 
1267
  ):
1268
  logger.info(f"Process {os.getpid()} buidling query context...")
1269
  if query_param.mode == "local":
 
1273
  entities_vdb,
1274
  text_chunks_db,
1275
  query_param,
 
1276
  )
1277
  elif query_param.mode == "global":
1278
  entities_context, relations_context, text_units_context = await _get_edge_data(
 
1281
  relationships_vdb,
1282
  text_chunks_db,
1283
  query_param,
 
1284
  )
1285
  else: # hybrid mode
1286
  ll_data = await _get_node_data(
 
1289
  entities_vdb,
1290
  text_chunks_db,
1291
  query_param,
 
1292
  )
1293
  hl_data = await _get_edge_data(
1294
  hl_keywords,
 
1296
  relationships_vdb,
1297
  text_chunks_db,
1298
  query_param,
 
1299
  )
1300
 
1301
  (
 
1342
  entities_vdb: BaseVectorStorage,
1343
  text_chunks_db: BaseKVStorage,
1344
  query_param: QueryParam,
 
1345
  ):
1346
  # get similar entities
1347
  logger.info(
 
1378
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1379
  # get entitytext chunk
1380
  use_text_units = await _find_most_related_text_unit_from_entities(
1381
+ node_datas, query_param, text_chunks_db, knowledge_graph_inst,
1382
  )
1383
  use_relations = await _find_most_related_edges_from_entities(
1384
+ node_datas, query_param, knowledge_graph_inst,
1385
  )
1386
 
1387
+ tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
1388
  len_node_datas = len(node_datas)
1389
  node_datas = truncate_list_by_token_size(
1390
  node_datas,
 
1484
  query_param: QueryParam,
1485
  text_chunks_db: BaseKVStorage,
1486
  knowledge_graph_inst: BaseGraphStorage,
 
1487
  ):
1488
  text_units = [
1489
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
 
1565
  logger.warning("No valid text units found")
1566
  return []
1567
 
1568
+ tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
1569
  all_text_units = sorted(
1570
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
1571
  )
 
1588
  node_datas: list[dict],
1589
  query_param: QueryParam,
1590
  knowledge_graph_inst: BaseGraphStorage,
 
1591
  ):
1592
  node_names = [dp["entity_name"] for dp in node_datas]
1593
  batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
 
1627
  }
1628
  all_edges_data.append(combined)
1629
 
1630
+ tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
1631
  all_edges_data = sorted(
1632
  all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
1633
  )
 
1651
  relationships_vdb: BaseVectorStorage,
1652
  text_chunks_db: BaseKVStorage,
1653
  query_param: QueryParam,
 
1654
  ):
1655
  logger.info(
1656
  f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
 
1691
  }
1692
  edge_datas.append(combined)
1693
 
1694
+ tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
1695
  edge_datas = sorted(
1696
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1697
  )
 
1703
  )
1704
  use_entities, use_text_units = await asyncio.gather(
1705
  _find_most_related_entities_from_relationships(
1706
+ edge_datas, query_param, knowledge_graph_inst,
1707
  ),
1708
  _find_related_text_unit_from_relationships(
1709
+ edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
1710
  ),
1711
  )
1712
  logger.info(
 
1786
  edge_datas: list[dict],
1787
  query_param: QueryParam,
1788
  knowledge_graph_inst: BaseGraphStorage,
 
1789
  ):
1790
  entity_names = []
1791
  seen = set()
 
1816
  combined = {**node, "entity_name": entity_name, "rank": degree}
1817
  node_datas.append(combined)
1818
 
1819
+ tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
1820
  len_node_datas = len(node_datas)
1821
  node_datas = truncate_list_by_token_size(
1822
  node_datas,
 
1836
  query_param: QueryParam,
1837
  text_chunks_db: BaseKVStorage,
1838
  knowledge_graph_inst: BaseGraphStorage,
 
1839
  ):
1840
  text_units = [
1841
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
 
1877
  logger.warning("No valid text chunks after filtering")
1878
  return []
1879
 
1880
+ tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
1881
  truncated_text_units = truncate_list_by_token_size(
1882
  valid_text_units,
1883
  key=lambda x: x["data"]["content"],
 
2114
  relationships_vdb,
2115
  text_chunks_db,
2116
  query_param,
 
2117
  )
2118
  if not context:
2119
  return PROMPTS["fail_response"]