File size: 6,792 Bytes
1d80bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import networkx as nx
import json
import matplotlib.pyplot as plt


class Node:
    def __init__(self, name: str, value=None, parent=None, children: list = []):
        self.name = name
        self.children = set(children)
        self.parent = parent
        self.value = value

    def __repr__(self):
        return self.name

    def __str__(self):
        return self.name

    def __eq__(self, other):
        return self.name == other.name

    def __hash__(self) -> int:
        return hash(self.name)

    # make serializable for json
    def __getstate__(self):
        return self.__dict__

    def __dict__(self):
        # return a dict of the node's attributes
        return {
            "name": self.name,
            "children": self.children,
            "parent": self.parent,
            "value": self.value,
        }

    def to_json(self):
        """
        Returns a JSON string representation of the node.
        """
        return json.dumps(self.__dict__)

    def add_child(self, child):
        self.children.add(child)

    def has_children(self):
        return len(self.children) > 0

    def set_parent(self, new_parent):
        self.parent = new_parent

    def set_value(self, new_value):
        self.value = new_value


def read_json(fname: str) -> dict:
    assert fname.endswith(".json"), "File must be a json file"
    with open(fname, "r") as f:
        data = json.load(f)
    return dict(data)


def build_tree_from_dict(data: dict, connect_children: bool = True):
    # every dict key is a node's name
    # dict value is a dict with keys "value", "parent", "children"
    # "value" is the node's value
    # "parent" is the node's parent's name
    # "children" is a list of the node's children's names
    # create a networkx graph
    G = nx.Graph()
    nodes_dict = dict()
    # build the nodes
    for name, info in data.items():
        value = info["value"]
        parent = info["parent"]
        children: list = info["children"]
        nodes_dict[name] = Node(
            name=name, parent=parent, children=children, value=value
        )
        G.add_node(nodes_dict[name], value=value)
    # build the edges
    for _, node in nodes_dict.items():
        for child in node.children:
            G.add_edge(node, nodes_dict[child])
            # connect children to each other if connect_children is True
            if connect_children:
                for child2 in node.children:
                    if child != child2:
                        G.add_edge(nodes_dict[child], nodes_dict[child2])
    return G, nodes_dict


def build_tree_from_file(fname: str):
    data = read_json(fname)
    return build_tree_from_dict(data)


# calculate the number of edges between two nodes
def num_edges_between_nodes(G, node1, node2):
    return len(nx.shortest_path(G, node1, node2)) - 1


def explore_bfs(G: nx.Graph, source: Node, nodes_dict: dict[str, Node]):
    # start from a source node and explore the graph in a breadth-first manner
    # prioritize nodes with non-empty values
    # explore the graph and return a list of nodes in the order they were explored
    explored_nodes = []
    queue = [source]
    while queue:
        node = queue.pop(0)
        explored_nodes.append(node)
        for child in node.children:
            if nodes_dict[child].value:
                queue.insert(0, nodes_dict[child])
            else:
                queue.append(nodes_dict[child])
    return explored_nodes


def from_list(node_list: list[Node], directional=True):
    # create a tree from a list of nodes
    # and label the edges from the first node to the last node from 1 to n
    if directional:
        G = nx.DiGraph()
    else:
        G = nx.Graph()
    G.add_nodes_from(node_list)
    for i in range(len(node_list) - 1):
        G.add_edge(node_list[i], node_list[i + 1], label=i + 1)
    return G


def visualize_graph(
    graph: nx.Graph,
    layout_graph: nx.Graph,
    title="BFS Tree",
    fig_size=(30, 20),
    title_fontsize=20,
    edge_width=1,
    font_size=9,
    node_size=500,
    node_shape="o",
    prog="dot",
):
    graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0"
    _, ax = plt.subplots(figsize=fig_size)
    ax.set_title(title, fontsize=title_fontsize)
    # also draw edge labels
    nx.draw(
        graph,
        ax=ax,
        with_labels=True,
        # color every node lightblue except the root which is colored red
        node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"])
        if len(graph.nodes) > 2
        else ["lightgreen", "red"]
        if len(graph.nodes) == 2
        else ["lightgreen"],
        edge_color="gray",
        width=edge_width,
        font_size=font_size,
        # node size to be proportional to the node's value
        node_size=node_size,
        # shape set to rectangle
        node_shape=node_shape,
        pos=nx.nx_agraph.graphviz_layout(
            layout_graph, prog=prog, root="root", args=graphviz_args
        ),
    )
    nx.draw_networkx_edge_labels(
        graph,
        pos=nx.nx_agraph.graphviz_layout(
            layout_graph, prog=prog, root="root", args=graphviz_args
        ),
        edge_labels=nx.get_edge_attributes(graph, "label"),
        font_size=font_size,
    )
    plt.show()


def get_graph(
    graph: nx.Graph,
    layout_graph: nx.Graph,
    title="BFS Tree",
    fig_size=(30, 20),
    title_fontsize=20,
    edge_width=1,
    font_size=9,
    node_size=500,
    node_shape="o",
    prog="dot",
):
    graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0"
    fig, ax = plt.subplots(figsize=fig_size)
    ax.set_title(title, fontsize=title_fontsize)
    nx.draw(
        graph,
        ax=ax,
        with_labels=True,
        # color every node lightblue except the root which is colored red
        node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"])
        if len(graph.nodes) > 2
        else ["lightgreen", "red"]
        if len(graph.nodes) == 2
        else ["lightgreen"],
        edge_color="gray",
        width=edge_width,
        font_size=font_size,
        # node size to be proportional to the node's value
        node_size=node_size,
        # shape set to rectangle
        node_shape=node_shape,
        pos=nx.nx_agraph.graphviz_layout(
            layout_graph, prog=prog, root="root", args=graphviz_args
        ),
    )
    nx.draw_networkx_edge_labels(
        graph,
        pos=nx.nx_agraph.graphviz_layout(
            layout_graph, prog=prog, root="root", args=graphviz_args
        ),
        edge_labels=nx.get_edge_attributes(graph, "label"),
        font_size=font_size,
    )
    return fig, ax