knee_report_checklist / treegraph.py
Mahbodez's picture
Upload 5 files
1d80bec
raw
history blame contribute delete
No virus
6.79 kB
import networkx as nx
import json
import matplotlib.pyplot as plt
class Node:
def __init__(self, name: str, value=None, parent=None, children: list = []):
self.name = name
self.children = set(children)
self.parent = parent
self.value = value
def __repr__(self):
return self.name
def __str__(self):
return self.name
def __eq__(self, other):
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)
# make serializable for json
def __getstate__(self):
return self.__dict__
def __dict__(self):
# return a dict of the node's attributes
return {
"name": self.name,
"children": self.children,
"parent": self.parent,
"value": self.value,
}
def to_json(self):
"""
Returns a JSON string representation of the node.
"""
return json.dumps(self.__dict__)
def add_child(self, child):
self.children.add(child)
def has_children(self):
return len(self.children) > 0
def set_parent(self, new_parent):
self.parent = new_parent
def set_value(self, new_value):
self.value = new_value
def read_json(fname: str) -> dict:
assert fname.endswith(".json"), "File must be a json file"
with open(fname, "r") as f:
data = json.load(f)
return dict(data)
def build_tree_from_dict(data: dict, connect_children: bool = True):
# every dict key is a node's name
# dict value is a dict with keys "value", "parent", "children"
# "value" is the node's value
# "parent" is the node's parent's name
# "children" is a list of the node's children's names
# create a networkx graph
G = nx.Graph()
nodes_dict = dict()
# build the nodes
for name, info in data.items():
value = info["value"]
parent = info["parent"]
children: list = info["children"]
nodes_dict[name] = Node(
name=name, parent=parent, children=children, value=value
)
G.add_node(nodes_dict[name], value=value)
# build the edges
for _, node in nodes_dict.items():
for child in node.children:
G.add_edge(node, nodes_dict[child])
# connect children to each other if connect_children is True
if connect_children:
for child2 in node.children:
if child != child2:
G.add_edge(nodes_dict[child], nodes_dict[child2])
return G, nodes_dict
def build_tree_from_file(fname: str):
data = read_json(fname)
return build_tree_from_dict(data)
# calculate the number of edges between two nodes
def num_edges_between_nodes(G, node1, node2):
return len(nx.shortest_path(G, node1, node2)) - 1
def explore_bfs(G: nx.Graph, source: Node, nodes_dict: dict[str, Node]):
# start from a source node and explore the graph in a breadth-first manner
# prioritize nodes with non-empty values
# explore the graph and return a list of nodes in the order they were explored
explored_nodes = []
queue = [source]
while queue:
node = queue.pop(0)
explored_nodes.append(node)
for child in node.children:
if nodes_dict[child].value:
queue.insert(0, nodes_dict[child])
else:
queue.append(nodes_dict[child])
return explored_nodes
def from_list(node_list: list[Node], directional=True):
# create a tree from a list of nodes
# and label the edges from the first node to the last node from 1 to n
if directional:
G = nx.DiGraph()
else:
G = nx.Graph()
G.add_nodes_from(node_list)
for i in range(len(node_list) - 1):
G.add_edge(node_list[i], node_list[i + 1], label=i + 1)
return G
def visualize_graph(
graph: nx.Graph,
layout_graph: nx.Graph,
title="BFS Tree",
fig_size=(30, 20),
title_fontsize=20,
edge_width=1,
font_size=9,
node_size=500,
node_shape="o",
prog="dot",
):
graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0"
_, ax = plt.subplots(figsize=fig_size)
ax.set_title(title, fontsize=title_fontsize)
# also draw edge labels
nx.draw(
graph,
ax=ax,
with_labels=True,
# color every node lightblue except the root which is colored red
node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"])
if len(graph.nodes) > 2
else ["lightgreen", "red"]
if len(graph.nodes) == 2
else ["lightgreen"],
edge_color="gray",
width=edge_width,
font_size=font_size,
# node size to be proportional to the node's value
node_size=node_size,
# shape set to rectangle
node_shape=node_shape,
pos=nx.nx_agraph.graphviz_layout(
layout_graph, prog=prog, root="root", args=graphviz_args
),
)
nx.draw_networkx_edge_labels(
graph,
pos=nx.nx_agraph.graphviz_layout(
layout_graph, prog=prog, root="root", args=graphviz_args
),
edge_labels=nx.get_edge_attributes(graph, "label"),
font_size=font_size,
)
plt.show()
def get_graph(
graph: nx.Graph,
layout_graph: nx.Graph,
title="BFS Tree",
fig_size=(30, 20),
title_fontsize=20,
edge_width=1,
font_size=9,
node_size=500,
node_shape="o",
prog="dot",
):
graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0"
fig, ax = plt.subplots(figsize=fig_size)
ax.set_title(title, fontsize=title_fontsize)
nx.draw(
graph,
ax=ax,
with_labels=True,
# color every node lightblue except the root which is colored red
node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"])
if len(graph.nodes) > 2
else ["lightgreen", "red"]
if len(graph.nodes) == 2
else ["lightgreen"],
edge_color="gray",
width=edge_width,
font_size=font_size,
# node size to be proportional to the node's value
node_size=node_size,
# shape set to rectangle
node_shape=node_shape,
pos=nx.nx_agraph.graphviz_layout(
layout_graph, prog=prog, root="root", args=graphviz_args
),
)
nx.draw_networkx_edge_labels(
graph,
pos=nx.nx_agraph.graphviz_layout(
layout_graph, prog=prog, root="root", args=graphviz_args
),
edge_labels=nx.get_edge_attributes(graph, "label"),
font_size=font_size,
)
return fig, ax