cryptocalypse commited on
Commit
a4844a1
1 Parent(s): 7915d45

Create psychohistory.py

Browse files
Files changed (1) hide show
  1. psychohistory.py +260 -0
psychohistory.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from mpl_toolkits.mplot3d import Axes3D
3
+ import networkx as nx
4
+ import random
5
+ import numpy as np
6
+ import json
7
+ import sys
8
+
9
+ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G, parent=None, node_count_per_depth=None):
10
+ """Generates a tree of nodes with positions adjusted on the x-axis, and the number of nodes on the z-axis."""
11
+ if node_count_per_depth is None:
12
+ node_count_per_depth = {}
13
+
14
+ if depth not in node_count_per_depth:
15
+ node_count_per_depth[depth] = 0
16
+
17
+ if depth > max_depth:
18
+ return node_count_per_depth
19
+
20
+ num_children = random.randint(1, max_nodes)
21
+ x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
22
+
23
+ for x in x_positions:
24
+ # Add node to the graph
25
+ node_id = len(G.nodes)
26
+ node_count_per_depth[depth] += 1
27
+ prob = random.uniform(0, 1) # Assign random probability
28
+ G.add_node(node_id, pos=(x, prob, depth)) # Use `depth` for the z position
29
+ if parent is not None:
30
+ G.add_edge(parent, node_id)
31
+ # Recursively add child nodes
32
+ generate_tree(x, current_y + 1, depth + 1, max_depth, max_nodes, x_range, G, parent=node_id, node_count_per_depth=node_count_per_depth)
33
+
34
+ return node_count_per_depth
35
+
36
+ def build_graph_from_json(json_data, G):
37
+ """Builds a graph from JSON data."""
38
+ def add_event(parent_id, event_data, prob_level):
39
+ for key, value in event_data.get('events', {}).items():
40
+ # Add node
41
+ node_id = len(G.nodes)
42
+ prob = {'high_probability': 0.9, 'medium_probability': 0.5, 'low_probability': 0.1}[prob_level]
43
+ G.add_node(node_id, pos=(len(G.nodes), prob, len(G.nodes))) # Ensure each node has 'pos'
44
+ G.add_edge(parent_id, node_id)
45
+ # Add child events
46
+ add_event(node_id, {'events': value}, key)
47
+
48
+ root_id = len(G.nodes)
49
+ G.add_node(root_id, pos=(0, 0.5, 0)) # Root node with default medium probability
50
+ if len(G.nodes) > 1:
51
+ G.add_edge(-1, root_id) # Root node without a parent
52
+ data = json.loads(json_data)
53
+ add_event(root_id, data, 'medium_probability')
54
+
55
+ def find_paths(G):
56
+ """Finds the paths with the highest and lowest average probability, and the maximum and minimum duration in graph G."""
57
+ best_path = None
58
+ worst_path = None
59
+ longest_duration_path = None
60
+ shortest_duration_path = None
61
+ best_mean_prob = -1
62
+ worst_mean_prob = float('inf')
63
+ max_duration = -1
64
+ min_duration = float('inf')
65
+
66
+ for source in G.nodes:
67
+ for target in G.nodes:
68
+ if source != target:
69
+ all_paths = list(nx.all_simple_paths(G, source=source, target=target))
70
+ for path in all_paths:
71
+ # Check if all nodes in the path have the 'pos' attribute
72
+ if not all('pos' in G.nodes[node] for node in path):
73
+ continue # Skip paths with nodes missing the 'pos' attribute
74
+
75
+ # Calculate the average probability of the path
76
+ probabilities = [G.nodes[node]['pos'][1] for node in path] # Get probabilities of the nodes in the path
77
+ mean_prob = np.mean(probabilities)
78
+
79
+ # Evaluate the path with the highest average probability
80
+ if mean_prob > best_mean_prob:
81
+ best_mean_prob = mean_prob
82
+ best_path = path
83
+
84
+ # Evaluate the path with the lowest average probability
85
+ if mean_prob < worst_mean_prob:
86
+ worst_mean_prob = mean_prob
87
+ worst_path = path
88
+
89
+ # Calculate the duration of the path
90
+ x_positions = [G.nodes[node]['pos'][0] for node in path]
91
+ duration = max(x_positions) - min(x_positions)
92
+
93
+ # Evaluate the path with the maximum duration
94
+ if duration > max_duration:
95
+ max_duration = duration
96
+ longest_duration_path = path
97
+
98
+ # Evaluate the path with the minimum duration
99
+ if duration < min_duration:
100
+ min_duration = duration
101
+ shortest_duration_path = path
102
+
103
+ return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path
104
+
105
+ def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
106
+ """Draws only the specific path in 3D using networkx and matplotlib and saves the figure to a file."""
107
+ # Create a subgraph containing only the nodes and edges of the path
108
+ H = G.subgraph(path).copy()
109
+
110
+ pos = nx.get_node_attributes(G, 'pos')
111
+
112
+ # Get data for 3D visualization
113
+ x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
114
+
115
+ fig = plt.figure(figsize=(16, 12))
116
+ ax = fig.add_subplot(111, projection='3d')
117
+
118
+ # Assign colors to the nodes based on probability
119
+ node_colors = []
120
+ for node in path:
121
+ prob = G.nodes[node]['pos'][1]
122
+ if prob < 0.33:
123
+ node_colors.append('red')
124
+ elif prob < 0.67:
125
+ node_colors.append('blue')
126
+ else:
127
+ node_colors.append('green')
128
+
129
+ # Draw nodes
130
+ ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
131
+
132
+ # Draw edges
133
+ for edge in H.edges():
134
+ x_start, y_start, z_start = pos[edge[0]]
135
+ x_end, y_end, z_end = pos[edge[1]]
136
+ ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
137
+
138
+ # Add labels to the nodes
139
+ for node, (x, y, z) in pos.items():
140
+ if node in path:
141
+ ax.text(x, y, z, str(node), fontsize=12, color='black')
142
+
143
+ # Adjust labels and title
144
+ ax.set_xlabel('Time (weeks)')
145
+ ax.set_ylabel('Event Probability')
146
+ ax.set_zlabel('Event Number')
147
+ ax.set_title('Event Tree in 3D - Path')
148
+
149
+ plt.savefig(filename, bbox_inches='tight') # Save to a file with adjusted margins
150
+ plt.close() # Close the figure to free up resources
151
+
152
+ def draw_global_tree_3d(G, filename='global_tree.png'):
153
+ """Draws the entire graph in 3D using networkx and matplotlib and saves the figure to a file."""
154
+ pos = nx.get_node_attributes(G, 'pos')
155
+
156
+ # Get data for 3D visualization
157
+ x_vals, y_vals, z_vals = zip(*pos.values())
158
+
159
+ fig = plt.figure(figsize=(16, 12))
160
+ ax = fig.add_subplot(111, projection='3d')
161
+
162
+ # Assign colors to the nodes based on probability
163
+ node_colors = []
164
+ for node, (x, prob, z) in pos.items():
165
+ if prob < 0.33:
166
+ node_colors.append('red')
167
+ elif prob < 0.67:
168
+ node_colors.append('blue')
169
+ else:
170
+ node_colors.append('green')
171
+
172
+ # Draw nodes
173
+ ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
174
+
175
+ # Draw edges
176
+ for edge in G.edges():
177
+ x_start, y_start, z_start = pos[edge[0]]
178
+ x_end, y_end, z_end = pos[edge[1]]
179
+ ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
180
+
181
+ # Add labels to the nodes
182
+ for node, (x, y, z) in pos.items():
183
+ ax.text(x, y, z, str(node), fontsize=12, color='black')
184
+
185
+ # Adjust labels and title
186
+ ax.set_xlabel('Time (weeks)')
187
+ ax.set_ylabel('Event Probability')
188
+ ax.set_zlabel('Event Number')
189
+ ax.set_title('Event Tree in 3D')
190
+
191
+ plt.savefig(filename, bbox_inches='tight') # Save to a file with adjusted margins
192
+ plt.close() # Close the figure to free up resources
193
+
194
+ def main(mode, input_file=None):
195
+ G = nx.DiGraph()
196
+
197
+ if mode == 'random':
198
+ starting_x = 0
199
+ starting_y = 0
200
+ max_depth = 5 # Maximum tree depth
201
+ max_nodes = 3 # Maximum number of child nodes
202
+ x_range = 10 # Maximum range for node x positions
203
+
204
+ # Generate the tree and get the node count per depth
205
+ generate_tree(starting_x, starting_y, 0, max_depth, max_nodes, x_range, G)
206
+ elif mode == 'json' and input_file:
207
+ with open(input_file, 'r') as file:
208
+ json_data = file.read()
209
+ build_graph_from_json(json_data, G)
210
+ else:
211
+ print("Invalid mode or input file not provided.")
212
+ return
213
+
214
+ # Find relevant paths
215
+ best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path = find_paths(G)
216
+
217
+ # Print the results
218
+ if best_path:
219
+ print(f"\nPath with the highest average probability:")
220
+ print(" -> ".join(map(str, best_path)))
221
+ print(f"Average probability: {best_mean_prob:.2f}")
222
+
223
+ if worst_path:
224
+ print(f"\nPath with the lowest average probability:")
225
+ print(" -> ".join(map(str, worst_path)))
226
+ print(f"Average probability: {worst_mean_prob:.2f}")
227
+
228
+ if longest_duration_path:
229
+ print(f"\nPath with the longest duration:")
230
+ print(" -> ".join(map(str, longest_duration_path)))
231
+ print(f"Duration: {max(G.nodes[node]['pos'][0] for node in longest_duration_path) - min(G.nodes[node]['pos'][0] for node in longest_duration_path):.2f}")
232
+
233
+ if shortest_duration_path:
234
+ print(f"\nPath with the shortest duration:")
235
+ print(" -> ".join(map(str, shortest_duration_path)))
236
+ print(f"Duration: {max(G.nodes[node]['pos'][0] for node in shortest_duration_path) - min(G.nodes[node]['pos'][0] for node in shortest_duration_path):.2f}")
237
+
238
+ # Save the global visualization
239
+ draw_global_tree_3d(G, filename='global_tree.png')
240
+
241
+ # Draw and save the 3D figure for each relevant path
242
+ if best_path:
243
+ draw_path_3d(G, path=best_path, filename='best_path.png', highlight_color='blue')
244
+
245
+ if worst_path:
246
+ draw_path_3d(G, path=worst_path, filename='worst_path.png', highlight_color='red')
247
+
248
+ if longest_duration_path:
249
+ draw_path_3d(G, path=longest_duration_path, filename='longest_duration_path.png', highlight_color='green')
250
+
251
+ if shortest_duration_path:
252
+ draw_path_3d(G, path=shortest_duration_path, filename='shortest_duration_path.png', highlight_color='purple')
253
+
254
+ if __name__ == "__main__":
255
+ if len(sys.argv) < 2:
256
+ print("Usage: python script.py <mode> [json_file]")
257
+ else:
258
+ mode = sys.argv[1]
259
+ input_file = sys.argv[2] if len(sys.argv) > 2 else None
260
+ main(mode, input_file)