yangdx commited on
Commit
b20527f
·
1 Parent(s): 3b80bfc

Fix get_node error for PostgreSQL graph storage

Browse files
lightrag/base.py CHANGED
@@ -297,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
297
 
298
  @abstractmethod
299
  async def get_node(self, node_id: str) -> dict[str, str] | None:
300
- """Get an edge by its source and target node ids."""
301
 
302
  @abstractmethod
303
  async def get_edge(
304
  self, source_node_id: str, target_node_id: str
305
  ) -> dict[str, str] | None:
306
- """Get all edges connected to a node."""
307
 
308
  @abstractmethod
309
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
 
297
 
298
  @abstractmethod
299
  async def get_node(self, node_id: str) -> dict[str, str] | None:
300
+ """Get node by its label identifier, return only node properties"""
301
 
302
  @abstractmethod
303
  async def get_edge(
304
  self, source_node_id: str, target_node_id: str
305
  ) -> dict[str, str] | None:
306
+ """Get edge properties between two nodes"""
307
 
308
  @abstractmethod
309
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
lightrag/kg/neo4j_impl.py CHANGED
@@ -267,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
267
  raise
268
 
269
  async def get_node(self, node_id: str) -> dict[str, str] | None:
270
- """Get node by its label identifier.
271
 
272
  Args:
273
  node_id: The node label to look up
 
267
  raise
268
 
269
  async def get_node(self, node_id: str) -> dict[str, str] | None:
270
+ """Get node by its label identifier, return only node properties
271
 
272
  Args:
273
  node_id: The node label to look up
lightrag/kg/postgres_impl.py CHANGED
@@ -1194,6 +1194,8 @@ class PGGraphStorage(BaseGraphStorage):
1194
  return single_result["edge_exists"]
1195
 
1196
  async def get_node(self, node_id: str) -> dict[str, str] | None:
 
 
1197
  label = node_id.strip('"')
1198
  query = """SELECT * FROM cypher('%s', $$
1199
  MATCH (n:base {entity_id: "%s"})
@@ -1202,7 +1204,7 @@ class PGGraphStorage(BaseGraphStorage):
1202
  record = await self._query(query)
1203
  if record:
1204
  node = record[0]
1205
- node_dict = node["n"]
1206
 
1207
  return node_dict
1208
  return None
@@ -1235,6 +1237,8 @@ class PGGraphStorage(BaseGraphStorage):
1235
  async def get_edge(
1236
  self, source_node_id: str, target_node_id: str
1237
  ) -> dict[str, str] | None:
 
 
1238
  src_label = source_node_id.strip('"')
1239
  tgt_label = target_node_id.strip('"')
1240
 
 
1194
  return single_result["edge_exists"]
1195
 
1196
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1197
+ """Get node by its label identifier, return only node properties"""
1198
+
1199
  label = node_id.strip('"')
1200
  query = """SELECT * FROM cypher('%s', $$
1201
  MATCH (n:base {entity_id: "%s"})
 
1204
  record = await self._query(query)
1205
  if record:
1206
  node = record[0]
1207
+ node_dict = node["n"]["properties"]
1208
 
1209
  return node_dict
1210
  return None
 
1237
  async def get_edge(
1238
  self, source_node_id: str, target_node_id: str
1239
  ) -> dict[str, str] | None:
1240
+ """Get edge properties between two nodes"""
1241
+
1242
  src_label = source_node_id.strip('"')
1243
  tgt_label = target_node_id.strip('"')
1244