hgbdt-viz / viz_classifier.py
Working version of the streamlit animation
history blame
5.96 kB
import joblib
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
hgb = joblib.load('hgb_classifier.joblib')
#'state', I dropped this one when I trained the model
# plotly only has the CSS named colors
# I don't think I can use xkcd colors
# I copied a bunch of CSS colors from somewhere online
# and then deleted whites and things that showed up too close on the tree
# this is not really a general solution, it just works for this specific tree
# I'll have to come up with a better colormap at some point
trees = [x[0].nodes for x in hgb._predictors]
# the final tree definitely has a similar structure but is noticably different
# that's really cool
# I think this will make a cool animation
# if I can figure it out
tree = pd.DataFrame(trees[0])
#tree = pd.DataFrame(trees[9])
# parents is going to be tricky
# I need get the index of whichever node has the current node listed in either left or right
parents = [None]
# keep track of whether each node is a left or right child of the parent in the list
directions = [None]
# it uses 0 to say "no left/right child"
# so I have to skip searching for node 0
# which is fine b/c node 0 is the root
for i in tree.index[1:]:
# it seems to make a very even tree
# so just guess it's in the right side
# and that will be right half the time
parent = tree[tree['right']==i].index
if parent.empty:
# generate the labels
# and the colors
labels = ['Histogram Gradient-Boosted Decision Tree']
colors = ['white']
for i, node, parent, direction in zip(
# skip the first one (the root)
if i == 0:
node = node[1]
feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
thresh = tree.loc[int(parent), 'num_threshold']
if direction == 'l':
labels.append(f"[{i}] {feat} <= {thresh}")
labels.append(f"[{i}] {feat} > {thresh}")
# colors
offset = FEATS.index(feat)
# actual plot
f = go.Figure(
# treemapcolorway = ['pink']
# converting the ndarry with columns names to a pandas df
# 3284 bytes as an ndarry
# 3300 bytes as a dataframe
# so they're the same size
# do I need to convert it to pandas? idk
# just curious
# https://linuxtut.com/en/ffb2e319db5545965933/
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
# figuring out how the thing works
# `value` is the predicted class / value / whatever
# so if it's a leaf node, it returns that value as the prediction
# there are negative values in some of the leaves
# maybe the classes are +/-1 instead of 0/1?
# if the data value is <= `num_threshold` then it goes in the left node
# if it's > `num_threshold` then it goes in the right node
# okay and then all the leave have feature_idx=0, num_threshold=0, left=0, right=0
# that makes sense
# still kind of annoying that they use 0 instead of np.nan but oh well
# also super super hard to figure out what the labels on the tree map should be
# like it has to check the parent's feature_idx and num_threshold
# which I guess isn't too bad once we have the list of parents already built
# except that I don't know whether a node is left or right from its parent
# hmmmm