Spaces:
Runtime error
Runtime error
none
commited on
Commit
·
045d7d4
0
Parent(s):
Working version of the streamlit animation
Browse files- README.md +1 -0
- streamlit_viz.py +254 -0
- train_classifier.py +86 -0
- viz_classifier.py +215 -0
README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The `id` column is baloney. There are lots of duplicates.
|
streamlit_viz.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import joblib
|
2 |
+
import time
|
3 |
+
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
import streamlit as st
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
FEATS = [
|
10 |
+
'srcip',
|
11 |
+
'sport',
|
12 |
+
'dstip',
|
13 |
+
'dsport',
|
14 |
+
'proto',
|
15 |
+
#'state', I dropped this one when I trained the model
|
16 |
+
'dur',
|
17 |
+
'sbytes',
|
18 |
+
'dbytes',
|
19 |
+
'sttl',
|
20 |
+
'dttl',
|
21 |
+
'sloss',
|
22 |
+
'dloss',
|
23 |
+
'service',
|
24 |
+
'Sload',
|
25 |
+
'Dload',
|
26 |
+
'Spkts',
|
27 |
+
'Dpkts',
|
28 |
+
'swin',
|
29 |
+
'dwin',
|
30 |
+
'stcpb',
|
31 |
+
'dtcpb',
|
32 |
+
'smeansz',
|
33 |
+
'dmeansz',
|
34 |
+
'trans_depth',
|
35 |
+
'res_bdy_len',
|
36 |
+
'Sjit',
|
37 |
+
'Djit',
|
38 |
+
'Stime',
|
39 |
+
'Ltime',
|
40 |
+
'Sintpkt',
|
41 |
+
'Dintpkt',
|
42 |
+
'tcprtt',
|
43 |
+
'synack',
|
44 |
+
'ackdat',
|
45 |
+
'is_sm_ips_ports',
|
46 |
+
'ct_state_ttl',
|
47 |
+
'ct_flw_http_mthd',
|
48 |
+
'is_ftp_login',
|
49 |
+
'ct_ftp_cmd',
|
50 |
+
'ct_srv_src',
|
51 |
+
'ct_srv_dst',
|
52 |
+
'ct_dst_ltm',
|
53 |
+
'ct_src_ltm',
|
54 |
+
'ct_src_dport_ltm',
|
55 |
+
'ct_dst_sport_ltm',
|
56 |
+
'ct_dst_src_ltm',
|
57 |
+
]
|
58 |
+
|
59 |
+
COLORS = [
|
60 |
+
'aliceblue','aqua','aquamarine','azure',
|
61 |
+
'bisque','black','blanchedalmond','blue',
|
62 |
+
'blueviolet','brown','burlywood','cadetblue',
|
63 |
+
'chartreuse','chocolate','coral','cornflowerblue',
|
64 |
+
'cornsilk','crimson','cyan','darkblue','darkcyan',
|
65 |
+
'darkgoldenrod','darkgray','darkgreen',
|
66 |
+
'darkkhaki','darkmagenta','darkolivegreen','darkorange',
|
67 |
+
'darkorchid','darkred','darksalmon','darkseagreen',
|
68 |
+
'darkslateblue','darkslategray',
|
69 |
+
'darkturquoise','darkviolet','deeppink','deepskyblue',
|
70 |
+
'dimgray','dodgerblue',
|
71 |
+
'forestgreen','fuchsia','gainsboro',
|
72 |
+
'gold','goldenrod','gray','green',
|
73 |
+
'greenyellow','honeydew','hotpink','indianred','indigo',
|
74 |
+
'ivory','khaki','lavender','lavenderblush','lawngreen',
|
75 |
+
'lemonchiffon','lightblue','lightcoral','lightcyan',
|
76 |
+
'lightgoldenrodyellow','lightgray',
|
77 |
+
'lightgreen','lightpink','lightsalmon','lightseagreen',
|
78 |
+
'lightskyblue','lightslategray',
|
79 |
+
'lightsteelblue','lightyellow','lime','limegreen',
|
80 |
+
'linen','magenta','maroon','mediumaquamarine',
|
81 |
+
'mediumblue','mediumorchid','mediumpurple',
|
82 |
+
'mediumseagreen','mediumslateblue','mediumspringgreen',
|
83 |
+
'mediumturquoise','mediumvioletred','midnightblue',
|
84 |
+
'mintcream','mistyrose','moccasin','navy',
|
85 |
+
'oldlace','olive','olivedrab','orange','orangered',
|
86 |
+
'orchid','palegoldenrod','palegreen','paleturquoise',
|
87 |
+
'palevioletred','papayawhip','peachpuff','peru','pink',
|
88 |
+
'plum','powderblue','purple','red','rosybrown',
|
89 |
+
'royalblue','saddlebrown','salmon','sandybrown',
|
90 |
+
'seagreen','seashell','sienna','silver','skyblue',
|
91 |
+
'slateblue','slategray','slategrey','snow','springgreen',
|
92 |
+
'steelblue','tan','teal','thistle','tomato','turquoise',
|
93 |
+
'violet','wheat','yellow','yellowgreen'
|
94 |
+
]
|
95 |
+
|
96 |
+
def build_parents(tree, visit_order, node_id2plot_id):
|
97 |
+
parents = [None]
|
98 |
+
parent_plot_ids = [None]
|
99 |
+
directions = [None]
|
100 |
+
for i in visit_order[1:]:
|
101 |
+
parent = tree[tree['right']==i].index
|
102 |
+
if parent.empty:
|
103 |
+
p = tree[tree['left']==i].index[0]
|
104 |
+
parent_plot_ids.append(str(node_id2plot_id[p]))
|
105 |
+
parents.append(p)
|
106 |
+
directions.append('l')
|
107 |
+
else:
|
108 |
+
parent_plot_ids.append(str(node_id2plot_id[parent[0]]))
|
109 |
+
parents.append(parent[0])
|
110 |
+
directions.append('r')
|
111 |
+
return parents, parent_plot_ids, directions
|
112 |
+
|
113 |
+
|
114 |
+
def build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions):
|
115 |
+
labels = ['Histogram Gradient-Boosted Decision Tree']
|
116 |
+
colors = ['white']
|
117 |
+
for i, parent, parent_plot_id, direction in zip(
|
118 |
+
visit_order,
|
119 |
+
parents,
|
120 |
+
parent_plot_ids,
|
121 |
+
directions
|
122 |
+
):
|
123 |
+
# skip the first one (the root)
|
124 |
+
if i == 0:
|
125 |
+
continue
|
126 |
+
node = tree.loc[i]
|
127 |
+
feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
|
128 |
+
|
129 |
+
thresh = tree.loc[int(parent), 'num_threshold']
|
130 |
+
if direction == 'l':
|
131 |
+
labels.append(f"[{parent_plot_id}.L] {feat} <= {thresh}")
|
132 |
+
else:
|
133 |
+
labels.append(f"[{parent_plot_id}.R] {feat} > {thresh}")
|
134 |
+
|
135 |
+
# colors
|
136 |
+
offset = FEATS.index(feat)
|
137 |
+
colors.append(COLORS[offset])
|
138 |
+
return labels, colors
|
139 |
+
|
140 |
+
|
141 |
+
def build_plot(tree):
|
142 |
+
#https://stackoverflow.com/questions/64393535/python-plotly-treemap-ids-format-and-how-to-display-multiple-duplicated-labels-i
|
143 |
+
# if you use `ids`, then `parents` has to be in terms of `ids`
|
144 |
+
visit_order = breadth_first_traverse(tree)
|
145 |
+
node_id2plot_id = {node:i for i, node in enumerate(visit_order)}
|
146 |
+
parents, parent_plot_ids, directions = build_parents(tree, visit_order, node_id2plot_id)
|
147 |
+
labels, colors = build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions)
|
148 |
+
# this should just be ['0', '1', '2', . . .]
|
149 |
+
plot_ids = [str(node_id2plot_id[x]) for x in visit_order]
|
150 |
+
|
151 |
+
return go.Treemap(
|
152 |
+
values=tree['count'].to_numpy(),
|
153 |
+
labels=labels,
|
154 |
+
ids=plot_ids,
|
155 |
+
parents=parent_plot_ids,
|
156 |
+
marker_colors=colors,
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
def breadth_first_traverse(tree):
|
161 |
+
"""
|
162 |
+
https://www.101computing.net/breadth-first-traversal-of-a-binary-tree/
|
163 |
+
Iterative version makes more sense since I have the whole tree in a table
|
164 |
+
instead of just nodes and pointers
|
165 |
+
"""
|
166 |
+
q = [0]
|
167 |
+
visited_nodes = []
|
168 |
+
while len(q) != 0:
|
169 |
+
cur = q.pop(0)
|
170 |
+
visited_nodes.append(cur)
|
171 |
+
|
172 |
+
if tree.loc[cur, 'left'] != 0:
|
173 |
+
q.append(tree.loc[cur, 'left'])
|
174 |
+
|
175 |
+
if tree.loc[cur, 'right'] != 0:
|
176 |
+
q.append(tree.loc[cur, 'right'])
|
177 |
+
|
178 |
+
return visited_nodes
|
179 |
+
|
180 |
+
|
181 |
+
def main():
|
182 |
+
# load the data
|
183 |
+
hgb = joblib.load('hgb_classifier.joblib')
|
184 |
+
trees = [pd.DataFrame(x[0].nodes) for x in hgb._predictors]
|
185 |
+
# make the plots
|
186 |
+
graph_objs = [build_plot(tree) for tree in trees]
|
187 |
+
figures = [go.Figure(graph_obj) for graph_obj in graph_objs]
|
188 |
+
frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs]
|
189 |
+
# show them with streamlit
|
190 |
+
|
191 |
+
# this puts them all on the screen at once
|
192 |
+
# like each new one shows up below the previous one
|
193 |
+
# instead of replacing the previous one
|
194 |
+
#for fig in figures:
|
195 |
+
# st.plotly_chart(fig)
|
196 |
+
# time.sleep(1)
|
197 |
+
|
198 |
+
# This works the way I want
|
199 |
+
# but the plot is tiny
|
200 |
+
# also it recalcualtes all of the plots
|
201 |
+
# every time the slider value changes
|
202 |
+
#
|
203 |
+
# I tried to cache the plots but build_plot() takes
|
204 |
+
# a DataFrame which is mutable and therefore unhashable I guess
|
205 |
+
# so it won't let me cache that function
|
206 |
+
# I could pack the dataframe bytes to smuggle them past that check
|
207 |
+
# but whatever
|
208 |
+
idx = st.slider(
|
209 |
+
label='which step to show',
|
210 |
+
min_value=0,
|
211 |
+
max_value=len(figures)-1,
|
212 |
+
value=0,
|
213 |
+
step=1
|
214 |
+
)
|
215 |
+
st.plotly_chart(figures[idx])
|
216 |
+
st.markdown(f'## Tree {idx}')
|
217 |
+
st.dataframe(trees[idx])
|
218 |
+
|
219 |
+
# Maybe just show a Plotly animated chart
|
220 |
+
# https://plotly.com/python/animations/#using-a-slider-and-buttons
|
221 |
+
# They don't really document the animation stuff on their website
|
222 |
+
# but it's in here
|
223 |
+
# https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plot-schema.json
|
224 |
+
# I guess it's only in the JS docs and hasn't made it to the Python docs yet
|
225 |
+
# https://plotly.com/javascript/animations/
|
226 |
+
# trying to find stuff here instead
|
227 |
+
# https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout.updatemenu.html?highlight=updatemenu
|
228 |
+
|
229 |
+
# this one finally set the speed
|
230 |
+
# no mention of how they figured this out but thank goodness I found it
|
231 |
+
# https://towardsdatascience.com/basic-animation-with-matplotlib-and-plotly-5eef4ad6c5aa
|
232 |
+
ani_fig = go.Figure(
|
233 |
+
data=graph_objs[0],
|
234 |
+
frames=frames,
|
235 |
+
layout=go.Layout(
|
236 |
+
updatemenus=[{
|
237 |
+
'type':'buttons',
|
238 |
+
'buttons':[{
|
239 |
+
'label':'Play',
|
240 |
+
'method': 'animate',
|
241 |
+
'args':[None, {
|
242 |
+
'frame': {'duration':5000},
|
243 |
+
'transition': {'duration': 2500}
|
244 |
+
}]
|
245 |
+
}]
|
246 |
+
}]
|
247 |
+
)
|
248 |
+
)
|
249 |
+
st.plotly_chart(ani_fig)
|
250 |
+
|
251 |
+
if __name__=='__main__':
|
252 |
+
main()
|
253 |
+
|
254 |
+
|
train_classifier.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import joblib
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
|
6 |
+
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
|
7 |
+
from sklearn.metrics import classification_report
|
8 |
+
|
9 |
+
|
10 |
+
def main():
|
11 |
+
train_df = pd.read_csv('train_data.csv', na_values='-')
|
12 |
+
# `service` is about half-empty and the rest are completely full
|
13 |
+
# one of the rows has `no` for `state` which isn't listed as an option in the description of the fields
|
14 |
+
# I'm just going to delete that
|
15 |
+
train_df = train_df.drop(columns=['id'])
|
16 |
+
train_df = train_df.drop(index=train_df[train_df['state']=='no'].index)
|
17 |
+
|
18 |
+
# It can predict `label` really well ~0.95 accuracy/f1/whatever other stat you care about
|
19 |
+
# It does a lot worse trying to predict `attack_cat` b/c there are 10 classes
|
20 |
+
# and some of them are not well-represented
|
21 |
+
# so that might be more interesting to visualize
|
22 |
+
cheating = train_df.pop('attack_cat')
|
23 |
+
y_enc = LabelEncoder().fit(train_df['label'])
|
24 |
+
train_y = y_enc.transform(train_df.pop('label'))
|
25 |
+
x_enc = OrdinalEncoder().fit(train_df)
|
26 |
+
train_df = x_enc.transform(train_df)
|
27 |
+
|
28 |
+
# Random forest doesn't handle NaNs
|
29 |
+
# I could drop the `service` column or I can use the HistGradientBoostingClassifier
|
30 |
+
# super helpful error message from sklearn pointed me to this list
|
31 |
+
# https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values
|
32 |
+
#rf = RandomForestClassifier()
|
33 |
+
#rf.fit(train_df, y_train)
|
34 |
+
|
35 |
+
# max_iter is the number of time it builds a gradient-boosted tree
|
36 |
+
# so it's the number of estimators
|
37 |
+
hgb = HistGradientBoostingClassifier(max_iter=10).fit(train_df, train_y)
|
38 |
+
joblib.dump(hgb, 'hgb_classifier.joblib', compress=9)
|
39 |
+
|
40 |
+
test_df = pd.read_csv('test_data.csv', na_values='-')
|
41 |
+
test_df = test_df.drop(columns=['id', 'attack_cat'])
|
42 |
+
test_y = y_enc.transform(test_df.pop('label'))
|
43 |
+
test_df = x_enc.transform(test_df)
|
44 |
+
test_preds = hgb.predict(test_df)
|
45 |
+
print(classification_report(test_y, test_preds))
|
46 |
+
|
47 |
+
# I guess they took out the RF feature importance
|
48 |
+
# or maybe that's only in XGBoost
|
49 |
+
# you can still kind of get to it
|
50 |
+
# with RandomForestClassifier.feature_importances_
|
51 |
+
# or like this
|
52 |
+
# https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html
|
53 |
+
# but there's really nothing for the HistGradientBoostingClassifier
|
54 |
+
# but you can get to the actual nodes for each predictor/estimator like this
|
55 |
+
# hgb._predictors[i][0].nodes
|
56 |
+
# and that has information gain metric for each node which might be viz-able
|
57 |
+
# so that might be an interesting viz
|
58 |
+
# like plot the whole forest
|
59 |
+
# maybe only do like 10 estimators to keep it smaller
|
60 |
+
# or stick with 100 and figure out a good way to viz big models
|
61 |
+
# the first two estimators are almost identical
|
62 |
+
# so maybe like plot the first estimator
|
63 |
+
# and then fuzz the nodes by how much the other estimators differ
|
64 |
+
# assuming there's some things they all agree on exactly and others where they differ a little bit
|
65 |
+
# idk I don't really know how the algorithm works
|
66 |
+
# the 96th estimator looks pretty different (I'm assuming from boosting)
|
67 |
+
# so maybe like an evolution animation from the first to the last
|
68 |
+
# to see the effect of the boosting
|
69 |
+
# like plot the points and show how the decision boundary shifts with each generation
|
70 |
+
# alongside an animation of the actual decision tree morphing each step
|
71 |
+
# That might look too much like an animation of the model being trained though
|
72 |
+
# which I guess that's sort of what it is so idk
|
73 |
+
|
74 |
+
# https://scikit-learn.org/stable/modules/ensemble.html#interpretation-with-feature-importance
|
75 |
+
|
76 |
+
# also
|
77 |
+
# you can see what path a data point takes through the forest
|
78 |
+
# with RandomForestClassifier.decision_path()
|
79 |
+
# which might be really cool
|
80 |
+
# to see like 10 trees and the path through each tree and what each tree predicted
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
main()
|
85 |
+
|
86 |
+
|
viz_classifier.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import joblib
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
import plotly.express as px
|
7 |
+
|
8 |
+
hgb = joblib.load('hgb_classifier.joblib')
|
9 |
+
FEATS = [
|
10 |
+
'srcip',
|
11 |
+
'sport',
|
12 |
+
'dstip',
|
13 |
+
'dsport',
|
14 |
+
'proto',
|
15 |
+
#'state', I dropped this one when I trained the model
|
16 |
+
'dur',
|
17 |
+
'sbytes',
|
18 |
+
'dbytes',
|
19 |
+
'sttl',
|
20 |
+
'dttl',
|
21 |
+
'sloss',
|
22 |
+
'dloss',
|
23 |
+
'service',
|
24 |
+
'Sload',
|
25 |
+
'Dload',
|
26 |
+
'Spkts',
|
27 |
+
'Dpkts',
|
28 |
+
'swin',
|
29 |
+
'dwin',
|
30 |
+
'stcpb',
|
31 |
+
'dtcpb',
|
32 |
+
'smeansz',
|
33 |
+
'dmeansz',
|
34 |
+
'trans_depth',
|
35 |
+
'res_bdy_len',
|
36 |
+
'Sjit',
|
37 |
+
'Djit',
|
38 |
+
'Stime',
|
39 |
+
'Ltime',
|
40 |
+
'Sintpkt',
|
41 |
+
'Dintpkt',
|
42 |
+
'tcprtt',
|
43 |
+
'synack',
|
44 |
+
'ackdat',
|
45 |
+
'is_sm_ips_ports',
|
46 |
+
'ct_state_ttl',
|
47 |
+
'ct_flw_http_mthd',
|
48 |
+
'is_ftp_login',
|
49 |
+
'ct_ftp_cmd',
|
50 |
+
'ct_srv_src',
|
51 |
+
'ct_srv_dst',
|
52 |
+
'ct_dst_ltm',
|
53 |
+
'ct_src_ltm',
|
54 |
+
'ct_src_dport_ltm',
|
55 |
+
'ct_dst_sport_ltm',
|
56 |
+
'ct_dst_src_ltm',
|
57 |
+
]
|
58 |
+
|
59 |
+
# plotly only has the CSS named colors
|
60 |
+
# I don't think I can use xkcd colors
|
61 |
+
# I copied a bunch of CSS colors from somewhere online
|
62 |
+
# and then deleted whites and things that showed up too close on the tree
|
63 |
+
# this is not really a general solution, it just works for this specific tree
|
64 |
+
# I'll have to come up with a better colormap at some point
|
65 |
+
COLORS = [
|
66 |
+
'aliceblue','aqua','aquamarine','azure',
|
67 |
+
'bisque','black','blanchedalmond','blue',
|
68 |
+
'blueviolet','brown','burlywood','cadetblue',
|
69 |
+
'chartreuse','chocolate','coral','cornflowerblue',
|
70 |
+
'cornsilk','crimson','cyan','darkblue','darkcyan',
|
71 |
+
'darkgoldenrod','darkgray','darkgreen',
|
72 |
+
'darkkhaki','darkmagenta','darkolivegreen','darkorange',
|
73 |
+
'darkorchid','darkred','darksalmon','darkseagreen',
|
74 |
+
'darkslateblue','darkslategray',
|
75 |
+
'darkturquoise','darkviolet','deeppink','deepskyblue',
|
76 |
+
'dimgray','dodgerblue',
|
77 |
+
'forestgreen','fuchsia','gainsboro',
|
78 |
+
'gold','goldenrod','gray','green',
|
79 |
+
'greenyellow','honeydew','hotpink','indianred','indigo',
|
80 |
+
'ivory','khaki','lavender','lavenderblush','lawngreen',
|
81 |
+
'lemonchiffon','lightblue','lightcoral','lightcyan',
|
82 |
+
'lightgoldenrodyellow','lightgray',
|
83 |
+
'lightgreen','lightpink','lightsalmon','lightseagreen',
|
84 |
+
'lightskyblue','lightslategray',
|
85 |
+
'lightsteelblue','lightyellow','lime','limegreen',
|
86 |
+
'linen','magenta','maroon','mediumaquamarine',
|
87 |
+
'mediumblue','mediumorchid','mediumpurple',
|
88 |
+
'mediumseagreen','mediumslateblue','mediumspringgreen',
|
89 |
+
'mediumturquoise','mediumvioletred','midnightblue',
|
90 |
+
'mintcream','mistyrose','moccasin','navy',
|
91 |
+
'oldlace','olive','olivedrab','orange','orangered',
|
92 |
+
'orchid','palegoldenrod','palegreen','paleturquoise',
|
93 |
+
'palevioletred','papayawhip','peachpuff','peru','pink',
|
94 |
+
'plum','powderblue','purple','red','rosybrown',
|
95 |
+
'royalblue','saddlebrown','salmon','sandybrown',
|
96 |
+
'seagreen','seashell','sienna','silver','skyblue',
|
97 |
+
'slateblue','slategray','slategrey','snow','springgreen',
|
98 |
+
'steelblue','tan','teal','thistle','tomato','turquoise',
|
99 |
+
'violet','wheat','yellow','yellowgreen'
|
100 |
+
]
|
101 |
+
|
102 |
+
trees = [x[0].nodes for x in hgb._predictors]
|
103 |
+
|
104 |
+
# the final tree definitely has a similar structure but is noticably different
|
105 |
+
# that's really cool
|
106 |
+
# I think this will make a cool animation
|
107 |
+
# if I can figure it out
|
108 |
+
tree = pd.DataFrame(trees[0])
|
109 |
+
#tree = pd.DataFrame(trees[9])
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
# parents is going to be tricky
|
114 |
+
# I need get the index of whichever node has the current node listed in either left or right
|
115 |
+
|
116 |
+
parents = [None]
|
117 |
+
# keep track of whether each node is a left or right child of the parent in the list
|
118 |
+
directions = [None]
|
119 |
+
# it uses 0 to say "no left/right child"
|
120 |
+
# so I have to skip searching for node 0
|
121 |
+
# which is fine b/c node 0 is the root
|
122 |
+
for i in tree.index[1:]:
|
123 |
+
# it seems to make a very even tree
|
124 |
+
# so just guess it's in the right side
|
125 |
+
# and that will be right half the time
|
126 |
+
parent = tree[tree['right']==i].index
|
127 |
+
if parent.empty:
|
128 |
+
parents.append(str(tree[tree['left']==i].index[0]))
|
129 |
+
directions.append('l')
|
130 |
+
else:
|
131 |
+
parents.append(str(parent[0]))
|
132 |
+
directions.append('r')
|
133 |
+
|
134 |
+
|
135 |
+
# generate the labels
|
136 |
+
# and the colors
|
137 |
+
labels = ['Histogram Gradient-Boosted Decision Tree']
|
138 |
+
colors = ['white']
|
139 |
+
for i, node, parent, direction in zip(
|
140 |
+
tree.index.to_numpy(),
|
141 |
+
tree.iterrows(),
|
142 |
+
parents,
|
143 |
+
directions
|
144 |
+
):
|
145 |
+
# skip the first one (the root)
|
146 |
+
if i == 0:
|
147 |
+
continue
|
148 |
+
node = node[1]
|
149 |
+
feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
|
150 |
+
thresh = tree.loc[int(parent), 'num_threshold']
|
151 |
+
if direction == 'l':
|
152 |
+
labels.append(f"[{i}] {feat} <= {thresh}")
|
153 |
+
else:
|
154 |
+
labels.append(f"[{i}] {feat} > {thresh}")
|
155 |
+
|
156 |
+
# colors
|
157 |
+
offset = FEATS.index(feat)
|
158 |
+
colors.append(COLORS[offset])
|
159 |
+
|
160 |
+
|
161 |
+
# actual plot
|
162 |
+
f = go.Figure(
|
163 |
+
go.Treemap(
|
164 |
+
values=tree['count'].to_numpy(),
|
165 |
+
labels=labels,
|
166 |
+
ids=tree.index.to_numpy(),
|
167 |
+
parents=parents,
|
168 |
+
marker_colors=colors,
|
169 |
+
)
|
170 |
+
)
|
171 |
+
|
172 |
+
#f.update_layout(
|
173 |
+
# treemapcolorway = ['pink']
|
174 |
+
#)
|
175 |
+
|
176 |
+
breakpoint()
|
177 |
+
|
178 |
+
|
179 |
+
# converting the ndarry with columns names to a pandas df
|
180 |
+
# 3284 bytes as an ndarry
|
181 |
+
# 3300 bytes as a dataframe
|
182 |
+
# so they're the same size
|
183 |
+
# do I need to convert it to pandas? idk
|
184 |
+
# just curious
|
185 |
+
|
186 |
+
# https://linuxtut.com/en/ffb2e319db5545965933/
|
187 |
+
|
188 |
+
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
|
189 |
+
# figuring out how the thing works
|
190 |
+
|
191 |
+
# `value` is the predicted class / value / whatever
|
192 |
+
# so if it's a leaf node, it returns that value as the prediction
|
193 |
+
# there are negative values in some of the leaves
|
194 |
+
# maybe the classes are +/-1 instead of 0/1?
|
195 |
+
|
196 |
+
# if the data value is <= `num_threshold` then it goes in the left node
|
197 |
+
# if it's > `num_threshold` then it goes in the right node
|
198 |
+
|
199 |
+
# okay and then all the leave have feature_idx=0, num_threshold=0, left=0, right=0
|
200 |
+
# that makes sense
|
201 |
+
# still kind of annoying that they use 0 instead of np.nan but oh well
|
202 |
+
|
203 |
+
# also super super hard to figure out what the labels on the tree map should be
|
204 |
+
# like it has to check the parent's feature_idx and num_threshold
|
205 |
+
# which I guess isn't too bad once we have the list of parents already built
|
206 |
+
# except that I don't know whether a node is left or right from its parent
|
207 |
+
# hmmmm
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
|