Spaces:
Runtime error
Runtime error
File size: 6,792 Bytes
1d80bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import networkx as nx
import json
import matplotlib.pyplot as plt
class Node:
def __init__(self, name: str, value=None, parent=None, children: list = []):
self.name = name
self.children = set(children)
self.parent = parent
self.value = value
def __repr__(self):
return self.name
def __str__(self):
return self.name
def __eq__(self, other):
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)
# make serializable for json
def __getstate__(self):
return self.__dict__
def __dict__(self):
# return a dict of the node's attributes
return {
"name": self.name,
"children": self.children,
"parent": self.parent,
"value": self.value,
}
def to_json(self):
"""
Returns a JSON string representation of the node.
"""
return json.dumps(self.__dict__)
def add_child(self, child):
self.children.add(child)
def has_children(self):
return len(self.children) > 0
def set_parent(self, new_parent):
self.parent = new_parent
def set_value(self, new_value):
self.value = new_value
def read_json(fname: str) -> dict:
assert fname.endswith(".json"), "File must be a json file"
with open(fname, "r") as f:
data = json.load(f)
return dict(data)
def build_tree_from_dict(data: dict, connect_children: bool = True):
# every dict key is a node's name
# dict value is a dict with keys "value", "parent", "children"
# "value" is the node's value
# "parent" is the node's parent's name
# "children" is a list of the node's children's names
# create a networkx graph
G = nx.Graph()
nodes_dict = dict()
# build the nodes
for name, info in data.items():
value = info["value"]
parent = info["parent"]
children: list = info["children"]
nodes_dict[name] = Node(
name=name, parent=parent, children=children, value=value
)
G.add_node(nodes_dict[name], value=value)
# build the edges
for _, node in nodes_dict.items():
for child in node.children:
G.add_edge(node, nodes_dict[child])
# connect children to each other if connect_children is True
if connect_children:
for child2 in node.children:
if child != child2:
G.add_edge(nodes_dict[child], nodes_dict[child2])
return G, nodes_dict
def build_tree_from_file(fname: str):
data = read_json(fname)
return build_tree_from_dict(data)
# calculate the number of edges between two nodes
def num_edges_between_nodes(G, node1, node2):
return len(nx.shortest_path(G, node1, node2)) - 1
def explore_bfs(G: nx.Graph, source: Node, nodes_dict: dict[str, Node]):
# start from a source node and explore the graph in a breadth-first manner
# prioritize nodes with non-empty values
# explore the graph and return a list of nodes in the order they were explored
explored_nodes = []
queue = [source]
while queue:
node = queue.pop(0)
explored_nodes.append(node)
for child in node.children:
if nodes_dict[child].value:
queue.insert(0, nodes_dict[child])
else:
queue.append(nodes_dict[child])
return explored_nodes
def from_list(node_list: list[Node], directional=True):
# create a tree from a list of nodes
# and label the edges from the first node to the last node from 1 to n
if directional:
G = nx.DiGraph()
else:
G = nx.Graph()
G.add_nodes_from(node_list)
for i in range(len(node_list) - 1):
G.add_edge(node_list[i], node_list[i + 1], label=i + 1)
return G
def visualize_graph(
graph: nx.Graph,
layout_graph: nx.Graph,
title="BFS Tree",
fig_size=(30, 20),
title_fontsize=20,
edge_width=1,
font_size=9,
node_size=500,
node_shape="o",
prog="dot",
):
graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0"
_, ax = plt.subplots(figsize=fig_size)
ax.set_title(title, fontsize=title_fontsize)
# also draw edge labels
nx.draw(
graph,
ax=ax,
with_labels=True,
# color every node lightblue except the root which is colored red
node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"])
if len(graph.nodes) > 2
else ["lightgreen", "red"]
if len(graph.nodes) == 2
else ["lightgreen"],
edge_color="gray",
width=edge_width,
font_size=font_size,
# node size to be proportional to the node's value
node_size=node_size,
# shape set to rectangle
node_shape=node_shape,
pos=nx.nx_agraph.graphviz_layout(
layout_graph, prog=prog, root="root", args=graphviz_args
),
)
nx.draw_networkx_edge_labels(
graph,
pos=nx.nx_agraph.graphviz_layout(
layout_graph, prog=prog, root="root", args=graphviz_args
),
edge_labels=nx.get_edge_attributes(graph, "label"),
font_size=font_size,
)
plt.show()
def get_graph(
graph: nx.Graph,
layout_graph: nx.Graph,
title="BFS Tree",
fig_size=(30, 20),
title_fontsize=20,
edge_width=1,
font_size=9,
node_size=500,
node_shape="o",
prog="dot",
):
graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0"
fig, ax = plt.subplots(figsize=fig_size)
ax.set_title(title, fontsize=title_fontsize)
nx.draw(
graph,
ax=ax,
with_labels=True,
# color every node lightblue except the root which is colored red
node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"])
if len(graph.nodes) > 2
else ["lightgreen", "red"]
if len(graph.nodes) == 2
else ["lightgreen"],
edge_color="gray",
width=edge_width,
font_size=font_size,
# node size to be proportional to the node's value
node_size=node_size,
# shape set to rectangle
node_shape=node_shape,
pos=nx.nx_agraph.graphviz_layout(
layout_graph, prog=prog, root="root", args=graphviz_args
),
)
nx.draw_networkx_edge_labels(
graph,
pos=nx.nx_agraph.graphviz_layout(
layout_graph, prog=prog, root="root", args=graphviz_args
),
edge_labels=nx.get_edge_attributes(graph, "label"),
font_size=font_size,
)
return fig, ax
|