Spaces:
Runtime error
Runtime error
import networkx as nx | |
import json | |
import matplotlib.pyplot as plt | |
class Node: | |
def __init__(self, name: str, value=None, parent=None, children: list = []): | |
self.name = name | |
self.children = set(children) | |
self.parent = parent | |
self.value = value | |
def __repr__(self): | |
return self.name | |
def __str__(self): | |
return self.name | |
def __eq__(self, other): | |
return self.name == other.name | |
def __hash__(self) -> int: | |
return hash(self.name) | |
# make serializable for json | |
def __getstate__(self): | |
return self.__dict__ | |
def __dict__(self): | |
# return a dict of the node's attributes | |
return { | |
"name": self.name, | |
"children": self.children, | |
"parent": self.parent, | |
"value": self.value, | |
} | |
def to_json(self): | |
""" | |
Returns a JSON string representation of the node. | |
""" | |
return json.dumps(self.__dict__) | |
def add_child(self, child): | |
self.children.add(child) | |
def has_children(self): | |
return len(self.children) > 0 | |
def set_parent(self, new_parent): | |
self.parent = new_parent | |
def set_value(self, new_value): | |
self.value = new_value | |
def read_json(fname: str) -> dict: | |
assert fname.endswith(".json"), "File must be a json file" | |
with open(fname, "r") as f: | |
data = json.load(f) | |
return dict(data) | |
def build_tree_from_dict(data: dict, connect_children: bool = True): | |
# every dict key is a node's name | |
# dict value is a dict with keys "value", "parent", "children" | |
# "value" is the node's value | |
# "parent" is the node's parent's name | |
# "children" is a list of the node's children's names | |
# create a networkx graph | |
G = nx.Graph() | |
nodes_dict = dict() | |
# build the nodes | |
for name, info in data.items(): | |
value = info["value"] | |
parent = info["parent"] | |
children: list = info["children"] | |
nodes_dict[name] = Node( | |
name=name, parent=parent, children=children, value=value | |
) | |
G.add_node(nodes_dict[name], value=value) | |
# build the edges | |
for _, node in nodes_dict.items(): | |
for child in node.children: | |
G.add_edge(node, nodes_dict[child]) | |
# connect children to each other if connect_children is True | |
if connect_children: | |
for child2 in node.children: | |
if child != child2: | |
G.add_edge(nodes_dict[child], nodes_dict[child2]) | |
return G, nodes_dict | |
def build_tree_from_file(fname: str): | |
data = read_json(fname) | |
return build_tree_from_dict(data) | |
# calculate the number of edges between two nodes | |
def num_edges_between_nodes(G, node1, node2): | |
return len(nx.shortest_path(G, node1, node2)) - 1 | |
def explore_bfs(G: nx.Graph, source: Node, nodes_dict: dict[str, Node]): | |
# start from a source node and explore the graph in a breadth-first manner | |
# prioritize nodes with non-empty values | |
# explore the graph and return a list of nodes in the order they were explored | |
explored_nodes = [] | |
queue = [source] | |
while queue: | |
node = queue.pop(0) | |
explored_nodes.append(node) | |
for child in node.children: | |
if nodes_dict[child].value: | |
queue.insert(0, nodes_dict[child]) | |
else: | |
queue.append(nodes_dict[child]) | |
return explored_nodes | |
def from_list(node_list: list[Node], directional=True): | |
# create a tree from a list of nodes | |
# and label the edges from the first node to the last node from 1 to n | |
if directional: | |
G = nx.DiGraph() | |
else: | |
G = nx.Graph() | |
G.add_nodes_from(node_list) | |
for i in range(len(node_list) - 1): | |
G.add_edge(node_list[i], node_list[i + 1], label=i + 1) | |
return G | |
def visualize_graph( | |
graph: nx.Graph, | |
layout_graph: nx.Graph, | |
title="BFS Tree", | |
fig_size=(30, 20), | |
title_fontsize=20, | |
edge_width=1, | |
font_size=9, | |
node_size=500, | |
node_shape="o", | |
prog="dot", | |
): | |
graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0" | |
_, ax = plt.subplots(figsize=fig_size) | |
ax.set_title(title, fontsize=title_fontsize) | |
# also draw edge labels | |
nx.draw( | |
graph, | |
ax=ax, | |
with_labels=True, | |
# color every node lightblue except the root which is colored red | |
node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"]) | |
if len(graph.nodes) > 2 | |
else ["lightgreen", "red"] | |
if len(graph.nodes) == 2 | |
else ["lightgreen"], | |
edge_color="gray", | |
width=edge_width, | |
font_size=font_size, | |
# node size to be proportional to the node's value | |
node_size=node_size, | |
# shape set to rectangle | |
node_shape=node_shape, | |
pos=nx.nx_agraph.graphviz_layout( | |
layout_graph, prog=prog, root="root", args=graphviz_args | |
), | |
) | |
nx.draw_networkx_edge_labels( | |
graph, | |
pos=nx.nx_agraph.graphviz_layout( | |
layout_graph, prog=prog, root="root", args=graphviz_args | |
), | |
edge_labels=nx.get_edge_attributes(graph, "label"), | |
font_size=font_size, | |
) | |
plt.show() | |
def get_graph( | |
graph: nx.Graph, | |
layout_graph: nx.Graph, | |
title="BFS Tree", | |
fig_size=(30, 20), | |
title_fontsize=20, | |
edge_width=1, | |
font_size=9, | |
node_size=500, | |
node_shape="o", | |
prog="dot", | |
): | |
graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0" | |
fig, ax = plt.subplots(figsize=fig_size) | |
ax.set_title(title, fontsize=title_fontsize) | |
nx.draw( | |
graph, | |
ax=ax, | |
with_labels=True, | |
# color every node lightblue except the root which is colored red | |
node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"]) | |
if len(graph.nodes) > 2 | |
else ["lightgreen", "red"] | |
if len(graph.nodes) == 2 | |
else ["lightgreen"], | |
edge_color="gray", | |
width=edge_width, | |
font_size=font_size, | |
# node size to be proportional to the node's value | |
node_size=node_size, | |
# shape set to rectangle | |
node_shape=node_shape, | |
pos=nx.nx_agraph.graphviz_layout( | |
layout_graph, prog=prog, root="root", args=graphviz_args | |
), | |
) | |
nx.draw_networkx_edge_labels( | |
graph, | |
pos=nx.nx_agraph.graphviz_layout( | |
layout_graph, prog=prog, root="root", args=graphviz_args | |
), | |
edge_labels=nx.get_edge_attributes(graph, "label"), | |
font_size=font_size, | |
) | |
return fig, ax | |