yangdx commited on
Commit
80bef8b
·
1 Parent(s): 2bf79e3

Remove unused node embedding functionality from graph storage

Browse files
lightrag/kg/age_impl.py CHANGED
@@ -6,7 +6,6 @@ import sys
6
  from contextlib import asynccontextmanager
7
  from dataclasses import dataclass
8
  from typing import Any, Dict, List, NamedTuple, Optional, Union, final
9
- import numpy as np
10
  import pipmaster as pm
11
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
12
 
@@ -668,21 +667,6 @@ class AGEStorage(BaseGraphStorage):
668
  logger.error(f"Error during edge deletion: {str(e)}")
669
  raise
670
 
671
- async def embed_nodes(
672
- self, algorithm: str
673
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
674
- """Embed nodes using the specified algorithm
675
-
676
- Args:
677
- algorithm: Name of the embedding algorithm
678
-
679
- Returns:
680
- tuple: (embedding matrix, list of node identifiers)
681
- """
682
- if algorithm not in self._node_embed_algorithms:
683
- raise ValueError(f"Node embedding algorithm {algorithm} not supported")
684
- return await self._node_embed_algorithms[algorithm]()
685
-
686
  async def get_all_labels(self) -> list[str]:
687
  """Get all node labels in the database
688
 
 
6
  from contextlib import asynccontextmanager
7
  from dataclasses import dataclass
8
  from typing import Any, Dict, List, NamedTuple, Optional, Union, final
 
9
  import pipmaster as pm
10
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
11
 
 
667
  logger.error(f"Error during edge deletion: {str(e)}")
668
  raise
669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
  async def get_all_labels(self) -> list[str]:
671
  """Get all node labels in the database
672
 
lightrag/kg/gremlin_impl.py CHANGED
@@ -6,9 +6,6 @@ import pipmaster as pm
6
  from dataclasses import dataclass
7
  from typing import Any, Dict, List, final
8
 
9
- import numpy as np
10
-
11
-
12
  from tenacity import (
13
  retry,
14
  retry_if_exception_type,
@@ -419,27 +416,6 @@ class GremlinStorage(BaseGraphStorage):
419
  logger.error(f"Error during node deletion: {str(e)}")
420
  raise
421
 
422
- async def embed_nodes(
423
- self, algorithm: str
424
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
425
- """
426
- Embed nodes using the specified algorithm.
427
- Currently, only node2vec is supported but never called.
428
-
429
- Args:
430
- algorithm: The name of the embedding algorithm to use
431
-
432
- Returns:
433
- A tuple of (embeddings, node_ids)
434
-
435
- Raises:
436
- NotImplementedError: If the specified algorithm is not supported
437
- ValueError: If the algorithm is not supported
438
- """
439
- if algorithm not in self._node_embed_algorithms:
440
- raise ValueError(f"Node embedding algorithm {algorithm} not supported")
441
- return await self._node_embed_algorithms[algorithm]()
442
-
443
  async def get_all_labels(self) -> list[str]:
444
  """
445
  Get all node entity_names in the graph
 
6
  from dataclasses import dataclass
7
  from typing import Any, Dict, List, final
8
 
 
 
 
9
  from tenacity import (
10
  retry,
11
  retry_if_exception_type,
 
416
  logger.error(f"Error during node deletion: {str(e)}")
417
  raise
418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  async def get_all_labels(self) -> list[str]:
420
  """
421
  Get all node entity_names in the graph
lightrag/kg/mongo_impl.py CHANGED
@@ -663,20 +663,6 @@ class MongoGraphStorage(BaseGraphStorage):
663
  # Remove the node doc
664
  await self.collection.delete_one({"_id": node_id})
665
 
666
- #
667
- # -------------------------------------------------------------------------
668
- # EMBEDDINGS (NOT IMPLEMENTED)
669
- # -------------------------------------------------------------------------
670
- #
671
-
672
- async def embed_nodes(
673
- self, algorithm: str
674
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
675
- """
676
- Placeholder for demonstration, raises NotImplementedError.
677
- """
678
- raise NotImplementedError("Node embedding is not used in lightrag.")
679
-
680
  #
681
  # -------------------------------------------------------------------------
682
  # QUERY
 
663
  # Remove the node doc
664
  await self.collection.delete_one({"_id": node_id})
665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  #
667
  # -------------------------------------------------------------------------
668
  # QUERY
lightrag/kg/neo4j_impl.py CHANGED
@@ -2,8 +2,7 @@ import inspect
2
  import os
3
  import re
4
  from dataclasses import dataclass
5
- from typing import Any, final
6
- import numpy as np
7
  import configparser
8
 
9
 
@@ -1126,11 +1125,6 @@ class Neo4JStorage(BaseGraphStorage):
1126
  logger.error(f"Error during edge deletion: {str(e)}")
1127
  raise
1128
 
1129
- async def embed_nodes(
1130
- self, algorithm: str
1131
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
1132
- raise NotImplementedError
1133
-
1134
  async def drop(self) -> dict[str, str]:
1135
  """Drop all data from storage and clean up resources
1136
 
 
2
  import os
3
  import re
4
  from dataclasses import dataclass
5
+ from typing import final
 
6
  import configparser
7
 
8
 
 
1125
  logger.error(f"Error during edge deletion: {str(e)}")
1126
  raise
1127
 
 
 
 
 
 
1128
  async def drop(self) -> dict[str, str]:
1129
  """Drop all data from storage and clean up resources
1130
 
lightrag/kg/networkx_impl.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  from dataclasses import dataclass
3
- from typing import Any, final
4
- import numpy as np
5
 
6
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
7
  from lightrag.utils import logger
@@ -16,7 +15,6 @@ if not pm.is_installed("graspologic"):
16
  pm.install("graspologic")
17
 
18
  import networkx as nx
19
- from graspologic import embed
20
  from .shared_storage import (
21
  get_storage_lock,
22
  get_update_flag,
@@ -42,40 +40,6 @@ class NetworkXStorage(BaseGraphStorage):
42
  )
43
  nx.write_graphml(graph, file_name)
44
 
45
- # TODO:deprecated, remove later
46
- @staticmethod
47
- def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
48
- """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
49
- Ensure an undirected graph with the same relationships will always be read the same way.
50
- """
51
- fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
52
-
53
- sorted_nodes = graph.nodes(data=True)
54
- sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
55
-
56
- fixed_graph.add_nodes_from(sorted_nodes)
57
- edges = list(graph.edges(data=True))
58
-
59
- if not graph.is_directed():
60
-
61
- def _sort_source_target(edge):
62
- source, target, edge_data = edge
63
- if source > target:
64
- temp = source
65
- source = target
66
- target = temp
67
- return source, target, edge_data
68
-
69
- edges = [_sort_source_target(edge) for edge in edges]
70
-
71
- def _get_edge_key(source: Any, target: Any) -> str:
72
- return f"{source} -> {target}"
73
-
74
- edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
75
-
76
- fixed_graph.add_edges_from(edges)
77
- return fixed_graph
78
-
79
  def __post_init__(self):
80
  self._graphml_xml_file = os.path.join(
81
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
@@ -191,24 +155,6 @@ class NetworkXStorage(BaseGraphStorage):
191
  else:
192
  logger.warning(f"Node {node_id} not found in the graph for deletion.")
193
 
194
- # TODO: NOT USED
195
- async def embed_nodes(
196
- self, algorithm: str
197
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
198
- if algorithm not in self._node_embed_algorithms:
199
- raise ValueError(f"Node embedding algorithm {algorithm} not supported")
200
- return await self._node_embed_algorithms[algorithm]()
201
-
202
- # TODO: NOT USED
203
- async def _node2vec_embed(self):
204
- graph = await self._get_graph()
205
- embeddings, nodes = embed.node2vec_embed(
206
- graph,
207
- **self.global_config["node2vec_params"],
208
- )
209
- nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
210
- return embeddings, nodes_ids
211
-
212
  async def remove_nodes(self, nodes: list[str]):
213
  """Delete multiple nodes
214
 
 
1
  import os
2
  from dataclasses import dataclass
3
+ from typing import final
 
4
 
5
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
6
  from lightrag.utils import logger
 
15
  pm.install("graspologic")
16
 
17
  import networkx as nx
 
18
  from .shared_storage import (
19
  get_storage_lock,
20
  get_update_flag,
 
40
  )
41
  nx.write_graphml(graph, file_name)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def __post_init__(self):
44
  self._graphml_xml_file = os.path.join(
45
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
 
155
  else:
156
  logger.warning(f"Node {node_id} not found in the graph for deletion.")
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  async def remove_nodes(self, nodes: list[str]):
159
  """Delete multiple nodes
160
 
lightrag/kg/postgres_impl.py CHANGED
@@ -1485,24 +1485,6 @@ class PGGraphStorage(BaseGraphStorage):
1485
  labels = [result["label"] for result in results]
1486
  return labels
1487
 
1488
- async def embed_nodes(
1489
- self, algorithm: str
1490
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
1491
- """
1492
- Generate node embeddings using the specified algorithm.
1493
-
1494
- Args:
1495
- algorithm (str): The name of the embedding algorithm to use.
1496
-
1497
- Returns:
1498
- tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs.
1499
- """
1500
- if algorithm not in self._node_embed_algorithms:
1501
- raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
1502
-
1503
- embed_func = self._node_embed_algorithms[algorithm]
1504
- return await embed_func()
1505
-
1506
  async def get_knowledge_graph(
1507
  self,
1508
  node_label: str,
 
1485
  labels = [result["label"] for result in results]
1486
  return labels
1487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1488
  async def get_knowledge_graph(
1489
  self,
1490
  node_label: str,
lightrag/kg/tidb_impl.py CHANGED
@@ -800,13 +800,6 @@ class TiDBGraphStorage(BaseGraphStorage):
800
  }
801
  await self.db.execute(merge_sql, data)
802
 
803
- async def embed_nodes(
804
- self, algorithm: str
805
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
806
- if algorithm not in self._node_embed_algorithms:
807
- raise ValueError(f"Node embedding algorithm {algorithm} not supported")
808
- return await self._node_embed_algorithms[algorithm]()
809
-
810
  # Query
811
 
812
  async def has_node(self, node_id: str) -> bool:
 
800
  }
801
  await self.db.execute(merge_sql, data)
802
 
 
 
 
 
 
 
 
803
  # Query
804
 
805
  async def has_node(self, node_id: str) -> bool: