Spaces:
Paused
Paused
File size: 6,003 Bytes
ab2ded1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
from typing import Any, cast, List
import html
from graspologic.partition import hierarchical_leiden
from graspologic.utils import largest_connected_component
import networkx as nx
from networkx import is_empty
log = logging.getLogger(__name__)
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Ensure an undirected graph with the same relationships will always be read the same way."""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
# If the graph is undirected, we create the edges in a stable way, so we get the same results
# for example:
# A -> B
# in graph theory is the same as
# B -> A
# in an undirected graph
# however, this can lead to downstream issues because sometimes
# consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
# but they base some of their logic on the order of the nodes, so the order ends up being important
# so we sort the nodes in the edge in a stable way, so that we always get the same order
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
"""Normalize node names."""
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
return nx.relabel_nodes(graph, node_mapping)
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
graph = normalize_node_names(graph)
return _stabilize_graph(graph)
def _compute_leiden_communities(
graph: nx.Graph | nx.DiGraph,
max_cluster_size: int,
use_lcc: bool,
seed=0xDEADBEEF,
) -> dict[int, dict[str, int]]:
"""Return Leiden root communities."""
results: dict[int, dict[str, int]] = {}
if is_empty(graph): return results
if use_lcc:
graph = stable_largest_connected_component(graph)
community_mapping = hierarchical_leiden(
graph, max_cluster_size=max_cluster_size, random_seed=seed
)
for partition in community_mapping:
results[partition.level] = results.get(partition.level, {})
results[partition.level][partition.node] = partition.cluster
return results
def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
"""Run method definition."""
max_cluster_size = args.get("max_cluster_size", 12)
use_lcc = args.get("use_lcc", True)
if args.get("verbose", False):
log.info(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
if not graph.nodes(): return {}
node_id_to_community_map = _compute_leiden_communities(
graph=graph,
max_cluster_size=max_cluster_size,
use_lcc=use_lcc,
seed=args.get("seed", 0xDEADBEEF),
)
levels = args.get("levels")
# If they don't pass in levels, use them all
if levels is None:
levels = sorted(node_id_to_community_map.keys())
results_by_level: dict[int, dict[str, list[str]]] = {}
for level in levels:
result = {}
results_by_level[level] = result
for node_id, raw_community_id in node_id_to_community_map[level].items():
community_id = str(raw_community_id)
if community_id not in result:
result[community_id] = {"weight": 0, "nodes": []}
result[community_id]["nodes"].append(node_id)
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
weights = [comm["weight"] for _, comm in result.items()]
if not weights:continue
max_weight = max(weights)
for _, comm in result.items(): comm["weight"] /= max_weight
return results_by_level
def add_community_info2graph(graph: nx.Graph, commu_info: dict[str, dict[str, dict]]):
for lev, cluster_info in commu_info.items():
for cid, nodes in cluster_info.items():
for n in nodes["nodes"]:
if "community" not in graph.nodes[n]: graph.nodes[n]["community"] = {}
graph.nodes[n]["community"].update({lev: cid})
def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title):
for n in nodes:
if "communities" not in graph.nodes[n]:
graph.nodes[n]["communities"] = []
graph.nodes[n]["communities"].append(community_title)
|