none commited on
Commit
045d7d4
·
0 Parent(s):

Working version of the streamlit animation

Browse files
Files changed (4) hide show
  1. README.md +1 -0
  2. streamlit_viz.py +254 -0
  3. train_classifier.py +86 -0
  4. viz_classifier.py +215 -0
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The `id` column is baloney. There are lots of duplicates.
streamlit_viz.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import time
3
+
4
+ import plotly.graph_objects as go
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ FEATS = [
10
+ 'srcip',
11
+ 'sport',
12
+ 'dstip',
13
+ 'dsport',
14
+ 'proto',
15
+ #'state', I dropped this one when I trained the model
16
+ 'dur',
17
+ 'sbytes',
18
+ 'dbytes',
19
+ 'sttl',
20
+ 'dttl',
21
+ 'sloss',
22
+ 'dloss',
23
+ 'service',
24
+ 'Sload',
25
+ 'Dload',
26
+ 'Spkts',
27
+ 'Dpkts',
28
+ 'swin',
29
+ 'dwin',
30
+ 'stcpb',
31
+ 'dtcpb',
32
+ 'smeansz',
33
+ 'dmeansz',
34
+ 'trans_depth',
35
+ 'res_bdy_len',
36
+ 'Sjit',
37
+ 'Djit',
38
+ 'Stime',
39
+ 'Ltime',
40
+ 'Sintpkt',
41
+ 'Dintpkt',
42
+ 'tcprtt',
43
+ 'synack',
44
+ 'ackdat',
45
+ 'is_sm_ips_ports',
46
+ 'ct_state_ttl',
47
+ 'ct_flw_http_mthd',
48
+ 'is_ftp_login',
49
+ 'ct_ftp_cmd',
50
+ 'ct_srv_src',
51
+ 'ct_srv_dst',
52
+ 'ct_dst_ltm',
53
+ 'ct_src_ltm',
54
+ 'ct_src_dport_ltm',
55
+ 'ct_dst_sport_ltm',
56
+ 'ct_dst_src_ltm',
57
+ ]
58
+
59
+ COLORS = [
60
+ 'aliceblue','aqua','aquamarine','azure',
61
+ 'bisque','black','blanchedalmond','blue',
62
+ 'blueviolet','brown','burlywood','cadetblue',
63
+ 'chartreuse','chocolate','coral','cornflowerblue',
64
+ 'cornsilk','crimson','cyan','darkblue','darkcyan',
65
+ 'darkgoldenrod','darkgray','darkgreen',
66
+ 'darkkhaki','darkmagenta','darkolivegreen','darkorange',
67
+ 'darkorchid','darkred','darksalmon','darkseagreen',
68
+ 'darkslateblue','darkslategray',
69
+ 'darkturquoise','darkviolet','deeppink','deepskyblue',
70
+ 'dimgray','dodgerblue',
71
+ 'forestgreen','fuchsia','gainsboro',
72
+ 'gold','goldenrod','gray','green',
73
+ 'greenyellow','honeydew','hotpink','indianred','indigo',
74
+ 'ivory','khaki','lavender','lavenderblush','lawngreen',
75
+ 'lemonchiffon','lightblue','lightcoral','lightcyan',
76
+ 'lightgoldenrodyellow','lightgray',
77
+ 'lightgreen','lightpink','lightsalmon','lightseagreen',
78
+ 'lightskyblue','lightslategray',
79
+ 'lightsteelblue','lightyellow','lime','limegreen',
80
+ 'linen','magenta','maroon','mediumaquamarine',
81
+ 'mediumblue','mediumorchid','mediumpurple',
82
+ 'mediumseagreen','mediumslateblue','mediumspringgreen',
83
+ 'mediumturquoise','mediumvioletred','midnightblue',
84
+ 'mintcream','mistyrose','moccasin','navy',
85
+ 'oldlace','olive','olivedrab','orange','orangered',
86
+ 'orchid','palegoldenrod','palegreen','paleturquoise',
87
+ 'palevioletred','papayawhip','peachpuff','peru','pink',
88
+ 'plum','powderblue','purple','red','rosybrown',
89
+ 'royalblue','saddlebrown','salmon','sandybrown',
90
+ 'seagreen','seashell','sienna','silver','skyblue',
91
+ 'slateblue','slategray','slategrey','snow','springgreen',
92
+ 'steelblue','tan','teal','thistle','tomato','turquoise',
93
+ 'violet','wheat','yellow','yellowgreen'
94
+ ]
95
+
96
+ def build_parents(tree, visit_order, node_id2plot_id):
97
+ parents = [None]
98
+ parent_plot_ids = [None]
99
+ directions = [None]
100
+ for i in visit_order[1:]:
101
+ parent = tree[tree['right']==i].index
102
+ if parent.empty:
103
+ p = tree[tree['left']==i].index[0]
104
+ parent_plot_ids.append(str(node_id2plot_id[p]))
105
+ parents.append(p)
106
+ directions.append('l')
107
+ else:
108
+ parent_plot_ids.append(str(node_id2plot_id[parent[0]]))
109
+ parents.append(parent[0])
110
+ directions.append('r')
111
+ return parents, parent_plot_ids, directions
112
+
113
+
114
+ def build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions):
115
+ labels = ['Histogram Gradient-Boosted Decision Tree']
116
+ colors = ['white']
117
+ for i, parent, parent_plot_id, direction in zip(
118
+ visit_order,
119
+ parents,
120
+ parent_plot_ids,
121
+ directions
122
+ ):
123
+ # skip the first one (the root)
124
+ if i == 0:
125
+ continue
126
+ node = tree.loc[i]
127
+ feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
128
+
129
+ thresh = tree.loc[int(parent), 'num_threshold']
130
+ if direction == 'l':
131
+ labels.append(f"[{parent_plot_id}.L] {feat} <= {thresh}")
132
+ else:
133
+ labels.append(f"[{parent_plot_id}.R] {feat} > {thresh}")
134
+
135
+ # colors
136
+ offset = FEATS.index(feat)
137
+ colors.append(COLORS[offset])
138
+ return labels, colors
139
+
140
+
141
+ def build_plot(tree):
142
+ #https://stackoverflow.com/questions/64393535/python-plotly-treemap-ids-format-and-how-to-display-multiple-duplicated-labels-i
143
+ # if you use `ids`, then `parents` has to be in terms of `ids`
144
+ visit_order = breadth_first_traverse(tree)
145
+ node_id2plot_id = {node:i for i, node in enumerate(visit_order)}
146
+ parents, parent_plot_ids, directions = build_parents(tree, visit_order, node_id2plot_id)
147
+ labels, colors = build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions)
148
+ # this should just be ['0', '1', '2', . . .]
149
+ plot_ids = [str(node_id2plot_id[x]) for x in visit_order]
150
+
151
+ return go.Treemap(
152
+ values=tree['count'].to_numpy(),
153
+ labels=labels,
154
+ ids=plot_ids,
155
+ parents=parent_plot_ids,
156
+ marker_colors=colors,
157
+ )
158
+
159
+
160
+ def breadth_first_traverse(tree):
161
+ """
162
+ https://www.101computing.net/breadth-first-traversal-of-a-binary-tree/
163
+ Iterative version makes more sense since I have the whole tree in a table
164
+ instead of just nodes and pointers
165
+ """
166
+ q = [0]
167
+ visited_nodes = []
168
+ while len(q) != 0:
169
+ cur = q.pop(0)
170
+ visited_nodes.append(cur)
171
+
172
+ if tree.loc[cur, 'left'] != 0:
173
+ q.append(tree.loc[cur, 'left'])
174
+
175
+ if tree.loc[cur, 'right'] != 0:
176
+ q.append(tree.loc[cur, 'right'])
177
+
178
+ return visited_nodes
179
+
180
+
181
+ def main():
182
+ # load the data
183
+ hgb = joblib.load('hgb_classifier.joblib')
184
+ trees = [pd.DataFrame(x[0].nodes) for x in hgb._predictors]
185
+ # make the plots
186
+ graph_objs = [build_plot(tree) for tree in trees]
187
+ figures = [go.Figure(graph_obj) for graph_obj in graph_objs]
188
+ frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs]
189
+ # show them with streamlit
190
+
191
+ # this puts them all on the screen at once
192
+ # like each new one shows up below the previous one
193
+ # instead of replacing the previous one
194
+ #for fig in figures:
195
+ # st.plotly_chart(fig)
196
+ # time.sleep(1)
197
+
198
+ # This works the way I want
199
+ # but the plot is tiny
200
+ # also it recalcualtes all of the plots
201
+ # every time the slider value changes
202
+ #
203
+ # I tried to cache the plots but build_plot() takes
204
+ # a DataFrame which is mutable and therefore unhashable I guess
205
+ # so it won't let me cache that function
206
+ # I could pack the dataframe bytes to smuggle them past that check
207
+ # but whatever
208
+ idx = st.slider(
209
+ label='which step to show',
210
+ min_value=0,
211
+ max_value=len(figures)-1,
212
+ value=0,
213
+ step=1
214
+ )
215
+ st.plotly_chart(figures[idx])
216
+ st.markdown(f'## Tree {idx}')
217
+ st.dataframe(trees[idx])
218
+
219
+ # Maybe just show a Plotly animated chart
220
+ # https://plotly.com/python/animations/#using-a-slider-and-buttons
221
+ # They don't really document the animation stuff on their website
222
+ # but it's in here
223
+ # https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plot-schema.json
224
+ # I guess it's only in the JS docs and hasn't made it to the Python docs yet
225
+ # https://plotly.com/javascript/animations/
226
+ # trying to find stuff here instead
227
+ # https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout.updatemenu.html?highlight=updatemenu
228
+
229
+ # this one finally set the speed
230
+ # no mention of how they figured this out but thank goodness I found it
231
+ # https://towardsdatascience.com/basic-animation-with-matplotlib-and-plotly-5eef4ad6c5aa
232
+ ani_fig = go.Figure(
233
+ data=graph_objs[0],
234
+ frames=frames,
235
+ layout=go.Layout(
236
+ updatemenus=[{
237
+ 'type':'buttons',
238
+ 'buttons':[{
239
+ 'label':'Play',
240
+ 'method': 'animate',
241
+ 'args':[None, {
242
+ 'frame': {'duration':5000},
243
+ 'transition': {'duration': 2500}
244
+ }]
245
+ }]
246
+ }]
247
+ )
248
+ )
249
+ st.plotly_chart(ani_fig)
250
+
251
+ if __name__=='__main__':
252
+ main()
253
+
254
+
train_classifier.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+
3
+ import pandas as pd
4
+
5
+ from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
6
+ from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
7
+ from sklearn.metrics import classification_report
8
+
9
+
10
+ def main():
11
+ train_df = pd.read_csv('train_data.csv', na_values='-')
12
+ # `service` is about half-empty and the rest are completely full
13
+ # one of the rows has `no` for `state` which isn't listed as an option in the description of the fields
14
+ # I'm just going to delete that
15
+ train_df = train_df.drop(columns=['id'])
16
+ train_df = train_df.drop(index=train_df[train_df['state']=='no'].index)
17
+
18
+ # It can predict `label` really well ~0.95 accuracy/f1/whatever other stat you care about
19
+ # It does a lot worse trying to predict `attack_cat` b/c there are 10 classes
20
+ # and some of them are not well-represented
21
+ # so that might be more interesting to visualize
22
+ cheating = train_df.pop('attack_cat')
23
+ y_enc = LabelEncoder().fit(train_df['label'])
24
+ train_y = y_enc.transform(train_df.pop('label'))
25
+ x_enc = OrdinalEncoder().fit(train_df)
26
+ train_df = x_enc.transform(train_df)
27
+
28
+ # Random forest doesn't handle NaNs
29
+ # I could drop the `service` column or I can use the HistGradientBoostingClassifier
30
+ # super helpful error message from sklearn pointed me to this list
31
+ # https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values
32
+ #rf = RandomForestClassifier()
33
+ #rf.fit(train_df, y_train)
34
+
35
+ # max_iter is the number of time it builds a gradient-boosted tree
36
+ # so it's the number of estimators
37
+ hgb = HistGradientBoostingClassifier(max_iter=10).fit(train_df, train_y)
38
+ joblib.dump(hgb, 'hgb_classifier.joblib', compress=9)
39
+
40
+ test_df = pd.read_csv('test_data.csv', na_values='-')
41
+ test_df = test_df.drop(columns=['id', 'attack_cat'])
42
+ test_y = y_enc.transform(test_df.pop('label'))
43
+ test_df = x_enc.transform(test_df)
44
+ test_preds = hgb.predict(test_df)
45
+ print(classification_report(test_y, test_preds))
46
+
47
+ # I guess they took out the RF feature importance
48
+ # or maybe that's only in XGBoost
49
+ # you can still kind of get to it
50
+ # with RandomForestClassifier.feature_importances_
51
+ # or like this
52
+ # https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html
53
+ # but there's really nothing for the HistGradientBoostingClassifier
54
+ # but you can get to the actual nodes for each predictor/estimator like this
55
+ # hgb._predictors[i][0].nodes
56
+ # and that has information gain metric for each node which might be viz-able
57
+ # so that might be an interesting viz
58
+ # like plot the whole forest
59
+ # maybe only do like 10 estimators to keep it smaller
60
+ # or stick with 100 and figure out a good way to viz big models
61
+ # the first two estimators are almost identical
62
+ # so maybe like plot the first estimator
63
+ # and then fuzz the nodes by how much the other estimators differ
64
+ # assuming there's some things they all agree on exactly and others where they differ a little bit
65
+ # idk I don't really know how the algorithm works
66
+ # the 96th estimator looks pretty different (I'm assuming from boosting)
67
+ # so maybe like an evolution animation from the first to the last
68
+ # to see the effect of the boosting
69
+ # like plot the points and show how the decision boundary shifts with each generation
70
+ # alongside an animation of the actual decision tree morphing each step
71
+ # That might look too much like an animation of the model being trained though
72
+ # which I guess that's sort of what it is so idk
73
+
74
+ # https://scikit-learn.org/stable/modules/ensemble.html#interpretation-with-feature-importance
75
+
76
+ # also
77
+ # you can see what path a data point takes through the forest
78
+ # with RandomForestClassifier.decision_path()
79
+ # which might be really cool
80
+ # to see like 10 trees and the path through each tree and what each tree predicted
81
+
82
+
83
+ if __name__ == '__main__':
84
+ main()
85
+
86
+
viz_classifier.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+
3
+ import pandas as pd
4
+
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
+
8
+ hgb = joblib.load('hgb_classifier.joblib')
9
+ FEATS = [
10
+ 'srcip',
11
+ 'sport',
12
+ 'dstip',
13
+ 'dsport',
14
+ 'proto',
15
+ #'state', I dropped this one when I trained the model
16
+ 'dur',
17
+ 'sbytes',
18
+ 'dbytes',
19
+ 'sttl',
20
+ 'dttl',
21
+ 'sloss',
22
+ 'dloss',
23
+ 'service',
24
+ 'Sload',
25
+ 'Dload',
26
+ 'Spkts',
27
+ 'Dpkts',
28
+ 'swin',
29
+ 'dwin',
30
+ 'stcpb',
31
+ 'dtcpb',
32
+ 'smeansz',
33
+ 'dmeansz',
34
+ 'trans_depth',
35
+ 'res_bdy_len',
36
+ 'Sjit',
37
+ 'Djit',
38
+ 'Stime',
39
+ 'Ltime',
40
+ 'Sintpkt',
41
+ 'Dintpkt',
42
+ 'tcprtt',
43
+ 'synack',
44
+ 'ackdat',
45
+ 'is_sm_ips_ports',
46
+ 'ct_state_ttl',
47
+ 'ct_flw_http_mthd',
48
+ 'is_ftp_login',
49
+ 'ct_ftp_cmd',
50
+ 'ct_srv_src',
51
+ 'ct_srv_dst',
52
+ 'ct_dst_ltm',
53
+ 'ct_src_ltm',
54
+ 'ct_src_dport_ltm',
55
+ 'ct_dst_sport_ltm',
56
+ 'ct_dst_src_ltm',
57
+ ]
58
+
59
+ # plotly only has the CSS named colors
60
+ # I don't think I can use xkcd colors
61
+ # I copied a bunch of CSS colors from somewhere online
62
+ # and then deleted whites and things that showed up too close on the tree
63
+ # this is not really a general solution, it just works for this specific tree
64
+ # I'll have to come up with a better colormap at some point
65
+ COLORS = [
66
+ 'aliceblue','aqua','aquamarine','azure',
67
+ 'bisque','black','blanchedalmond','blue',
68
+ 'blueviolet','brown','burlywood','cadetblue',
69
+ 'chartreuse','chocolate','coral','cornflowerblue',
70
+ 'cornsilk','crimson','cyan','darkblue','darkcyan',
71
+ 'darkgoldenrod','darkgray','darkgreen',
72
+ 'darkkhaki','darkmagenta','darkolivegreen','darkorange',
73
+ 'darkorchid','darkred','darksalmon','darkseagreen',
74
+ 'darkslateblue','darkslategray',
75
+ 'darkturquoise','darkviolet','deeppink','deepskyblue',
76
+ 'dimgray','dodgerblue',
77
+ 'forestgreen','fuchsia','gainsboro',
78
+ 'gold','goldenrod','gray','green',
79
+ 'greenyellow','honeydew','hotpink','indianred','indigo',
80
+ 'ivory','khaki','lavender','lavenderblush','lawngreen',
81
+ 'lemonchiffon','lightblue','lightcoral','lightcyan',
82
+ 'lightgoldenrodyellow','lightgray',
83
+ 'lightgreen','lightpink','lightsalmon','lightseagreen',
84
+ 'lightskyblue','lightslategray',
85
+ 'lightsteelblue','lightyellow','lime','limegreen',
86
+ 'linen','magenta','maroon','mediumaquamarine',
87
+ 'mediumblue','mediumorchid','mediumpurple',
88
+ 'mediumseagreen','mediumslateblue','mediumspringgreen',
89
+ 'mediumturquoise','mediumvioletred','midnightblue',
90
+ 'mintcream','mistyrose','moccasin','navy',
91
+ 'oldlace','olive','olivedrab','orange','orangered',
92
+ 'orchid','palegoldenrod','palegreen','paleturquoise',
93
+ 'palevioletred','papayawhip','peachpuff','peru','pink',
94
+ 'plum','powderblue','purple','red','rosybrown',
95
+ 'royalblue','saddlebrown','salmon','sandybrown',
96
+ 'seagreen','seashell','sienna','silver','skyblue',
97
+ 'slateblue','slategray','slategrey','snow','springgreen',
98
+ 'steelblue','tan','teal','thistle','tomato','turquoise',
99
+ 'violet','wheat','yellow','yellowgreen'
100
+ ]
101
+
102
+ trees = [x[0].nodes for x in hgb._predictors]
103
+
104
+ # the final tree definitely has a similar structure but is noticably different
105
+ # that's really cool
106
+ # I think this will make a cool animation
107
+ # if I can figure it out
108
+ tree = pd.DataFrame(trees[0])
109
+ #tree = pd.DataFrame(trees[9])
110
+
111
+
112
+
113
+ # parents is going to be tricky
114
+ # I need get the index of whichever node has the current node listed in either left or right
115
+
116
+ parents = [None]
117
+ # keep track of whether each node is a left or right child of the parent in the list
118
+ directions = [None]
119
+ # it uses 0 to say "no left/right child"
120
+ # so I have to skip searching for node 0
121
+ # which is fine b/c node 0 is the root
122
+ for i in tree.index[1:]:
123
+ # it seems to make a very even tree
124
+ # so just guess it's in the right side
125
+ # and that will be right half the time
126
+ parent = tree[tree['right']==i].index
127
+ if parent.empty:
128
+ parents.append(str(tree[tree['left']==i].index[0]))
129
+ directions.append('l')
130
+ else:
131
+ parents.append(str(parent[0]))
132
+ directions.append('r')
133
+
134
+
135
+ # generate the labels
136
+ # and the colors
137
+ labels = ['Histogram Gradient-Boosted Decision Tree']
138
+ colors = ['white']
139
+ for i, node, parent, direction in zip(
140
+ tree.index.to_numpy(),
141
+ tree.iterrows(),
142
+ parents,
143
+ directions
144
+ ):
145
+ # skip the first one (the root)
146
+ if i == 0:
147
+ continue
148
+ node = node[1]
149
+ feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
150
+ thresh = tree.loc[int(parent), 'num_threshold']
151
+ if direction == 'l':
152
+ labels.append(f"[{i}] {feat} <= {thresh}")
153
+ else:
154
+ labels.append(f"[{i}] {feat} > {thresh}")
155
+
156
+ # colors
157
+ offset = FEATS.index(feat)
158
+ colors.append(COLORS[offset])
159
+
160
+
161
+ # actual plot
162
+ f = go.Figure(
163
+ go.Treemap(
164
+ values=tree['count'].to_numpy(),
165
+ labels=labels,
166
+ ids=tree.index.to_numpy(),
167
+ parents=parents,
168
+ marker_colors=colors,
169
+ )
170
+ )
171
+
172
+ #f.update_layout(
173
+ # treemapcolorway = ['pink']
174
+ #)
175
+
176
+ breakpoint()
177
+
178
+
179
+ # converting the ndarry with columns names to a pandas df
180
+ # 3284 bytes as an ndarry
181
+ # 3300 bytes as a dataframe
182
+ # so they're the same size
183
+ # do I need to convert it to pandas? idk
184
+ # just curious
185
+
186
+ # https://linuxtut.com/en/ffb2e319db5545965933/
187
+
188
+ # https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
189
+ # figuring out how the thing works
190
+
191
+ # `value` is the predicted class / value / whatever
192
+ # so if it's a leaf node, it returns that value as the prediction
193
+ # there are negative values in some of the leaves
194
+ # maybe the classes are +/-1 instead of 0/1?
195
+
196
+ # if the data value is <= `num_threshold` then it goes in the left node
197
+ # if it's > `num_threshold` then it goes in the right node
198
+
199
+ # okay and then all the leave have feature_idx=0, num_threshold=0, left=0, right=0
200
+ # that makes sense
201
+ # still kind of annoying that they use 0 instead of np.nan but oh well
202
+
203
+ # also super super hard to figure out what the labels on the tree map should be
204
+ # like it has to check the parent's feature_idx and num_threshold
205
+ # which I guess isn't too bad once we have the list of parents already built
206
+ # except that I don't know whether a node is left or right from its parent
207
+ # hmmmm
208
+
209
+
210
+
211
+
212
+
213
+
214
+
215
+