hgbdt-viz / viz_classifier.py
none
Working version of the streamlit animation
045d7d4
raw
history blame
5.96 kB
import joblib
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
hgb = joblib.load('hgb_classifier.joblib')
FEATS = [
'srcip',
'sport',
'dstip',
'dsport',
'proto',
#'state', I dropped this one when I trained the model
'dur',
'sbytes',
'dbytes',
'sttl',
'dttl',
'sloss',
'dloss',
'service',
'Sload',
'Dload',
'Spkts',
'Dpkts',
'swin',
'dwin',
'stcpb',
'dtcpb',
'smeansz',
'dmeansz',
'trans_depth',
'res_bdy_len',
'Sjit',
'Djit',
'Stime',
'Ltime',
'Sintpkt',
'Dintpkt',
'tcprtt',
'synack',
'ackdat',
'is_sm_ips_ports',
'ct_state_ttl',
'ct_flw_http_mthd',
'is_ftp_login',
'ct_ftp_cmd',
'ct_srv_src',
'ct_srv_dst',
'ct_dst_ltm',
'ct_src_ltm',
'ct_src_dport_ltm',
'ct_dst_sport_ltm',
'ct_dst_src_ltm',
]
# plotly only has the CSS named colors
# I don't think I can use xkcd colors
# I copied a bunch of CSS colors from somewhere online
# and then deleted whites and things that showed up too close on the tree
# this is not really a general solution, it just works for this specific tree
# I'll have to come up with a better colormap at some point
COLORS = [
'aliceblue','aqua','aquamarine','azure',
'bisque','black','blanchedalmond','blue',
'blueviolet','brown','burlywood','cadetblue',
'chartreuse','chocolate','coral','cornflowerblue',
'cornsilk','crimson','cyan','darkblue','darkcyan',
'darkgoldenrod','darkgray','darkgreen',
'darkkhaki','darkmagenta','darkolivegreen','darkorange',
'darkorchid','darkred','darksalmon','darkseagreen',
'darkslateblue','darkslategray',
'darkturquoise','darkviolet','deeppink','deepskyblue',
'dimgray','dodgerblue',
'forestgreen','fuchsia','gainsboro',
'gold','goldenrod','gray','green',
'greenyellow','honeydew','hotpink','indianred','indigo',
'ivory','khaki','lavender','lavenderblush','lawngreen',
'lemonchiffon','lightblue','lightcoral','lightcyan',
'lightgoldenrodyellow','lightgray',
'lightgreen','lightpink','lightsalmon','lightseagreen',
'lightskyblue','lightslategray',
'lightsteelblue','lightyellow','lime','limegreen',
'linen','magenta','maroon','mediumaquamarine',
'mediumblue','mediumorchid','mediumpurple',
'mediumseagreen','mediumslateblue','mediumspringgreen',
'mediumturquoise','mediumvioletred','midnightblue',
'mintcream','mistyrose','moccasin','navy',
'oldlace','olive','olivedrab','orange','orangered',
'orchid','palegoldenrod','palegreen','paleturquoise',
'palevioletred','papayawhip','peachpuff','peru','pink',
'plum','powderblue','purple','red','rosybrown',
'royalblue','saddlebrown','salmon','sandybrown',
'seagreen','seashell','sienna','silver','skyblue',
'slateblue','slategray','slategrey','snow','springgreen',
'steelblue','tan','teal','thistle','tomato','turquoise',
'violet','wheat','yellow','yellowgreen'
]
trees = [x[0].nodes for x in hgb._predictors]
# the final tree definitely has a similar structure but is noticably different
# that's really cool
# I think this will make a cool animation
# if I can figure it out
tree = pd.DataFrame(trees[0])
#tree = pd.DataFrame(trees[9])
# parents is going to be tricky
# I need get the index of whichever node has the current node listed in either left or right
parents = [None]
# keep track of whether each node is a left or right child of the parent in the list
directions = [None]
# it uses 0 to say "no left/right child"
# so I have to skip searching for node 0
# which is fine b/c node 0 is the root
for i in tree.index[1:]:
# it seems to make a very even tree
# so just guess it's in the right side
# and that will be right half the time
parent = tree[tree['right']==i].index
if parent.empty:
parents.append(str(tree[tree['left']==i].index[0]))
directions.append('l')
else:
parents.append(str(parent[0]))
directions.append('r')
# generate the labels
# and the colors
labels = ['Histogram Gradient-Boosted Decision Tree']
colors = ['white']
for i, node, parent, direction in zip(
tree.index.to_numpy(),
tree.iterrows(),
parents,
directions
):
# skip the first one (the root)
if i == 0:
continue
node = node[1]
feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
thresh = tree.loc[int(parent), 'num_threshold']
if direction == 'l':
labels.append(f"[{i}] {feat} <= {thresh}")
else:
labels.append(f"[{i}] {feat} > {thresh}")
# colors
offset = FEATS.index(feat)
colors.append(COLORS[offset])
# actual plot
f = go.Figure(
go.Treemap(
values=tree['count'].to_numpy(),
labels=labels,
ids=tree.index.to_numpy(),
parents=parents,
marker_colors=colors,
)
)
#f.update_layout(
# treemapcolorway = ['pink']
#)
breakpoint()
# converting the ndarry with columns names to a pandas df
# 3284 bytes as an ndarry
# 3300 bytes as a dataframe
# so they're the same size
# do I need to convert it to pandas? idk
# just curious
# https://linuxtut.com/en/ffb2e319db5545965933/
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
# figuring out how the thing works
# `value` is the predicted class / value / whatever
# so if it's a leaf node, it returns that value as the prediction
# there are negative values in some of the leaves
# maybe the classes are +/-1 instead of 0/1?
# if the data value is <= `num_threshold` then it goes in the left node
# if it's > `num_threshold` then it goes in the right node
# okay and then all the leave have feature_idx=0, num_threshold=0, left=0, right=0
# that makes sense
# still kind of annoying that they use 0 instead of np.nan but oh well
# also super super hard to figure out what the labels on the tree map should be
# like it has to check the parent's feature_idx and num_threshold
# which I guess isn't too bad once we have the list of parents already built
# except that I don't know whether a node is left or right from its parent
# hmmmm