GraphGen / graphgen /models /vis /community_visualizer.py
github-actions[bot]
Auto-sync from demo at Thu Aug 28 10:06:44 UTC 2025
ba8b592
raw
history blame
1.3 kB
from dataclasses import dataclass
from typing import Dict
import matplotlib.pyplot as plt
import networkx as nx
@dataclass
class Visualizer:
"""
Class for visualizing graphs using NetworkX and Matplotlib.
"""
graph: nx.Graph = None
communities: Dict[str, int] = None
layout: str = "spring"
max_nodes: int = 1000
node_size: int = 10
alpha: float = 0.6
def visualize(self, save_path: str = None):
n = self.graph.number_of_nodes()
if self.layout == "spring":
k = max(0.1, 1.0 / (n**0.5))
pos = nx.spring_layout(self.graph, k=k, seed=42)
else:
raise ValueError(f"Unknown layout: {self.layout}")
plt.figure(figsize=(10, 10))
node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
nx.draw_networkx_nodes(
self.graph,
pos,
node_size=self.node_size,
node_color=node_colors,
cmap=plt.cm.tab20,
alpha=self.alpha,
)
nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
plt.axis("off")
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print("Saved to", save_path)
else:
plt.show()