inebrahim99 commited on
Commit
fb9573a
·
verified ·
1 Parent(s): 33ab0a2

Update PathRAG/storage.py

Browse files
Files changed (1) hide show
  1. PathRAG/storage.py +20 -50
PathRAG/storage.py CHANGED
@@ -154,7 +154,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
154
  relations = [
155
  dp
156
  for dp in self.client_storage["data"]
157
- if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
158
  ]
159
  ids_to_delete = [relation["__id__"] for relation in relations]
160
 
@@ -181,10 +181,6 @@ class NetworkXStorage(BaseGraphStorage):
181
  if os.path.exists(file_name):
182
  return nx.read_graphml(file_name)
183
  return None
184
- # def load_nx_graph(file_name) -> nx.Graph:
185
- # if os.path.exists(file_name):
186
- # return nx.read_graphml(file_name)
187
- # return None
188
 
189
  @staticmethod
190
  def write_nx_graph(graph: nx.DiGraph, file_name):
@@ -195,49 +191,27 @@ class NetworkXStorage(BaseGraphStorage):
195
 
196
  @staticmethod
197
  def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
198
- """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
199
- Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
200
- """
201
  from graspologic.utils import largest_connected_component
202
-
203
  graph = graph.copy()
204
  graph = cast(nx.Graph, largest_connected_component(graph))
205
  node_mapping = {
206
  node: html.unescape(node.upper().strip()) for node in graph.nodes()
207
- } # type: ignore
208
  graph = nx.relabel_nodes(graph, node_mapping)
209
  return NetworkXStorage._stabilize_graph(graph)
210
 
211
  @staticmethod
212
  def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
213
- """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
214
- Ensure an undirected graph with the same relationships will always be read the same way.
215
- """
216
  fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
217
-
218
- sorted_nodes = graph.nodes(data=True)
219
- sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
220
-
221
  fixed_graph.add_nodes_from(sorted_nodes)
 
222
  edges = list(graph.edges(data=True))
223
-
224
  if not graph.is_directed():
225
-
226
- def _sort_source_target(edge):
227
- source, target, edge_data = edge
228
- if source > target:
229
- temp = source
230
- source = target
231
- target = temp
232
- return source, target, edge_data
233
-
234
- edges = [_sort_source_target(edge) for edge in edges]
235
-
236
- def _get_edge_key(source: Any, target: Any) -> str:
237
- return f"{source} -> {target}"
238
-
239
- edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
240
-
241
  fixed_graph.add_edges_from(edges)
242
  return fixed_graph
243
 
@@ -282,21 +256,20 @@ class NetworkXStorage(BaseGraphStorage):
282
  if self._graph.has_node(source_node_id):
283
  return list(self._graph.edges(source_node_id))
284
  return None
 
285
  async def get_node_in_edges(self, source_node_id: str):
286
  if self._graph.has_node(source_node_id):
287
  return list(self._graph.in_edges(source_node_id))
288
  return None
 
289
  async def get_node_out_edges(self, source_node_id: str):
290
  if self._graph.has_node(source_node_id):
291
  return list(self._graph.out_edges(source_node_id))
292
  return None
293
 
294
- async def get_pagerank(self,source_node_id:str):
295
- pagerank_list=nx.pagerank(self._graph)
296
- if source_node_id in pagerank_list:
297
- return pagerank_list[source_node_id]
298
- else:
299
- print("pagerank failed")
300
 
301
  async def upsert_node(self, node_id: str, node_data: dict[str, str]):
302
  self._graph.add_node(node_id, **node_data)
@@ -307,11 +280,6 @@ class NetworkXStorage(BaseGraphStorage):
307
  self._graph.add_edge(source_node_id, target_node_id, **edge_data)
308
 
309
  async def delete_node(self, node_id: str):
310
- """
311
- Delete a node from the graph based on the specified node_id.
312
-
313
- :param node_id: The node_id to delete
314
- """
315
  if self._graph.has_node(node_id):
316
  self._graph.remove_node(node_id)
317
  logger.info(f"Node {node_id} deleted from the graph.")
@@ -323,19 +291,21 @@ class NetworkXStorage(BaseGraphStorage):
323
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
324
  return await self._node_embed_algorithms[algorithm]()
325
 
326
- # @TODO: NOT USED
327
  async def _node2vec_embed(self):
328
  from graspologic import embed
329
-
330
  embeddings, nodes = embed.node2vec_embed(
331
  self._graph,
332
  **self.global_config["node2vec_params"],
333
  )
334
-
335
  nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
336
  return embeddings, nodes_ids
337
 
338
- async def edges(self):
339
- return self._graph.edges()
340
  async def nodes(self):
 
341
  return self._graph.nodes()
 
 
 
 
 
154
  relations = [
155
  dp
156
  for dp in self.client_storage["data"]
157
+ if dp.get("src_id") == entity_name or dp.get("tgt_id") == entity_name
158
  ]
159
  ids_to_delete = [relation["__id__"] for relation in relations]
160
 
 
181
  if os.path.exists(file_name):
182
  return nx.read_graphml(file_name)
183
  return None
 
 
 
 
184
 
185
  @staticmethod
186
  def write_nx_graph(graph: nx.DiGraph, file_name):
 
191
 
192
  @staticmethod
193
  def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
 
 
 
194
  from graspologic.utils import largest_connected_component
 
195
  graph = graph.copy()
196
  graph = cast(nx.Graph, largest_connected_component(graph))
197
  node_mapping = {
198
  node: html.unescape(node.upper().strip()) for node in graph.nodes()
199
+ }
200
  graph = nx.relabel_nodes(graph, node_mapping)
201
  return NetworkXStorage._stabilize_graph(graph)
202
 
203
  @staticmethod
204
  def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
 
 
 
205
  fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
206
+ sorted_nodes = sorted(list(graph.nodes(data=True)), key=lambda x: x[0])
 
 
 
207
  fixed_graph.add_nodes_from(sorted_nodes)
208
+
209
  edges = list(graph.edges(data=True))
 
210
  if not graph.is_directed():
211
+ edges = sorted(edges, key=lambda x: (min(x[0], x[1]), max(x[0], x[1])))
212
+ else:
213
+ edges = sorted(edges, key=lambda x: (x[0], x[1]))
214
+
 
 
 
 
 
 
 
 
 
 
 
 
215
  fixed_graph.add_edges_from(edges)
216
  return fixed_graph
217
 
 
256
  if self._graph.has_node(source_node_id):
257
  return list(self._graph.edges(source_node_id))
258
  return None
259
+
260
  async def get_node_in_edges(self, source_node_id: str):
261
  if self._graph.has_node(source_node_id):
262
  return list(self._graph.in_edges(source_node_id))
263
  return None
264
+
265
  async def get_node_out_edges(self, source_node_id: str):
266
  if self._graph.has_node(source_node_id):
267
  return list(self._graph.out_edges(source_node_id))
268
  return None
269
 
270
+ async def get_pagerank(self, source_node_id: str):
271
+ pagerank_list = nx.pagerank(self._graph)
272
+ return pagerank_list.get(source_node_id)
 
 
 
273
 
274
  async def upsert_node(self, node_id: str, node_data: dict[str, str]):
275
  self._graph.add_node(node_id, **node_data)
 
280
  self._graph.add_edge(source_node_id, target_node_id, **edge_data)
281
 
282
  async def delete_node(self, node_id: str):
 
 
 
 
 
283
  if self._graph.has_node(node_id):
284
  self._graph.remove_node(node_id)
285
  logger.info(f"Node {node_id} deleted from the graph.")
 
291
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
292
  return await self._node_embed_algorithms[algorithm]()
293
 
 
294
  async def _node2vec_embed(self):
295
  from graspologic import embed
 
296
  embeddings, nodes = embed.node2vec_embed(
297
  self._graph,
298
  **self.global_config["node2vec_params"],
299
  )
 
300
  nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
301
  return embeddings, nodes_ids
302
 
303
+ # --- CHANGE: Added missing methods ---
304
+ # These methods are required by operate.py for graph traversal.
305
  async def nodes(self):
306
+ """Returns all nodes in the graph."""
307
  return self._graph.nodes()
308
+
309
+ async def edges(self):
310
+ """Returns all edges in the graph."""
311
+ return self._graph.edges()