yangdx
commited on
Commit
·
5a6d534
1
Parent(s):
4498e2f
Refactoring PostgreSQL AGE graph db implementation
Browse files- lightrag/kg/postgres_impl.py +120 -211
lightrag/kg/postgres_impl.py
CHANGED
|
@@ -1064,31 +1064,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1064 |
if v.startswith("[") and v.endswith("]"):
|
| 1065 |
if "::vertex" in v:
|
| 1066 |
v = v.replace("::vertex", "")
|
| 1067 |
-
|
| 1068 |
-
dl = []
|
| 1069 |
-
for vertex in vertexes:
|
| 1070 |
-
prop = vertex.get("properties")
|
| 1071 |
-
if not prop:
|
| 1072 |
-
prop = {}
|
| 1073 |
-
prop["label"] = PGGraphStorage._decode_graph_label(
|
| 1074 |
-
prop["node_id"]
|
| 1075 |
-
)
|
| 1076 |
-
dl.append(prop)
|
| 1077 |
-
d[k] = dl
|
| 1078 |
|
| 1079 |
elif "::edge" in v:
|
| 1080 |
v = v.replace("::edge", "")
|
| 1081 |
-
|
| 1082 |
-
dl = []
|
| 1083 |
-
for edge in edges:
|
| 1084 |
-
dl.append(
|
| 1085 |
-
(
|
| 1086 |
-
vertices[edge["start_id"]],
|
| 1087 |
-
edge["label"],
|
| 1088 |
-
vertices[edge["end_id"]],
|
| 1089 |
-
)
|
| 1090 |
-
)
|
| 1091 |
-
d[k] = dl
|
| 1092 |
else:
|
| 1093 |
print("WARNING: unsupported type")
|
| 1094 |
continue
|
|
@@ -1097,26 +1077,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1097 |
dtype = v.split("::")[-1]
|
| 1098 |
v = v.split("::")[0]
|
| 1099 |
if dtype == "vertex":
|
| 1100 |
-
|
| 1101 |
-
field = vertex.get("properties")
|
| 1102 |
-
if not field:
|
| 1103 |
-
field = {}
|
| 1104 |
-
field["label"] = PGGraphStorage._decode_graph_label(
|
| 1105 |
-
field["node_id"]
|
| 1106 |
-
)
|
| 1107 |
-
d[k] = field
|
| 1108 |
-
# convert edge from id-label->id by replacing id with node information
|
| 1109 |
-
# we only do this if the vertex was also returned in the query
|
| 1110 |
-
# this is an attempt to be consistent with neo4j implementation
|
| 1111 |
elif dtype == "edge":
|
| 1112 |
-
|
| 1113 |
-
d[k] = (
|
| 1114 |
-
vertices.get(edge["start_id"], {}),
|
| 1115 |
-
edge[
|
| 1116 |
-
"label"
|
| 1117 |
-
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
| 1118 |
-
vertices.get(edge["end_id"], {}),
|
| 1119 |
-
)
|
| 1120 |
else:
|
| 1121 |
d[k] = (
|
| 1122 |
json.loads(v)
|
|
@@ -1152,56 +1115,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1152 |
)
|
| 1153 |
return "{" + ", ".join(props) + "}"
|
| 1154 |
|
| 1155 |
-
@staticmethod
|
| 1156 |
-
def _encode_graph_label(label: str) -> str:
|
| 1157 |
-
"""
|
| 1158 |
-
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
| 1159 |
-
|
| 1160 |
-
Args:
|
| 1161 |
-
label (str): the original label
|
| 1162 |
-
|
| 1163 |
-
Returns:
|
| 1164 |
-
str: the encoded label
|
| 1165 |
-
"""
|
| 1166 |
-
return "x" + label.encode().hex()
|
| 1167 |
-
|
| 1168 |
-
@staticmethod
|
| 1169 |
-
def _decode_graph_label(encoded_label: str) -> str:
|
| 1170 |
-
"""
|
| 1171 |
-
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
| 1172 |
-
|
| 1173 |
-
Args:
|
| 1174 |
-
encoded_label (str): the encoded label
|
| 1175 |
-
|
| 1176 |
-
Returns:
|
| 1177 |
-
str: the decoded label
|
| 1178 |
-
"""
|
| 1179 |
-
return bytes.fromhex(encoded_label.removeprefix("x")).decode()
|
| 1180 |
-
|
| 1181 |
-
@staticmethod
|
| 1182 |
-
def _get_col_name(field: str, idx: int) -> str:
|
| 1183 |
-
"""
|
| 1184 |
-
Convert a cypher return field to a pgsql select field
|
| 1185 |
-
If possible keep the cypher column name, but create a generic name if necessary
|
| 1186 |
-
|
| 1187 |
-
Args:
|
| 1188 |
-
field (str): a return field from a cypher query to be formatted for pgsql
|
| 1189 |
-
idx (int): the position of the field in the return statement
|
| 1190 |
-
|
| 1191 |
-
Returns:
|
| 1192 |
-
str: the field to be used in the pgsql select statement
|
| 1193 |
-
"""
|
| 1194 |
-
# remove white space
|
| 1195 |
-
field = field.strip()
|
| 1196 |
-
# if an alias is provided for the field, use it
|
| 1197 |
-
if " as " in field:
|
| 1198 |
-
return field.split(" as ")[-1].strip()
|
| 1199 |
-
# if the return value is an unnamed primitive, give it a generic name
|
| 1200 |
-
if field.isnumeric() or field in ("true", "false", "null"):
|
| 1201 |
-
return f"column_{idx}"
|
| 1202 |
-
# otherwise return the value stripping out some common special chars
|
| 1203 |
-
return field.replace("(", "_").replace(")", "")
|
| 1204 |
-
|
| 1205 |
async def _query(
|
| 1206 |
self,
|
| 1207 |
query: str,
|
|
@@ -1252,10 +1165,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1252 |
return result
|
| 1253 |
|
| 1254 |
async def has_node(self, node_id: str) -> bool:
|
| 1255 |
-
entity_name_label =
|
| 1256 |
|
| 1257 |
query = """SELECT * FROM cypher('%s', $$
|
| 1258 |
-
MATCH (n:base {
|
| 1259 |
RETURN count(n) > 0 AS node_exists
|
| 1260 |
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
| 1261 |
|
|
@@ -1264,11 +1177,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1264 |
return single_result["node_exists"]
|
| 1265 |
|
| 1266 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 1267 |
-
src_label =
|
| 1268 |
-
tgt_label =
|
| 1269 |
|
| 1270 |
query = """SELECT * FROM cypher('%s', $$
|
| 1271 |
-
MATCH (a:base {
|
| 1272 |
RETURN COUNT(r) > 0 AS edge_exists
|
| 1273 |
$$) AS (edge_exists bool)""" % (
|
| 1274 |
self.graph_name,
|
|
@@ -1281,13 +1194,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1281 |
return single_result["edge_exists"]
|
| 1282 |
|
| 1283 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 1284 |
-
label =
|
| 1285 |
query = """SELECT * FROM cypher('%s', $$
|
| 1286 |
-
MATCH (n:base {
|
| 1287 |
RETURN n
|
| 1288 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
| 1289 |
record = await self._query(query)
|
| 1290 |
if record:
|
|
|
|
| 1291 |
node = record[0]
|
| 1292 |
node_dict = node["n"]
|
| 1293 |
|
|
@@ -1295,10 +1209,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1295 |
return None
|
| 1296 |
|
| 1297 |
async def node_degree(self, node_id: str) -> int:
|
| 1298 |
-
label =
|
| 1299 |
|
| 1300 |
query = """SELECT * FROM cypher('%s', $$
|
| 1301 |
-
MATCH (n:base {
|
| 1302 |
RETURN count(x) AS total_edge_count
|
| 1303 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
| 1304 |
record = (await self._query(query))[0]
|
|
@@ -1322,11 +1236,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1322 |
async def get_edge(
|
| 1323 |
self, source_node_id: str, target_node_id: str
|
| 1324 |
) -> dict[str, str] | None:
|
| 1325 |
-
src_label =
|
| 1326 |
-
tgt_label =
|
| 1327 |
|
| 1328 |
query = """SELECT * FROM cypher('%s', $$
|
| 1329 |
-
MATCH (a:base {
|
| 1330 |
RETURN properties(r) as edge_properties
|
| 1331 |
LIMIT 1
|
| 1332 |
$$) AS (edge_properties agtype)""" % (
|
|
@@ -1336,6 +1250,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1336 |
)
|
| 1337 |
record = await self._query(query)
|
| 1338 |
if record and record[0] and record[0]["edge_properties"]:
|
|
|
|
| 1339 |
result = record[0]["edge_properties"]
|
| 1340 |
|
| 1341 |
return result
|
|
@@ -1345,10 +1260,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1345 |
Retrieves all edges (relationships) for a particular node identified by its label.
|
| 1346 |
:return: list of dictionaries containing edge information
|
| 1347 |
"""
|
| 1348 |
-
label =
|
| 1349 |
|
| 1350 |
query = """SELECT * FROM cypher('%s', $$
|
| 1351 |
-
MATCH (n:base {
|
| 1352 |
OPTIONAL MATCH (n)-[]-(connected:base)
|
| 1353 |
RETURN n, connected
|
| 1354 |
$$) AS (n agtype, connected agtype)""" % (
|
|
@@ -1362,24 +1277,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1362 |
source_node = record["n"] if record["n"] else None
|
| 1363 |
connected_node = record["connected"] if record["connected"] else None
|
| 1364 |
|
| 1365 |
-
|
| 1366 |
-
source_node
|
| 1367 |
-
|
| 1368 |
-
|
| 1369 |
-
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
-
|
| 1373 |
-
else None
|
| 1374 |
-
)
|
| 1375 |
|
| 1376 |
-
|
| 1377 |
-
|
| 1378 |
-
(
|
| 1379 |
-
self._decode_graph_label(source_label),
|
| 1380 |
-
self._decode_graph_label(target_label),
|
| 1381 |
-
)
|
| 1382 |
-
)
|
| 1383 |
|
| 1384 |
return edges
|
| 1385 |
|
|
@@ -1389,17 +1297,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1389 |
retry=retry_if_exception_type((PGGraphQueryException,)),
|
| 1390 |
)
|
| 1391 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
| 1392 |
-
label =
|
| 1393 |
-
properties = node_data
|
| 1394 |
|
| 1395 |
query = """SELECT * FROM cypher('%s', $$
|
| 1396 |
-
MERGE (n:base {
|
| 1397 |
SET n += %s
|
| 1398 |
RETURN n
|
| 1399 |
$$) AS (n agtype)""" % (
|
| 1400 |
self.graph_name,
|
| 1401 |
label,
|
| 1402 |
-
|
| 1403 |
)
|
| 1404 |
|
| 1405 |
try:
|
|
@@ -1425,14 +1333,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1425 |
target_node_id (str): Label of the target node (used as identifier)
|
| 1426 |
edge_data (dict): dictionary of properties to set on the edge
|
| 1427 |
"""
|
| 1428 |
-
src_label =
|
| 1429 |
-
tgt_label =
|
| 1430 |
-
edge_properties = edge_data
|
| 1431 |
|
| 1432 |
query = """SELECT * FROM cypher('%s', $$
|
| 1433 |
-
MATCH (source:base {
|
| 1434 |
WITH source
|
| 1435 |
-
MATCH (target:base {
|
| 1436 |
MERGE (source)-[r:DIRECTED]->(target)
|
| 1437 |
SET r += %s
|
| 1438 |
RETURN r
|
|
@@ -1440,7 +1348,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1440 |
self.graph_name,
|
| 1441 |
src_label,
|
| 1442 |
tgt_label,
|
| 1443 |
-
|
| 1444 |
)
|
| 1445 |
|
| 1446 |
try:
|
|
@@ -1460,7 +1368,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1460 |
Args:
|
| 1461 |
node_id (str): The ID of the node to delete.
|
| 1462 |
"""
|
| 1463 |
-
label =
|
| 1464 |
|
| 1465 |
query = """SELECT * FROM cypher('%s', $$
|
| 1466 |
MATCH (n:base {entity_id: "%s"})
|
|
@@ -1480,14 +1388,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1480 |
Args:
|
| 1481 |
node_ids (list[str]): A list of node IDs to remove.
|
| 1482 |
"""
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
]
|
| 1486 |
-
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
|
| 1487 |
|
| 1488 |
query = """SELECT * FROM cypher('%s', $$
|
| 1489 |
MATCH (n:base)
|
| 1490 |
-
WHERE n.
|
| 1491 |
DETACH DELETE n
|
| 1492 |
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
| 1493 |
|
|
@@ -1505,11 +1411,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1505 |
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
| 1506 |
"""
|
| 1507 |
for source, target in edges:
|
| 1508 |
-
src_label =
|
| 1509 |
-
tgt_label =
|
| 1510 |
|
| 1511 |
query = """SELECT * FROM cypher('%s', $$
|
| 1512 |
-
MATCH (a:base {
|
| 1513 |
DELETE r
|
| 1514 |
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
| 1515 |
|
|
@@ -1560,95 +1466,98 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1560 |
return await embed_func()
|
| 1561 |
|
| 1562 |
async def get_knowledge_graph(
|
| 1563 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 1564 |
) -> KnowledgeGraph:
|
| 1565 |
"""
|
| 1566 |
-
Retrieve a subgraph
|
| 1567 |
|
| 1568 |
Args:
|
| 1569 |
-
node_label
|
| 1570 |
-
max_depth
|
|
|
|
| 1571 |
|
| 1572 |
Returns:
|
| 1573 |
-
KnowledgeGraph
|
|
|
|
| 1574 |
"""
|
| 1575 |
-
MAX_GRAPH_NODES = 1000
|
| 1576 |
|
| 1577 |
# Build the query based on whether we want the full graph or a specific subgraph.
|
| 1578 |
if node_label == "*":
|
| 1579 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1580 |
-
|
| 1581 |
-
|
| 1582 |
-
|
| 1583 |
-
|
| 1584 |
-
|
| 1585 |
else:
|
| 1586 |
-
|
| 1587 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1588 |
-
|
| 1589 |
-
|
| 1590 |
-
|
| 1591 |
-
|
| 1592 |
-
|
| 1593 |
|
| 1594 |
results = await self._query(query)
|
| 1595 |
|
| 1596 |
-
|
| 1597 |
-
|
| 1598 |
-
|
| 1599 |
-
|
| 1600 |
-
|
| 1601 |
-
|
| 1602 |
-
|
| 1603 |
-
|
| 1604 |
-
|
| 1605 |
-
|
| 1606 |
-
|
| 1607 |
-
|
| 1608 |
-
edge_key = f"{src_id},{tgt_id}"
|
| 1609 |
-
if edge_key not in unique_edge_ids:
|
| 1610 |
-
unique_edge_ids.add(edge_key)
|
| 1611 |
-
edges.append(
|
| 1612 |
-
(
|
| 1613 |
-
edge_key,
|
| 1614 |
-
src_id,
|
| 1615 |
-
tgt_id,
|
| 1616 |
-
{"source": edge_data[0], "target": edge_data[2]},
|
| 1617 |
)
|
| 1618 |
-
|
| 1619 |
-
|
| 1620 |
-
|
| 1621 |
-
|
| 1622 |
-
|
| 1623 |
-
|
| 1624 |
-
|
| 1625 |
-
|
| 1626 |
-
|
| 1627 |
-
|
| 1628 |
-
|
| 1629 |
-
|
| 1630 |
-
|
| 1631 |
-
|
| 1632 |
-
|
| 1633 |
-
|
| 1634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1635 |
|
| 1636 |
-
# Construct and return the KnowledgeGraph
|
| 1637 |
kg = KnowledgeGraph(
|
| 1638 |
-
nodes=
|
| 1639 |
-
|
| 1640 |
-
|
| 1641 |
-
],
|
| 1642 |
-
edges=[
|
| 1643 |
-
KnowledgeGraphEdge(
|
| 1644 |
-
id=edge_id,
|
| 1645 |
-
type="DIRECTED",
|
| 1646 |
-
source=src,
|
| 1647 |
-
target=tgt,
|
| 1648 |
-
properties=props,
|
| 1649 |
-
)
|
| 1650 |
-
for edge_id, src, tgt, props in edges
|
| 1651 |
-
],
|
| 1652 |
)
|
| 1653 |
|
| 1654 |
return kg
|
|
|
|
| 1064 |
if v.startswith("[") and v.endswith("]"):
|
| 1065 |
if "::vertex" in v:
|
| 1066 |
v = v.replace("::vertex", "")
|
| 1067 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1068 |
|
| 1069 |
elif "::edge" in v:
|
| 1070 |
v = v.replace("::edge", "")
|
| 1071 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
else:
|
| 1073 |
print("WARNING: unsupported type")
|
| 1074 |
continue
|
|
|
|
| 1077 |
dtype = v.split("::")[-1]
|
| 1078 |
v = v.split("::")[0]
|
| 1079 |
if dtype == "vertex":
|
| 1080 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1081 |
elif dtype == "edge":
|
| 1082 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1083 |
else:
|
| 1084 |
d[k] = (
|
| 1085 |
json.loads(v)
|
|
|
|
| 1115 |
)
|
| 1116 |
return "{" + ", ".join(props) + "}"
|
| 1117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
async def _query(
|
| 1119 |
self,
|
| 1120 |
query: str,
|
|
|
|
| 1165 |
return result
|
| 1166 |
|
| 1167 |
async def has_node(self, node_id: str) -> bool:
|
| 1168 |
+
entity_name_label = node_id.strip('"')
|
| 1169 |
|
| 1170 |
query = """SELECT * FROM cypher('%s', $$
|
| 1171 |
+
MATCH (n:base {entity_id: "%s"})
|
| 1172 |
RETURN count(n) > 0 AS node_exists
|
| 1173 |
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
| 1174 |
|
|
|
|
| 1177 |
return single_result["node_exists"]
|
| 1178 |
|
| 1179 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 1180 |
+
src_label = source_node_id.strip('"')
|
| 1181 |
+
tgt_label = target_node_id.strip('"')
|
| 1182 |
|
| 1183 |
query = """SELECT * FROM cypher('%s', $$
|
| 1184 |
+
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
| 1185 |
RETURN COUNT(r) > 0 AS edge_exists
|
| 1186 |
$$) AS (edge_exists bool)""" % (
|
| 1187 |
self.graph_name,
|
|
|
|
| 1194 |
return single_result["edge_exists"]
|
| 1195 |
|
| 1196 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 1197 |
+
label = node_id.strip('"')
|
| 1198 |
query = """SELECT * FROM cypher('%s', $$
|
| 1199 |
+
MATCH (n:base {entity_id: "%s"})
|
| 1200 |
RETURN n
|
| 1201 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
| 1202 |
record = await self._query(query)
|
| 1203 |
if record:
|
| 1204 |
+
print(f"Record: {record}")
|
| 1205 |
node = record[0]
|
| 1206 |
node_dict = node["n"]
|
| 1207 |
|
|
|
|
| 1209 |
return None
|
| 1210 |
|
| 1211 |
async def node_degree(self, node_id: str) -> int:
|
| 1212 |
+
label = node_id.strip('"')
|
| 1213 |
|
| 1214 |
query = """SELECT * FROM cypher('%s', $$
|
| 1215 |
+
MATCH (n:base {entity_id: "%s"})-[]->(x)
|
| 1216 |
RETURN count(x) AS total_edge_count
|
| 1217 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
| 1218 |
record = (await self._query(query))[0]
|
|
|
|
| 1236 |
async def get_edge(
|
| 1237 |
self, source_node_id: str, target_node_id: str
|
| 1238 |
) -> dict[str, str] | None:
|
| 1239 |
+
src_label = source_node_id.strip('"')
|
| 1240 |
+
tgt_label = target_node_id.strip('"')
|
| 1241 |
|
| 1242 |
query = """SELECT * FROM cypher('%s', $$
|
| 1243 |
+
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
|
| 1244 |
RETURN properties(r) as edge_properties
|
| 1245 |
LIMIT 1
|
| 1246 |
$$) AS (edge_properties agtype)""" % (
|
|
|
|
| 1250 |
)
|
| 1251 |
record = await self._query(query)
|
| 1252 |
if record and record[0] and record[0]["edge_properties"]:
|
| 1253 |
+
print(f"Record: {record}")
|
| 1254 |
result = record[0]["edge_properties"]
|
| 1255 |
|
| 1256 |
return result
|
|
|
|
| 1260 |
Retrieves all edges (relationships) for a particular node identified by its label.
|
| 1261 |
:return: list of dictionaries containing edge information
|
| 1262 |
"""
|
| 1263 |
+
label = source_node_id.strip('"')
|
| 1264 |
|
| 1265 |
query = """SELECT * FROM cypher('%s', $$
|
| 1266 |
+
MATCH (n:base {entity_id: "%s"})
|
| 1267 |
OPTIONAL MATCH (n)-[]-(connected:base)
|
| 1268 |
RETURN n, connected
|
| 1269 |
$$) AS (n agtype, connected agtype)""" % (
|
|
|
|
| 1277 |
source_node = record["n"] if record["n"] else None
|
| 1278 |
connected_node = record["connected"] if record["connected"] else None
|
| 1279 |
|
| 1280 |
+
if (
|
| 1281 |
+
source_node
|
| 1282 |
+
and connected_node
|
| 1283 |
+
and "properties" in source_node
|
| 1284 |
+
and "properties" in connected_node
|
| 1285 |
+
):
|
| 1286 |
+
source_label = source_node["properties"].get("entity_id")
|
| 1287 |
+
target_label = connected_node["properties"].get("entity_id")
|
|
|
|
|
|
|
| 1288 |
|
| 1289 |
+
if source_label and target_label:
|
| 1290 |
+
edges.append((source_label, target_label))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1291 |
|
| 1292 |
return edges
|
| 1293 |
|
|
|
|
| 1297 |
retry=retry_if_exception_type((PGGraphQueryException,)),
|
| 1298 |
)
|
| 1299 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
| 1300 |
+
label = node_id.strip('"')
|
| 1301 |
+
properties = self._format_properties(node_data)
|
| 1302 |
|
| 1303 |
query = """SELECT * FROM cypher('%s', $$
|
| 1304 |
+
MERGE (n:base {entity_id: "%s"})
|
| 1305 |
SET n += %s
|
| 1306 |
RETURN n
|
| 1307 |
$$) AS (n agtype)""" % (
|
| 1308 |
self.graph_name,
|
| 1309 |
label,
|
| 1310 |
+
properties,
|
| 1311 |
)
|
| 1312 |
|
| 1313 |
try:
|
|
|
|
| 1333 |
target_node_id (str): Label of the target node (used as identifier)
|
| 1334 |
edge_data (dict): dictionary of properties to set on the edge
|
| 1335 |
"""
|
| 1336 |
+
src_label = source_node_id.strip('"')
|
| 1337 |
+
tgt_label = target_node_id.strip('"')
|
| 1338 |
+
edge_properties = self._format_properties(edge_data)
|
| 1339 |
|
| 1340 |
query = """SELECT * FROM cypher('%s', $$
|
| 1341 |
+
MATCH (source:base {entity_id: "%s"})
|
| 1342 |
WITH source
|
| 1343 |
+
MATCH (target:base {entity_id: "%s"})
|
| 1344 |
MERGE (source)-[r:DIRECTED]->(target)
|
| 1345 |
SET r += %s
|
| 1346 |
RETURN r
|
|
|
|
| 1348 |
self.graph_name,
|
| 1349 |
src_label,
|
| 1350 |
tgt_label,
|
| 1351 |
+
edge_properties,
|
| 1352 |
)
|
| 1353 |
|
| 1354 |
try:
|
|
|
|
| 1368 |
Args:
|
| 1369 |
node_id (str): The ID of the node to delete.
|
| 1370 |
"""
|
| 1371 |
+
label = node_id.strip('"')
|
| 1372 |
|
| 1373 |
query = """SELECT * FROM cypher('%s', $$
|
| 1374 |
MATCH (n:base {entity_id: "%s"})
|
|
|
|
| 1388 |
Args:
|
| 1389 |
node_ids (list[str]): A list of node IDs to remove.
|
| 1390 |
"""
|
| 1391 |
+
node_ids = [node_id.strip('"') for node_id in node_ids]
|
| 1392 |
+
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
|
|
|
|
|
|
|
| 1393 |
|
| 1394 |
query = """SELECT * FROM cypher('%s', $$
|
| 1395 |
MATCH (n:base)
|
| 1396 |
+
WHERE n.entity_id IN [%s]
|
| 1397 |
DETACH DELETE n
|
| 1398 |
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
| 1399 |
|
|
|
|
| 1411 |
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
| 1412 |
"""
|
| 1413 |
for source, target in edges:
|
| 1414 |
+
src_label = source.strip('"')
|
| 1415 |
+
tgt_label = target.strip('"')
|
| 1416 |
|
| 1417 |
query = """SELECT * FROM cypher('%s', $$
|
| 1418 |
+
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
|
| 1419 |
DELETE r
|
| 1420 |
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
| 1421 |
|
|
|
|
| 1466 |
return await embed_func()
|
| 1467 |
|
| 1468 |
async def get_knowledge_graph(
|
| 1469 |
+
self,
|
| 1470 |
+
node_label: str,
|
| 1471 |
+
max_depth: int = 3,
|
| 1472 |
+
max_nodes: int = MAX_GRAPH_NODES,
|
| 1473 |
) -> KnowledgeGraph:
|
| 1474 |
"""
|
| 1475 |
+
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
| 1476 |
|
| 1477 |
Args:
|
| 1478 |
+
node_label: Label of the starting node, * means all nodes
|
| 1479 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
| 1480 |
+
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
| 1481 |
|
| 1482 |
Returns:
|
| 1483 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
| 1484 |
+
indicating whether the graph was truncated due to max_nodes limit
|
| 1485 |
"""
|
|
|
|
| 1486 |
|
| 1487 |
# Build the query based on whether we want the full graph or a specific subgraph.
|
| 1488 |
if node_label == "*":
|
| 1489 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1490 |
+
MATCH (n:base)
|
| 1491 |
+
OPTIONAL MATCH (n)-[r]->(target:base)
|
| 1492 |
+
RETURN collect(distinct n) AS n, collect(distinct r) AS r
|
| 1493 |
+
LIMIT {MAX_GRAPH_NODES}
|
| 1494 |
+
$$) AS (n agtype, r agtype)"""
|
| 1495 |
else:
|
| 1496 |
+
strip_label = node_label.strip('"')
|
| 1497 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1498 |
+
MATCH (n:base {{entity_id: "{strip_label}"}})
|
| 1499 |
+
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
| 1500 |
+
RETURN nodes(p) AS n, relationships(p) AS r
|
| 1501 |
+
LIMIT {max_nodes}
|
| 1502 |
+
$$) AS (n agtype, r agtype)"""
|
| 1503 |
|
| 1504 |
results = await self._query(query)
|
| 1505 |
|
| 1506 |
+
# Process the query results with deduplication by node and edge IDs
|
| 1507 |
+
nodes_dict = {}
|
| 1508 |
+
edges_dict = {}
|
| 1509 |
+
for result in results:
|
| 1510 |
+
# Handle single node cases
|
| 1511 |
+
if result.get("n") and isinstance(result["n"], dict):
|
| 1512 |
+
node_id = str(result["n"]["id"])
|
| 1513 |
+
if node_id not in nodes_dict:
|
| 1514 |
+
nodes_dict[node_id] = KnowledgeGraphNode(
|
| 1515 |
+
id=node_id,
|
| 1516 |
+
labels=[result["n"]["properties"]["entity_id"]],
|
| 1517 |
+
properties=result["n"]["properties"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1518 |
)
|
| 1519 |
+
# Handle node list cases
|
| 1520 |
+
elif result.get("n") and isinstance(result["n"], list):
|
| 1521 |
+
for node in result["n"]:
|
| 1522 |
+
if isinstance(node, dict) and "id" in node:
|
| 1523 |
+
node_id = str(node["id"])
|
| 1524 |
+
if node_id not in nodes_dict and "properties" in node:
|
| 1525 |
+
nodes_dict[node_id] = KnowledgeGraphNode(
|
| 1526 |
+
id=node_id,
|
| 1527 |
+
labels=[node["properties"]["entity_id"]],
|
| 1528 |
+
properties=node["properties"],
|
| 1529 |
+
)
|
| 1530 |
+
|
| 1531 |
+
# Handle single edge cases
|
| 1532 |
+
if result.get("r") and isinstance(result["r"], dict):
|
| 1533 |
+
edge_id = str(result["r"]["id"])
|
| 1534 |
+
if edge_id not in edges_dict:
|
| 1535 |
+
edges_dict[edge_id] = KnowledgeGraphEdge(
|
| 1536 |
+
id=edge_id,
|
| 1537 |
+
type="DIRECTED",
|
| 1538 |
+
source=str(result["r"]["start_id"]),
|
| 1539 |
+
target=str(result["r"]["end_id"]),
|
| 1540 |
+
properties=result["r"]["properties"],
|
| 1541 |
+
)
|
| 1542 |
+
# Handle edge list cases
|
| 1543 |
+
elif result.get("r") and isinstance(result["r"], list):
|
| 1544 |
+
for edge in result["r"]:
|
| 1545 |
+
if isinstance(edge, dict) and "id" in edge:
|
| 1546 |
+
edge_id = str(edge["id"])
|
| 1547 |
+
if edge_id not in edges_dict:
|
| 1548 |
+
edges_dict[edge_id] = KnowledgeGraphEdge(
|
| 1549 |
+
id=edge_id,
|
| 1550 |
+
type="DIRECTED",
|
| 1551 |
+
source=str(edge["start_id"]),
|
| 1552 |
+
target=str(edge["end_id"]),
|
| 1553 |
+
properties=edge["properties"],
|
| 1554 |
+
)
|
| 1555 |
|
| 1556 |
+
# Construct and return the KnowledgeGraph with deduplicated nodes and edges
|
| 1557 |
kg = KnowledgeGraph(
|
| 1558 |
+
nodes=list(nodes_dict.values()),
|
| 1559 |
+
edges=list(edges_dict.values()),
|
| 1560 |
+
is_truncated=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1561 |
)
|
| 1562 |
|
| 1563 |
return kg
|