zrguo commited on
Commit
ee9eb3b
·
unverified ·
2 Parent(s): fa9d5f4 b8ce148

Merge pull request #642 from dimatill/main

Browse files
Files changed (1) hide show
  1. lightrag/operate.py +106 -67
lightrag/operate.py CHANGED
@@ -990,28 +990,35 @@ async def _build_query_context(
990
  query_param,
991
  )
992
  else: # hybrid mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993
  (
994
  ll_entities_context,
995
  ll_relations_context,
996
  ll_text_units_context,
997
- ) = await _get_node_data(
998
- ll_keywords,
999
- knowledge_graph_inst,
1000
- entities_vdb,
1001
- text_chunks_db,
1002
- query_param,
1003
- )
1004
  (
1005
  hl_entities_context,
1006
  hl_relations_context,
1007
  hl_text_units_context,
1008
- ) = await _get_edge_data(
1009
- hl_keywords,
1010
- knowledge_graph_inst,
1011
- relationships_vdb,
1012
- text_chunks_db,
1013
- query_param,
1014
- )
1015
  entities_context, relations_context, text_units_context = combine_contexts(
1016
  [hl_entities_context, ll_entities_context],
1017
  [hl_relations_context, ll_relations_context],
@@ -1045,28 +1052,31 @@ async def _get_node_data(
1045
  if not len(results):
1046
  return "", "", ""
1047
  # get entity information
1048
- node_datas = await asyncio.gather(
1049
- *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
 
 
 
 
 
1050
  )
 
1051
  if not all([n is not None for n in node_datas]):
1052
  logger.warning("Some nodes are missing, maybe the storage is damaged")
1053
 
1054
- # get entity degree
1055
- node_degrees = await asyncio.gather(
1056
- *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
1057
- )
1058
  node_datas = [
1059
  {**n, "entity_name": k["entity_name"], "rank": d}
1060
  for k, n, d in zip(results, node_datas, node_degrees)
1061
  if n is not None
1062
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1063
  # get entitytext chunk
1064
- use_text_units = await _find_most_related_text_unit_from_entities(
1065
- node_datas, query_param, text_chunks_db, knowledge_graph_inst
1066
- )
1067
- # get relate edges
1068
- use_relations = await _find_most_related_edges_from_entities(
1069
- node_datas, query_param, knowledge_graph_inst
 
1070
  )
1071
  logger.info(
1072
  f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
@@ -1156,22 +1166,30 @@ async def _find_most_related_text_unit_from_entities(
1156
  }
1157
 
1158
  all_text_units_lookup = {}
 
1159
  for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
1160
  for c_id in this_text_units:
1161
  if c_id not in all_text_units_lookup:
1162
- all_text_units_lookup[c_id] = {
1163
- "data": await text_chunks_db.get_by_id(c_id),
1164
- "order": index,
1165
- "relation_counts": 0,
1166
- }
1167
 
1168
- if this_edges:
1169
- for e in this_edges:
1170
- if (
1171
- e[1] in all_one_hop_text_units_lookup
1172
- and c_id in all_one_hop_text_units_lookup[e[1]]
1173
- ):
1174
- all_text_units_lookup[c_id]["relation_counts"] += 1
 
 
 
 
 
 
 
 
 
 
 
1175
 
1176
  # Filter out None values and ensure data has content
1177
  all_text_units = [
@@ -1216,11 +1234,11 @@ async def _find_most_related_edges_from_entities(
1216
  seen.add(sorted_edge)
1217
  all_edges.append(sorted_edge)
1218
 
1219
- all_edges_pack = await asyncio.gather(
1220
- *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
1221
- )
1222
- all_edges_degree = await asyncio.gather(
1223
- *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
1224
  )
1225
  all_edges_data = [
1226
  {"src_tgt": k, "rank": d, **v}
@@ -1250,15 +1268,21 @@ async def _get_edge_data(
1250
  if not len(results):
1251
  return "", "", ""
1252
 
1253
- edge_datas = await asyncio.gather(
1254
- *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
 
 
 
 
 
 
 
 
1255
  )
1256
 
1257
  if not all([n is not None for n in edge_datas]):
1258
  logger.warning("Some edges are missing, maybe the storage is damaged")
1259
- edge_degree = await asyncio.gather(
1260
- *[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
1261
- )
1262
  edge_datas = [
1263
  {
1264
  "src_id": k["src_id"],
@@ -1279,11 +1303,13 @@ async def _get_edge_data(
1279
  max_token_size=query_param.max_token_for_global_context,
1280
  )
1281
 
1282
- use_entities = await _find_most_related_entities_from_relationships(
1283
- edge_datas, query_param, knowledge_graph_inst
1284
- )
1285
- use_text_units = await _find_related_text_unit_from_relationships(
1286
- edge_datas, query_param, text_chunks_db, knowledge_graph_inst
 
 
1287
  )
1288
  logger.info(
1289
  f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
@@ -1356,12 +1382,19 @@ async def _find_most_related_entities_from_relationships(
1356
  entity_names.append(e["tgt_id"])
1357
  seen.add(e["tgt_id"])
1358
 
1359
- node_datas = await asyncio.gather(
1360
- *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
1361
- )
1362
-
1363
- node_degrees = await asyncio.gather(
1364
- *[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
 
 
 
 
 
 
 
1365
  )
1366
  node_datas = [
1367
  {**n, "entity_name": k, "rank": d}
@@ -1389,16 +1422,22 @@ async def _find_related_text_unit_from_relationships(
1389
  ]
1390
  all_text_units_lookup = {}
1391
 
 
 
 
 
 
 
 
 
 
 
 
1392
  for index, unit_list in enumerate(text_units):
1393
  for c_id in unit_list:
1394
- if c_id not in all_text_units_lookup:
1395
- chunk_data = await text_chunks_db.get_by_id(c_id)
1396
- # Only store valid data
1397
- if chunk_data is not None and "content" in chunk_data:
1398
- all_text_units_lookup[c_id] = {
1399
- "data": chunk_data,
1400
- "order": index,
1401
- }
1402
 
1403
  if not all_text_units_lookup:
1404
  logger.warning("No valid text chunks found")
 
990
  query_param,
991
  )
992
  else: # hybrid mode
993
+ ll_data, hl_data = await asyncio.gather(
994
+ _get_node_data(
995
+ ll_keywords,
996
+ knowledge_graph_inst,
997
+ entities_vdb,
998
+ text_chunks_db,
999
+ query_param,
1000
+ ),
1001
+ _get_edge_data(
1002
+ hl_keywords,
1003
+ knowledge_graph_inst,
1004
+ relationships_vdb,
1005
+ text_chunks_db,
1006
+ query_param,
1007
+ ),
1008
+ )
1009
+
1010
  (
1011
  ll_entities_context,
1012
  ll_relations_context,
1013
  ll_text_units_context,
1014
+ ) = ll_data
1015
+
 
 
 
 
 
1016
  (
1017
  hl_entities_context,
1018
  hl_relations_context,
1019
  hl_text_units_context,
1020
+ ) = hl_data
1021
+
 
 
 
 
 
1022
  entities_context, relations_context, text_units_context = combine_contexts(
1023
  [hl_entities_context, ll_entities_context],
1024
  [hl_relations_context, ll_relations_context],
 
1052
  if not len(results):
1053
  return "", "", ""
1054
  # get entity information
1055
+ node_datas, node_degrees = await asyncio.gather(
1056
+ asyncio.gather(
1057
+ *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
1058
+ ),
1059
+ asyncio.gather(
1060
+ *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
1061
+ ),
1062
  )
1063
+
1064
  if not all([n is not None for n in node_datas]):
1065
  logger.warning("Some nodes are missing, maybe the storage is damaged")
1066
 
 
 
 
 
1067
  node_datas = [
1068
  {**n, "entity_name": k["entity_name"], "rank": d}
1069
  for k, n, d in zip(results, node_datas, node_degrees)
1070
  if n is not None
1071
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1072
  # get entitytext chunk
1073
+ use_text_units, use_relations = await asyncio.gather(
1074
+ _find_most_related_text_unit_from_entities(
1075
+ node_datas, query_param, text_chunks_db, knowledge_graph_inst
1076
+ ),
1077
+ _find_most_related_edges_from_entities(
1078
+ node_datas, query_param, knowledge_graph_inst
1079
+ ),
1080
  )
1081
  logger.info(
1082
  f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
 
1166
  }
1167
 
1168
  all_text_units_lookup = {}
1169
+ tasks = []
1170
  for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
1171
  for c_id in this_text_units:
1172
  if c_id not in all_text_units_lookup:
1173
+ tasks.append((c_id, index, this_edges))
 
 
 
 
1174
 
1175
+ results = await asyncio.gather(
1176
+ *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks]
1177
+ )
1178
+
1179
+ for (c_id, index, this_edges), data in zip(tasks, results):
1180
+ all_text_units_lookup[c_id] = {
1181
+ "data": data,
1182
+ "order": index,
1183
+ "relation_counts": 0,
1184
+ }
1185
+
1186
+ if this_edges:
1187
+ for e in this_edges:
1188
+ if (
1189
+ e[1] in all_one_hop_text_units_lookup
1190
+ and c_id in all_one_hop_text_units_lookup[e[1]]
1191
+ ):
1192
+ all_text_units_lookup[c_id]["relation_counts"] += 1
1193
 
1194
  # Filter out None values and ensure data has content
1195
  all_text_units = [
 
1234
  seen.add(sorted_edge)
1235
  all_edges.append(sorted_edge)
1236
 
1237
+ all_edges_pack, all_edges_degree = await asyncio.gather(
1238
+ asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
1239
+ asyncio.gather(
1240
+ *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
1241
+ ),
1242
  )
1243
  all_edges_data = [
1244
  {"src_tgt": k, "rank": d, **v}
 
1268
  if not len(results):
1269
  return "", "", ""
1270
 
1271
+ edge_datas, edge_degree = await asyncio.gather(
1272
+ asyncio.gather(
1273
+ *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
1274
+ ),
1275
+ asyncio.gather(
1276
+ *[
1277
+ knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
1278
+ for r in results
1279
+ ]
1280
+ ),
1281
  )
1282
 
1283
  if not all([n is not None for n in edge_datas]):
1284
  logger.warning("Some edges are missing, maybe the storage is damaged")
1285
+
 
 
1286
  edge_datas = [
1287
  {
1288
  "src_id": k["src_id"],
 
1303
  max_token_size=query_param.max_token_for_global_context,
1304
  )
1305
 
1306
+ use_entities, use_text_units = await asyncio.gather(
1307
+ _find_most_related_entities_from_relationships(
1308
+ edge_datas, query_param, knowledge_graph_inst
1309
+ ),
1310
+ _find_related_text_unit_from_relationships(
1311
+ edge_datas, query_param, text_chunks_db, knowledge_graph_inst
1312
+ ),
1313
  )
1314
  logger.info(
1315
  f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
 
1382
  entity_names.append(e["tgt_id"])
1383
  seen.add(e["tgt_id"])
1384
 
1385
+ node_datas, node_degrees = await asyncio.gather(
1386
+ asyncio.gather(
1387
+ *[
1388
+ knowledge_graph_inst.get_node(entity_name)
1389
+ for entity_name in entity_names
1390
+ ]
1391
+ ),
1392
+ asyncio.gather(
1393
+ *[
1394
+ knowledge_graph_inst.node_degree(entity_name)
1395
+ for entity_name in entity_names
1396
+ ]
1397
+ ),
1398
  )
1399
  node_datas = [
1400
  {**n, "entity_name": k, "rank": d}
 
1422
  ]
1423
  all_text_units_lookup = {}
1424
 
1425
+ async def fetch_chunk_data(c_id, index):
1426
+ if c_id not in all_text_units_lookup:
1427
+ chunk_data = await text_chunks_db.get_by_id(c_id)
1428
+ # Only store valid data
1429
+ if chunk_data is not None and "content" in chunk_data:
1430
+ all_text_units_lookup[c_id] = {
1431
+ "data": chunk_data,
1432
+ "order": index,
1433
+ }
1434
+
1435
+ tasks = []
1436
  for index, unit_list in enumerate(text_units):
1437
  for c_id in unit_list:
1438
+ tasks.append(fetch_chunk_data(c_id, index))
1439
+
1440
+ await asyncio.gather(*tasks)
 
 
 
 
 
1441
 
1442
  if not all_text_units_lookup:
1443
  logger.warning("No valid text chunks found")