Spaces:
Running
Running
import os | |
import html | |
from typing import Any, Union, cast, Optional | |
from dataclasses import dataclass | |
import networkx as nx | |
from graphgen.utils import logger | |
from .base_storage import BaseGraphStorage | |
class NetworkXStorage(BaseGraphStorage): | |
def load_nx_graph(file_name) -> Optional[nx.Graph]: | |
if os.path.exists(file_name): | |
return nx.read_graphml(file_name) | |
return None | |
def write_nx_graph(graph: nx.Graph, file_name): | |
logger.info("Writing graph with %d nodes, %d edges", graph.number_of_nodes(), graph.number_of_edges()) | |
nx.write_graphml(graph, file_name) | |
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: | |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py | |
Return the largest connected component of the graph, with nodes and edges sorted in a stable way. | |
""" | |
from graspologic.utils import largest_connected_component | |
graph = graph.copy() | |
graph = cast(nx.Graph, largest_connected_component(graph)) | |
node_mapping = { | |
node: html.unescape(node.upper().strip()) for node in graph.nodes() | |
} # type: ignore | |
graph = nx.relabel_nodes(graph, node_mapping) | |
return NetworkXStorage._stabilize_graph(graph) | |
def _stabilize_graph(graph: nx.Graph) -> nx.Graph: | |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py | |
Ensure an undirected graph with the same relationships will always be read the same way. | |
通过对节点和边进行排序来实现 | |
""" | |
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() | |
sorted_nodes = graph.nodes(data=True) | |
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) | |
fixed_graph.add_nodes_from(sorted_nodes) | |
edges = list(graph.edges(data=True)) | |
if not graph.is_directed(): | |
def _sort_source_target(edge): | |
source, target, edge_data = edge | |
if source > target: | |
source, target = target, source | |
return source, target, edge_data | |
edges = [_sort_source_target(edge) for edge in edges] | |
def _get_edge_key(source: Any, target: Any) -> str: | |
return f"{source} -> {target}" | |
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) | |
fixed_graph.add_edges_from(edges) | |
return fixed_graph | |
def __post_init__(self): | |
""" | |
如果图文件存在,则加载图文件,否则创建一个新图 | |
""" | |
self._graphml_xml_file = os.path.join( | |
self.working_dir, f"{self.namespace}.graphml" | |
) | |
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) | |
if preloaded_graph is not None: | |
logger.info( | |
"Loaded graph from %s with %d nodes, %d edges", self._graphml_xml_file, | |
preloaded_graph.number_of_nodes(), preloaded_graph.number_of_edges() | |
) | |
self._graph = preloaded_graph or nx.Graph() | |
async def index_done_callback(self): | |
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) | |
async def has_node(self, node_id: str) -> bool: | |
return self._graph.has_node(node_id) | |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | |
return self._graph.has_edge(source_node_id, target_node_id) | |
async def get_node(self, node_id: str) -> Union[dict, None]: | |
return self._graph.nodes.get(node_id) | |
async def get_all_nodes(self) -> Union[list[dict], None]: | |
return self._graph.nodes(data=True) | |
async def node_degree(self, node_id: str) -> int: | |
return self._graph.degree(node_id) | |
async def edge_degree(self, src_id: str, tgt_id: str) -> int: | |
return self._graph.degree(src_id) + self._graph.degree(tgt_id) | |
async def get_edge( | |
self, source_node_id: str, target_node_id: str | |
) -> Union[dict, None]: | |
return self._graph.edges.get((source_node_id, target_node_id)) | |
async def get_all_edges(self) -> Union[list[dict], None]: | |
return self._graph.edges(data=True) | |
async def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]: | |
if self._graph.has_node(source_node_id): | |
return list(self._graph.edges(source_node_id, data=True)) | |
return None | |
async def get_graph(self) -> nx.Graph: | |
return self._graph | |
async def upsert_node(self, node_id: str, node_data: dict[str, str]): | |
self._graph.add_node(node_id, **node_data) | |
async def update_node(self, node_id: str, node_data: dict[str, str]): | |
if self._graph.has_node(node_id): | |
self._graph.nodes[node_id].update(node_data) | |
else: | |
logger.warning("Node %s not found in the graph for update.", node_id) | |
async def upsert_edge( | |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] | |
): | |
self._graph.add_edge(source_node_id, target_node_id, **edge_data) | |
async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]): | |
if self._graph.has_edge(source_node_id, target_node_id): | |
self._graph.edges[(source_node_id, target_node_id)].update(edge_data) | |
else: | |
logger.warning("Edge %s -> %s not found in the graph for update.", source_node_id, target_node_id) | |
async def delete_node(self, node_id: str): | |
""" | |
Delete a node from the graph based on the specified node_id. | |
:param node_id: The node_id to delete | |
""" | |
if self._graph.has_node(node_id): | |
self._graph.remove_node(node_id) | |
logger.info("Node %s deleted from the graph.", node_id) | |
else: | |
logger.warning("Node %s not found in the graph for deletion.", node_id) | |
async def clear(self): | |
""" | |
Clear the graph by removing all nodes and edges. | |
""" | |
self._graph.clear() | |
logger.info("Graph %s cleared.", self.namespace) | |