yangdx commited on
Commit
5a6d534
·
1 Parent(s): 4498e2f

Refactoring PostgreSQL AGE graph db implementation

Browse files
Files changed (1) hide show
  1. lightrag/kg/postgres_impl.py +120 -211
lightrag/kg/postgres_impl.py CHANGED
@@ -1064,31 +1064,11 @@ class PGGraphStorage(BaseGraphStorage):
1064
  if v.startswith("[") and v.endswith("]"):
1065
  if "::vertex" in v:
1066
  v = v.replace("::vertex", "")
1067
- vertexes = json.loads(v)
1068
- dl = []
1069
- for vertex in vertexes:
1070
- prop = vertex.get("properties")
1071
- if not prop:
1072
- prop = {}
1073
- prop["label"] = PGGraphStorage._decode_graph_label(
1074
- prop["node_id"]
1075
- )
1076
- dl.append(prop)
1077
- d[k] = dl
1078
 
1079
  elif "::edge" in v:
1080
  v = v.replace("::edge", "")
1081
- edges = json.loads(v)
1082
- dl = []
1083
- for edge in edges:
1084
- dl.append(
1085
- (
1086
- vertices[edge["start_id"]],
1087
- edge["label"],
1088
- vertices[edge["end_id"]],
1089
- )
1090
- )
1091
- d[k] = dl
1092
  else:
1093
  print("WARNING: unsupported type")
1094
  continue
@@ -1097,26 +1077,9 @@ class PGGraphStorage(BaseGraphStorage):
1097
  dtype = v.split("::")[-1]
1098
  v = v.split("::")[0]
1099
  if dtype == "vertex":
1100
- vertex = json.loads(v)
1101
- field = vertex.get("properties")
1102
- if not field:
1103
- field = {}
1104
- field["label"] = PGGraphStorage._decode_graph_label(
1105
- field["node_id"]
1106
- )
1107
- d[k] = field
1108
- # convert edge from id-label->id by replacing id with node information
1109
- # we only do this if the vertex was also returned in the query
1110
- # this is an attempt to be consistent with neo4j implementation
1111
  elif dtype == "edge":
1112
- edge = json.loads(v)
1113
- d[k] = (
1114
- vertices.get(edge["start_id"], {}),
1115
- edge[
1116
- "label"
1117
- ], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
1118
- vertices.get(edge["end_id"], {}),
1119
- )
1120
  else:
1121
  d[k] = (
1122
  json.loads(v)
@@ -1152,56 +1115,6 @@ class PGGraphStorage(BaseGraphStorage):
1152
  )
1153
  return "{" + ", ".join(props) + "}"
1154
 
1155
- @staticmethod
1156
- def _encode_graph_label(label: str) -> str:
1157
- """
1158
- Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
1159
-
1160
- Args:
1161
- label (str): the original label
1162
-
1163
- Returns:
1164
- str: the encoded label
1165
- """
1166
- return "x" + label.encode().hex()
1167
-
1168
- @staticmethod
1169
- def _decode_graph_label(encoded_label: str) -> str:
1170
- """
1171
- Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
1172
-
1173
- Args:
1174
- encoded_label (str): the encoded label
1175
-
1176
- Returns:
1177
- str: the decoded label
1178
- """
1179
- return bytes.fromhex(encoded_label.removeprefix("x")).decode()
1180
-
1181
- @staticmethod
1182
- def _get_col_name(field: str, idx: int) -> str:
1183
- """
1184
- Convert a cypher return field to a pgsql select field
1185
- If possible keep the cypher column name, but create a generic name if necessary
1186
-
1187
- Args:
1188
- field (str): a return field from a cypher query to be formatted for pgsql
1189
- idx (int): the position of the field in the return statement
1190
-
1191
- Returns:
1192
- str: the field to be used in the pgsql select statement
1193
- """
1194
- # remove white space
1195
- field = field.strip()
1196
- # if an alias is provided for the field, use it
1197
- if " as " in field:
1198
- return field.split(" as ")[-1].strip()
1199
- # if the return value is an unnamed primitive, give it a generic name
1200
- if field.isnumeric() or field in ("true", "false", "null"):
1201
- return f"column_{idx}"
1202
- # otherwise return the value stripping out some common special chars
1203
- return field.replace("(", "_").replace(")", "")
1204
-
1205
  async def _query(
1206
  self,
1207
  query: str,
@@ -1252,10 +1165,10 @@ class PGGraphStorage(BaseGraphStorage):
1252
  return result
1253
 
1254
  async def has_node(self, node_id: str) -> bool:
1255
- entity_name_label = self._encode_graph_label(node_id.strip('"'))
1256
 
1257
  query = """SELECT * FROM cypher('%s', $$
1258
- MATCH (n:base {node_id: "%s"})
1259
  RETURN count(n) > 0 AS node_exists
1260
  $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
1261
 
@@ -1264,11 +1177,11 @@ class PGGraphStorage(BaseGraphStorage):
1264
  return single_result["node_exists"]
1265
 
1266
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
1267
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1268
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1269
 
1270
  query = """SELECT * FROM cypher('%s', $$
1271
- MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"})
1272
  RETURN COUNT(r) > 0 AS edge_exists
1273
  $$) AS (edge_exists bool)""" % (
1274
  self.graph_name,
@@ -1281,13 +1194,14 @@ class PGGraphStorage(BaseGraphStorage):
1281
  return single_result["edge_exists"]
1282
 
1283
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1284
- label = self._encode_graph_label(node_id.strip('"'))
1285
  query = """SELECT * FROM cypher('%s', $$
1286
- MATCH (n:base {node_id: "%s"})
1287
  RETURN n
1288
  $$) AS (n agtype)""" % (self.graph_name, label)
1289
  record = await self._query(query)
1290
  if record:
 
1291
  node = record[0]
1292
  node_dict = node["n"]
1293
 
@@ -1295,10 +1209,10 @@ class PGGraphStorage(BaseGraphStorage):
1295
  return None
1296
 
1297
  async def node_degree(self, node_id: str) -> int:
1298
- label = self._encode_graph_label(node_id.strip('"'))
1299
 
1300
  query = """SELECT * FROM cypher('%s', $$
1301
- MATCH (n:base {node_id: "%s"})-[]->(x)
1302
  RETURN count(x) AS total_edge_count
1303
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1304
  record = (await self._query(query))[0]
@@ -1322,11 +1236,11 @@ class PGGraphStorage(BaseGraphStorage):
1322
  async def get_edge(
1323
  self, source_node_id: str, target_node_id: str
1324
  ) -> dict[str, str] | None:
1325
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1326
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1327
 
1328
  query = """SELECT * FROM cypher('%s', $$
1329
- MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
1330
  RETURN properties(r) as edge_properties
1331
  LIMIT 1
1332
  $$) AS (edge_properties agtype)""" % (
@@ -1336,6 +1250,7 @@ class PGGraphStorage(BaseGraphStorage):
1336
  )
1337
  record = await self._query(query)
1338
  if record and record[0] and record[0]["edge_properties"]:
 
1339
  result = record[0]["edge_properties"]
1340
 
1341
  return result
@@ -1345,10 +1260,10 @@ class PGGraphStorage(BaseGraphStorage):
1345
  Retrieves all edges (relationships) for a particular node identified by its label.
1346
  :return: list of dictionaries containing edge information
1347
  """
1348
- label = self._encode_graph_label(source_node_id.strip('"'))
1349
 
1350
  query = """SELECT * FROM cypher('%s', $$
1351
- MATCH (n:base {node_id: "%s"})
1352
  OPTIONAL MATCH (n)-[]-(connected:base)
1353
  RETURN n, connected
1354
  $$) AS (n agtype, connected agtype)""" % (
@@ -1362,24 +1277,17 @@ class PGGraphStorage(BaseGraphStorage):
1362
  source_node = record["n"] if record["n"] else None
1363
  connected_node = record["connected"] if record["connected"] else None
1364
 
1365
- source_label = (
1366
- source_node["node_id"]
1367
- if source_node and source_node["node_id"]
1368
- else None
1369
- )
1370
- target_label = (
1371
- connected_node["node_id"]
1372
- if connected_node and connected_node["node_id"]
1373
- else None
1374
- )
1375
 
1376
- if source_label and target_label:
1377
- edges.append(
1378
- (
1379
- self._decode_graph_label(source_label),
1380
- self._decode_graph_label(target_label),
1381
- )
1382
- )
1383
 
1384
  return edges
1385
 
@@ -1389,17 +1297,17 @@ class PGGraphStorage(BaseGraphStorage):
1389
  retry=retry_if_exception_type((PGGraphQueryException,)),
1390
  )
1391
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
1392
- label = self._encode_graph_label(node_id.strip('"'))
1393
- properties = node_data
1394
 
1395
  query = """SELECT * FROM cypher('%s', $$
1396
- MERGE (n:base {node_id: "%s"})
1397
  SET n += %s
1398
  RETURN n
1399
  $$) AS (n agtype)""" % (
1400
  self.graph_name,
1401
  label,
1402
- self._format_properties(properties),
1403
  )
1404
 
1405
  try:
@@ -1425,14 +1333,14 @@ class PGGraphStorage(BaseGraphStorage):
1425
  target_node_id (str): Label of the target node (used as identifier)
1426
  edge_data (dict): dictionary of properties to set on the edge
1427
  """
1428
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1429
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1430
- edge_properties = edge_data
1431
 
1432
  query = """SELECT * FROM cypher('%s', $$
1433
- MATCH (source:base {node_id: "%s"})
1434
  WITH source
1435
- MATCH (target:base {node_id: "%s"})
1436
  MERGE (source)-[r:DIRECTED]->(target)
1437
  SET r += %s
1438
  RETURN r
@@ -1440,7 +1348,7 @@ class PGGraphStorage(BaseGraphStorage):
1440
  self.graph_name,
1441
  src_label,
1442
  tgt_label,
1443
- self._format_properties(edge_properties),
1444
  )
1445
 
1446
  try:
@@ -1460,7 +1368,7 @@ class PGGraphStorage(BaseGraphStorage):
1460
  Args:
1461
  node_id (str): The ID of the node to delete.
1462
  """
1463
- label = self._encode_graph_label(node_id.strip('"'))
1464
 
1465
  query = """SELECT * FROM cypher('%s', $$
1466
  MATCH (n:base {entity_id: "%s"})
@@ -1480,14 +1388,12 @@ class PGGraphStorage(BaseGraphStorage):
1480
  Args:
1481
  node_ids (list[str]): A list of node IDs to remove.
1482
  """
1483
- encoded_node_ids = [
1484
- self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
1485
- ]
1486
- node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
1487
 
1488
  query = """SELECT * FROM cypher('%s', $$
1489
  MATCH (n:base)
1490
- WHERE n.nentity_id IN [%s]
1491
  DETACH DELETE n
1492
  $$) AS (n agtype)""" % (self.graph_name, node_id_list)
1493
 
@@ -1505,11 +1411,11 @@ class PGGraphStorage(BaseGraphStorage):
1505
  edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
1506
  """
1507
  for source, target in edges:
1508
- src_label = self._encode_graph_label(source.strip('"'))
1509
- tgt_label = self._encode_graph_label(target.strip('"'))
1510
 
1511
  query = """SELECT * FROM cypher('%s', $$
1512
- MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
1513
  DELETE r
1514
  $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
1515
 
@@ -1560,95 +1466,98 @@ class PGGraphStorage(BaseGraphStorage):
1560
  return await embed_func()
1561
 
1562
  async def get_knowledge_graph(
1563
- self, node_label: str, max_depth: int = 5
 
 
 
1564
  ) -> KnowledgeGraph:
1565
  """
1566
- Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
1567
 
1568
  Args:
1569
- node_label (str): The label of the node to start from. If "*", the entire graph is returned.
1570
- max_depth (int): The maximum depth to traverse from the starting node.
 
1571
 
1572
  Returns:
1573
- KnowledgeGraph: The retrieved subgraph.
 
1574
  """
1575
- MAX_GRAPH_NODES = 1000
1576
 
1577
  # Build the query based on whether we want the full graph or a specific subgraph.
1578
  if node_label == "*":
1579
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1580
- MATCH (n:base)
1581
- OPTIONAL MATCH (n)-[r]->(m:base)
1582
- RETURN n, r, m
1583
- LIMIT {MAX_GRAPH_NODES}
1584
- $$) AS (n agtype, r agtype, m agtype)"""
1585
  else:
1586
- encoded_label = self._encode_graph_label(node_label.strip('"'))
1587
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1588
- MATCH (n:base {{entity_id: "{encoded_label}"}})
1589
- OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1590
- RETURN nodes(p) AS nodes, relationships(p) AS relationships
1591
- LIMIT {MAX_GRAPH_NODES}
1592
- $$) AS (nodes agtype, relationships agtype)"""
1593
 
1594
  results = await self._query(query)
1595
 
1596
- nodes = {}
1597
- edges = []
1598
- unique_edge_ids = set()
1599
-
1600
- def add_node(node_data: dict):
1601
- node_id = self._decode_graph_label(node_data["node_id"])
1602
- if node_id not in nodes:
1603
- nodes[node_id] = node_data
1604
-
1605
- def add_edge(edge_data: list):
1606
- src_id = self._decode_graph_label(edge_data[0]["node_id"])
1607
- tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
1608
- edge_key = f"{src_id},{tgt_id}"
1609
- if edge_key not in unique_edge_ids:
1610
- unique_edge_ids.add(edge_key)
1611
- edges.append(
1612
- (
1613
- edge_key,
1614
- src_id,
1615
- tgt_id,
1616
- {"source": edge_data[0], "target": edge_data[2]},
1617
  )
1618
- )
1619
-
1620
- # Process the query results.
1621
- if node_label == "*":
1622
- for result in results:
1623
- if result.get("n"):
1624
- add_node(result["n"])
1625
- if result.get("m"):
1626
- add_node(result["m"])
1627
- if result.get("r"):
1628
- add_edge(result["r"])
1629
- else:
1630
- for result in results:
1631
- for node in result.get("nodes", []):
1632
- add_node(node)
1633
- for edge in result.get("relationships", []):
1634
- add_edge(edge)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1635
 
1636
- # Construct and return the KnowledgeGraph.
1637
  kg = KnowledgeGraph(
1638
- nodes=[
1639
- KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
1640
- for node_id, node_data in nodes.items()
1641
- ],
1642
- edges=[
1643
- KnowledgeGraphEdge(
1644
- id=edge_id,
1645
- type="DIRECTED",
1646
- source=src,
1647
- target=tgt,
1648
- properties=props,
1649
- )
1650
- for edge_id, src, tgt, props in edges
1651
- ],
1652
  )
1653
 
1654
  return kg
 
1064
  if v.startswith("[") and v.endswith("]"):
1065
  if "::vertex" in v:
1066
  v = v.replace("::vertex", "")
1067
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1068
 
1069
  elif "::edge" in v:
1070
  v = v.replace("::edge", "")
1071
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1072
  else:
1073
  print("WARNING: unsupported type")
1074
  continue
 
1077
  dtype = v.split("::")[-1]
1078
  v = v.split("::")[0]
1079
  if dtype == "vertex":
1080
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1081
  elif dtype == "edge":
1082
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
1083
  else:
1084
  d[k] = (
1085
  json.loads(v)
 
1115
  )
1116
  return "{" + ", ".join(props) + "}"
1117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1118
  async def _query(
1119
  self,
1120
  query: str,
 
1165
  return result
1166
 
1167
  async def has_node(self, node_id: str) -> bool:
1168
+ entity_name_label = node_id.strip('"')
1169
 
1170
  query = """SELECT * FROM cypher('%s', $$
1171
+ MATCH (n:base {entity_id: "%s"})
1172
  RETURN count(n) > 0 AS node_exists
1173
  $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
1174
 
 
1177
  return single_result["node_exists"]
1178
 
1179
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
1180
+ src_label = source_node_id.strip('"')
1181
+ tgt_label = target_node_id.strip('"')
1182
 
1183
  query = """SELECT * FROM cypher('%s', $$
1184
+ MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
1185
  RETURN COUNT(r) > 0 AS edge_exists
1186
  $$) AS (edge_exists bool)""" % (
1187
  self.graph_name,
 
1194
  return single_result["edge_exists"]
1195
 
1196
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1197
+ label = node_id.strip('"')
1198
  query = """SELECT * FROM cypher('%s', $$
1199
+ MATCH (n:base {entity_id: "%s"})
1200
  RETURN n
1201
  $$) AS (n agtype)""" % (self.graph_name, label)
1202
  record = await self._query(query)
1203
  if record:
1204
+ print(f"Record: {record}")
1205
  node = record[0]
1206
  node_dict = node["n"]
1207
 
 
1209
  return None
1210
 
1211
  async def node_degree(self, node_id: str) -> int:
1212
+ label = node_id.strip('"')
1213
 
1214
  query = """SELECT * FROM cypher('%s', $$
1215
+ MATCH (n:base {entity_id: "%s"})-[]->(x)
1216
  RETURN count(x) AS total_edge_count
1217
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1218
  record = (await self._query(query))[0]
 
1236
  async def get_edge(
1237
  self, source_node_id: str, target_node_id: str
1238
  ) -> dict[str, str] | None:
1239
+ src_label = source_node_id.strip('"')
1240
+ tgt_label = target_node_id.strip('"')
1241
 
1242
  query = """SELECT * FROM cypher('%s', $$
1243
+ MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
1244
  RETURN properties(r) as edge_properties
1245
  LIMIT 1
1246
  $$) AS (edge_properties agtype)""" % (
 
1250
  )
1251
  record = await self._query(query)
1252
  if record and record[0] and record[0]["edge_properties"]:
1253
+ print(f"Record: {record}")
1254
  result = record[0]["edge_properties"]
1255
 
1256
  return result
 
1260
  Retrieves all edges (relationships) for a particular node identified by its label.
1261
  :return: list of dictionaries containing edge information
1262
  """
1263
+ label = source_node_id.strip('"')
1264
 
1265
  query = """SELECT * FROM cypher('%s', $$
1266
+ MATCH (n:base {entity_id: "%s"})
1267
  OPTIONAL MATCH (n)-[]-(connected:base)
1268
  RETURN n, connected
1269
  $$) AS (n agtype, connected agtype)""" % (
 
1277
  source_node = record["n"] if record["n"] else None
1278
  connected_node = record["connected"] if record["connected"] else None
1279
 
1280
+ if (
1281
+ source_node
1282
+ and connected_node
1283
+ and "properties" in source_node
1284
+ and "properties" in connected_node
1285
+ ):
1286
+ source_label = source_node["properties"].get("entity_id")
1287
+ target_label = connected_node["properties"].get("entity_id")
 
 
1288
 
1289
+ if source_label and target_label:
1290
+ edges.append((source_label, target_label))
 
 
 
 
 
1291
 
1292
  return edges
1293
 
 
1297
  retry=retry_if_exception_type((PGGraphQueryException,)),
1298
  )
1299
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
1300
+ label = node_id.strip('"')
1301
+ properties = self._format_properties(node_data)
1302
 
1303
  query = """SELECT * FROM cypher('%s', $$
1304
+ MERGE (n:base {entity_id: "%s"})
1305
  SET n += %s
1306
  RETURN n
1307
  $$) AS (n agtype)""" % (
1308
  self.graph_name,
1309
  label,
1310
+ properties,
1311
  )
1312
 
1313
  try:
 
1333
  target_node_id (str): Label of the target node (used as identifier)
1334
  edge_data (dict): dictionary of properties to set on the edge
1335
  """
1336
+ src_label = source_node_id.strip('"')
1337
+ tgt_label = target_node_id.strip('"')
1338
+ edge_properties = self._format_properties(edge_data)
1339
 
1340
  query = """SELECT * FROM cypher('%s', $$
1341
+ MATCH (source:base {entity_id: "%s"})
1342
  WITH source
1343
+ MATCH (target:base {entity_id: "%s"})
1344
  MERGE (source)-[r:DIRECTED]->(target)
1345
  SET r += %s
1346
  RETURN r
 
1348
  self.graph_name,
1349
  src_label,
1350
  tgt_label,
1351
+ edge_properties,
1352
  )
1353
 
1354
  try:
 
1368
  Args:
1369
  node_id (str): The ID of the node to delete.
1370
  """
1371
+ label = node_id.strip('"')
1372
 
1373
  query = """SELECT * FROM cypher('%s', $$
1374
  MATCH (n:base {entity_id: "%s"})
 
1388
  Args:
1389
  node_ids (list[str]): A list of node IDs to remove.
1390
  """
1391
+ node_ids = [node_id.strip('"') for node_id in node_ids]
1392
+ node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
 
 
1393
 
1394
  query = """SELECT * FROM cypher('%s', $$
1395
  MATCH (n:base)
1396
+ WHERE n.entity_id IN [%s]
1397
  DETACH DELETE n
1398
  $$) AS (n agtype)""" % (self.graph_name, node_id_list)
1399
 
 
1411
  edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
1412
  """
1413
  for source, target in edges:
1414
+ src_label = source.strip('"')
1415
+ tgt_label = target.strip('"')
1416
 
1417
  query = """SELECT * FROM cypher('%s', $$
1418
+ MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
1419
  DELETE r
1420
  $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
1421
 
 
1466
  return await embed_func()
1467
 
1468
  async def get_knowledge_graph(
1469
+ self,
1470
+ node_label: str,
1471
+ max_depth: int = 3,
1472
+ max_nodes: int = MAX_GRAPH_NODES,
1473
  ) -> KnowledgeGraph:
1474
  """
1475
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
1476
 
1477
  Args:
1478
+ node_label: Label of the starting node, * means all nodes
1479
+ max_depth: Maximum depth of the subgraph, Defaults to 3
1480
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
1481
 
1482
  Returns:
1483
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
1484
+ indicating whether the graph was truncated due to max_nodes limit
1485
  """
 
1486
 
1487
  # Build the query based on whether we want the full graph or a specific subgraph.
1488
  if node_label == "*":
1489
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1490
+ MATCH (n:base)
1491
+ OPTIONAL MATCH (n)-[r]->(target:base)
1492
+ RETURN collect(distinct n) AS n, collect(distinct r) AS r
1493
+ LIMIT {MAX_GRAPH_NODES}
1494
+ $$) AS (n agtype, r agtype)"""
1495
  else:
1496
+ strip_label = node_label.strip('"')
1497
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1498
+ MATCH (n:base {{entity_id: "{strip_label}"}})
1499
+ OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1500
+ RETURN nodes(p) AS n, relationships(p) AS r
1501
+ LIMIT {max_nodes}
1502
+ $$) AS (n agtype, r agtype)"""
1503
 
1504
  results = await self._query(query)
1505
 
1506
+ # Process the query results with deduplication by node and edge IDs
1507
+ nodes_dict = {}
1508
+ edges_dict = {}
1509
+ for result in results:
1510
+ # Handle single node cases
1511
+ if result.get("n") and isinstance(result["n"], dict):
1512
+ node_id = str(result["n"]["id"])
1513
+ if node_id not in nodes_dict:
1514
+ nodes_dict[node_id] = KnowledgeGraphNode(
1515
+ id=node_id,
1516
+ labels=[result["n"]["properties"]["entity_id"]],
1517
+ properties=result["n"]["properties"],
 
 
 
 
 
 
 
 
 
1518
  )
1519
+ # Handle node list cases
1520
+ elif result.get("n") and isinstance(result["n"], list):
1521
+ for node in result["n"]:
1522
+ if isinstance(node, dict) and "id" in node:
1523
+ node_id = str(node["id"])
1524
+ if node_id not in nodes_dict and "properties" in node:
1525
+ nodes_dict[node_id] = KnowledgeGraphNode(
1526
+ id=node_id,
1527
+ labels=[node["properties"]["entity_id"]],
1528
+ properties=node["properties"],
1529
+ )
1530
+
1531
+ # Handle single edge cases
1532
+ if result.get("r") and isinstance(result["r"], dict):
1533
+ edge_id = str(result["r"]["id"])
1534
+ if edge_id not in edges_dict:
1535
+ edges_dict[edge_id] = KnowledgeGraphEdge(
1536
+ id=edge_id,
1537
+ type="DIRECTED",
1538
+ source=str(result["r"]["start_id"]),
1539
+ target=str(result["r"]["end_id"]),
1540
+ properties=result["r"]["properties"],
1541
+ )
1542
+ # Handle edge list cases
1543
+ elif result.get("r") and isinstance(result["r"], list):
1544
+ for edge in result["r"]:
1545
+ if isinstance(edge, dict) and "id" in edge:
1546
+ edge_id = str(edge["id"])
1547
+ if edge_id not in edges_dict:
1548
+ edges_dict[edge_id] = KnowledgeGraphEdge(
1549
+ id=edge_id,
1550
+ type="DIRECTED",
1551
+ source=str(edge["start_id"]),
1552
+ target=str(edge["end_id"]),
1553
+ properties=edge["properties"],
1554
+ )
1555
 
1556
+ # Construct and return the KnowledgeGraph with deduplicated nodes and edges
1557
  kg = KnowledgeGraph(
1558
+ nodes=list(nodes_dict.values()),
1559
+ edges=list(edges_dict.values()),
1560
+ is_truncated=False,
 
 
 
 
 
 
 
 
 
 
 
1561
  )
1562
 
1563
  return kg