dimatill commited on
Commit
b8ce148
·
1 Parent(s): 2a71867

asyncio optimizations

Browse files
Files changed (1) hide show
  1. lightrag/operate.py +106 -67
lightrag/operate.py CHANGED
@@ -941,28 +941,35 @@ async def _build_query_context(
941
  query_param,
942
  )
943
  else: # hybrid mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  (
945
  ll_entities_context,
946
  ll_relations_context,
947
  ll_text_units_context,
948
- ) = await _get_node_data(
949
- ll_keywords,
950
- knowledge_graph_inst,
951
- entities_vdb,
952
- text_chunks_db,
953
- query_param,
954
- )
955
  (
956
  hl_entities_context,
957
  hl_relations_context,
958
  hl_text_units_context,
959
- ) = await _get_edge_data(
960
- hl_keywords,
961
- knowledge_graph_inst,
962
- relationships_vdb,
963
- text_chunks_db,
964
- query_param,
965
- )
966
  entities_context, relations_context, text_units_context = combine_contexts(
967
  [hl_entities_context, ll_entities_context],
968
  [hl_relations_context, ll_relations_context],
@@ -996,28 +1003,31 @@ async def _get_node_data(
996
  if not len(results):
997
  return "", "", ""
998
  # get entity information
999
- node_datas = await asyncio.gather(
1000
- *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
 
 
 
 
 
1001
  )
 
1002
  if not all([n is not None for n in node_datas]):
1003
  logger.warning("Some nodes are missing, maybe the storage is damaged")
1004
 
1005
- # get entity degree
1006
- node_degrees = await asyncio.gather(
1007
- *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
1008
- )
1009
  node_datas = [
1010
  {**n, "entity_name": k["entity_name"], "rank": d}
1011
  for k, n, d in zip(results, node_datas, node_degrees)
1012
  if n is not None
1013
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1014
  # get entitytext chunk
1015
- use_text_units = await _find_most_related_text_unit_from_entities(
1016
- node_datas, query_param, text_chunks_db, knowledge_graph_inst
1017
- )
1018
- # get relate edges
1019
- use_relations = await _find_most_related_edges_from_entities(
1020
- node_datas, query_param, knowledge_graph_inst
 
1021
  )
1022
  logger.info(
1023
  f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
@@ -1107,22 +1117,30 @@ async def _find_most_related_text_unit_from_entities(
1107
  }
1108
 
1109
  all_text_units_lookup = {}
 
1110
  for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
1111
  for c_id in this_text_units:
1112
  if c_id not in all_text_units_lookup:
1113
- all_text_units_lookup[c_id] = {
1114
- "data": await text_chunks_db.get_by_id(c_id),
1115
- "order": index,
1116
- "relation_counts": 0,
1117
- }
1118
 
1119
- if this_edges:
1120
- for e in this_edges:
1121
- if (
1122
- e[1] in all_one_hop_text_units_lookup
1123
- and c_id in all_one_hop_text_units_lookup[e[1]]
1124
- ):
1125
- all_text_units_lookup[c_id]["relation_counts"] += 1
 
 
 
 
 
 
 
 
 
 
 
1126
 
1127
  # Filter out None values and ensure data has content
1128
  all_text_units = [
@@ -1167,11 +1185,11 @@ async def _find_most_related_edges_from_entities(
1167
  seen.add(sorted_edge)
1168
  all_edges.append(sorted_edge)
1169
 
1170
- all_edges_pack = await asyncio.gather(
1171
- *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
1172
- )
1173
- all_edges_degree = await asyncio.gather(
1174
- *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
1175
  )
1176
  all_edges_data = [
1177
  {"src_tgt": k, "rank": d, **v}
@@ -1201,15 +1219,21 @@ async def _get_edge_data(
1201
  if not len(results):
1202
  return "", "", ""
1203
 
1204
- edge_datas = await asyncio.gather(
1205
- *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
 
 
 
 
 
 
 
 
1206
  )
1207
 
1208
  if not all([n is not None for n in edge_datas]):
1209
  logger.warning("Some edges are missing, maybe the storage is damaged")
1210
- edge_degree = await asyncio.gather(
1211
- *[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
1212
- )
1213
  edge_datas = [
1214
  {
1215
  "src_id": k["src_id"],
@@ -1230,11 +1254,13 @@ async def _get_edge_data(
1230
  max_token_size=query_param.max_token_for_global_context,
1231
  )
1232
 
1233
- use_entities = await _find_most_related_entities_from_relationships(
1234
- edge_datas, query_param, knowledge_graph_inst
1235
- )
1236
- use_text_units = await _find_related_text_unit_from_relationships(
1237
- edge_datas, query_param, text_chunks_db, knowledge_graph_inst
 
 
1238
  )
1239
  logger.info(
1240
  f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
@@ -1307,12 +1333,19 @@ async def _find_most_related_entities_from_relationships(
1307
  entity_names.append(e["tgt_id"])
1308
  seen.add(e["tgt_id"])
1309
 
1310
- node_datas = await asyncio.gather(
1311
- *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
1312
- )
1313
-
1314
- node_degrees = await asyncio.gather(
1315
- *[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
 
 
 
 
 
 
 
1316
  )
1317
  node_datas = [
1318
  {**n, "entity_name": k, "rank": d}
@@ -1340,16 +1373,22 @@ async def _find_related_text_unit_from_relationships(
1340
  ]
1341
  all_text_units_lookup = {}
1342
 
 
 
 
 
 
 
 
 
 
 
 
1343
  for index, unit_list in enumerate(text_units):
1344
  for c_id in unit_list:
1345
- if c_id not in all_text_units_lookup:
1346
- chunk_data = await text_chunks_db.get_by_id(c_id)
1347
- # Only store valid data
1348
- if chunk_data is not None and "content" in chunk_data:
1349
- all_text_units_lookup[c_id] = {
1350
- "data": chunk_data,
1351
- "order": index,
1352
- }
1353
 
1354
  if not all_text_units_lookup:
1355
  logger.warning("No valid text chunks found")
 
941
  query_param,
942
  )
943
  else: # hybrid mode
944
+ ll_data, hl_data = await asyncio.gather(
945
+ _get_node_data(
946
+ ll_keywords,
947
+ knowledge_graph_inst,
948
+ entities_vdb,
949
+ text_chunks_db,
950
+ query_param,
951
+ ),
952
+ _get_edge_data(
953
+ hl_keywords,
954
+ knowledge_graph_inst,
955
+ relationships_vdb,
956
+ text_chunks_db,
957
+ query_param,
958
+ ),
959
+ )
960
+
961
  (
962
  ll_entities_context,
963
  ll_relations_context,
964
  ll_text_units_context,
965
+ ) = ll_data
966
+
 
 
 
 
 
967
  (
968
  hl_entities_context,
969
  hl_relations_context,
970
  hl_text_units_context,
971
+ ) = hl_data
972
+
 
 
 
 
 
973
  entities_context, relations_context, text_units_context = combine_contexts(
974
  [hl_entities_context, ll_entities_context],
975
  [hl_relations_context, ll_relations_context],
 
1003
  if not len(results):
1004
  return "", "", ""
1005
  # get entity information
1006
+ node_datas, node_degrees = await asyncio.gather(
1007
+ asyncio.gather(
1008
+ *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
1009
+ ),
1010
+ asyncio.gather(
1011
+ *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
1012
+ ),
1013
  )
1014
+
1015
  if not all([n is not None for n in node_datas]):
1016
  logger.warning("Some nodes are missing, maybe the storage is damaged")
1017
 
 
 
 
 
1018
  node_datas = [
1019
  {**n, "entity_name": k["entity_name"], "rank": d}
1020
  for k, n, d in zip(results, node_datas, node_degrees)
1021
  if n is not None
1022
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1023
  # get entitytext chunk
1024
+ use_text_units, use_relations = await asyncio.gather(
1025
+ _find_most_related_text_unit_from_entities(
1026
+ node_datas, query_param, text_chunks_db, knowledge_graph_inst
1027
+ ),
1028
+ _find_most_related_edges_from_entities(
1029
+ node_datas, query_param, knowledge_graph_inst
1030
+ ),
1031
  )
1032
  logger.info(
1033
  f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
 
1117
  }
1118
 
1119
  all_text_units_lookup = {}
1120
+ tasks = []
1121
  for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
1122
  for c_id in this_text_units:
1123
  if c_id not in all_text_units_lookup:
1124
+ tasks.append((c_id, index, this_edges))
 
 
 
 
1125
 
1126
+ results = await asyncio.gather(
1127
+ *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks]
1128
+ )
1129
+
1130
+ for (c_id, index, this_edges), data in zip(tasks, results):
1131
+ all_text_units_lookup[c_id] = {
1132
+ "data": data,
1133
+ "order": index,
1134
+ "relation_counts": 0,
1135
+ }
1136
+
1137
+ if this_edges:
1138
+ for e in this_edges:
1139
+ if (
1140
+ e[1] in all_one_hop_text_units_lookup
1141
+ and c_id in all_one_hop_text_units_lookup[e[1]]
1142
+ ):
1143
+ all_text_units_lookup[c_id]["relation_counts"] += 1
1144
 
1145
  # Filter out None values and ensure data has content
1146
  all_text_units = [
 
1185
  seen.add(sorted_edge)
1186
  all_edges.append(sorted_edge)
1187
 
1188
+ all_edges_pack, all_edges_degree = await asyncio.gather(
1189
+ asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
1190
+ asyncio.gather(
1191
+ *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
1192
+ ),
1193
  )
1194
  all_edges_data = [
1195
  {"src_tgt": k, "rank": d, **v}
 
1219
  if not len(results):
1220
  return "", "", ""
1221
 
1222
+ edge_datas, edge_degree = await asyncio.gather(
1223
+ asyncio.gather(
1224
+ *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
1225
+ ),
1226
+ asyncio.gather(
1227
+ *[
1228
+ knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
1229
+ for r in results
1230
+ ]
1231
+ ),
1232
  )
1233
 
1234
  if not all([n is not None for n in edge_datas]):
1235
  logger.warning("Some edges are missing, maybe the storage is damaged")
1236
+
 
 
1237
  edge_datas = [
1238
  {
1239
  "src_id": k["src_id"],
 
1254
  max_token_size=query_param.max_token_for_global_context,
1255
  )
1256
 
1257
+ use_entities, use_text_units = await asyncio.gather(
1258
+ _find_most_related_entities_from_relationships(
1259
+ edge_datas, query_param, knowledge_graph_inst
1260
+ ),
1261
+ _find_related_text_unit_from_relationships(
1262
+ edge_datas, query_param, text_chunks_db, knowledge_graph_inst
1263
+ ),
1264
  )
1265
  logger.info(
1266
  f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
 
1333
  entity_names.append(e["tgt_id"])
1334
  seen.add(e["tgt_id"])
1335
 
1336
+ node_datas, node_degrees = await asyncio.gather(
1337
+ asyncio.gather(
1338
+ *[
1339
+ knowledge_graph_inst.get_node(entity_name)
1340
+ for entity_name in entity_names
1341
+ ]
1342
+ ),
1343
+ asyncio.gather(
1344
+ *[
1345
+ knowledge_graph_inst.node_degree(entity_name)
1346
+ for entity_name in entity_names
1347
+ ]
1348
+ ),
1349
  )
1350
  node_datas = [
1351
  {**n, "entity_name": k, "rank": d}
 
1373
  ]
1374
  all_text_units_lookup = {}
1375
 
1376
+ async def fetch_chunk_data(c_id, index):
1377
+ if c_id not in all_text_units_lookup:
1378
+ chunk_data = await text_chunks_db.get_by_id(c_id)
1379
+ # Only store valid data
1380
+ if chunk_data is not None and "content" in chunk_data:
1381
+ all_text_units_lookup[c_id] = {
1382
+ "data": chunk_data,
1383
+ "order": index,
1384
+ }
1385
+
1386
+ tasks = []
1387
  for index, unit_list in enumerate(text_units):
1388
  for c_id in unit_list:
1389
+ tasks.append(fetch_chunk_data(c_id, index))
1390
+
1391
+ await asyncio.gather(*tasks)
 
 
 
 
 
1392
 
1393
  if not all_text_units_lookup:
1394
  logger.warning("No valid text chunks found")