Spaces:
Running
Running
import json | |
import random | |
import string | |
import warnings | |
import numpy as np | |
from . import colors | |
try: | |
from IPython.display import HTML | |
from IPython.display import display as ipython_display | |
have_ipython = True | |
except ImportError: | |
have_ipython = False | |
# TODO: we should support text output explanations (from models that output text not numbers), this would require the force | |
# the force plot and the coloring to update based on mouseovers (or clicks to make it fixed) of the output text | |
def text( | |
shap_values, | |
num_starting_labels=0, | |
grouping_threshold=0.01, | |
separator="", | |
xmin=None, | |
xmax=None, | |
cmax=None, | |
display=True, | |
): | |
"""Plots an explanation of a string of text using coloring and interactive labels. | |
The output is interactive HTML and you can click on any token to toggle the display of the | |
SHAP value assigned to that token. | |
Parameters | |
---------- | |
shap_values : [numpy.array] | |
List of arrays of SHAP values. Each array has the shap values for a string (#input_tokens x output_tokens). | |
num_starting_labels : int | |
Number of tokens (sorted in descending order by corresponding SHAP values) | |
that are uncovered in the initial view. | |
When set to 0, all tokens are covered. | |
grouping_threshold : float | |
If the component substring effects are less than a ``grouping_threshold`` | |
fraction of an unlowered interaction effect, then we visualize the entire group | |
as a single chunk. This is primarily used for explanations that were computed | |
with fixed_context set to 1 or 0 when using the :class:`.explainers.Partition` | |
explainer, since this causes interaction effects to be left on internal nodes | |
rather than lowered. | |
separator : string | |
The string separator that joins tokens grouped by interaction effects and | |
unbroken string spans. Defaults to the empty string ``""``. | |
xmin : float | |
Minimum shap value bound. | |
xmax : float | |
Maximum shap value bound. | |
cmax : float | |
Maximum absolute shap value for sample. Used for scaling colors for input tokens. | |
display: bool | |
Whether to display or return html to further manipulate or embed. Default: ``True`` | |
Examples | |
-------- | |
See `text plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/text.html>`_. | |
""" | |
def values_min_max(values, base_values): | |
"""Used to pick our axis limits.""" | |
fx = base_values + values.sum() | |
xmin = fx - values[values > 0].sum() | |
xmax = fx - values[values < 0].sum() | |
cmax = max(abs(values.min()), abs(values.max())) | |
d = xmax - xmin | |
xmin -= 0.1 * d | |
xmax += 0.1 * d | |
return xmin, xmax, cmax | |
uuid = "".join(random.choices(string.ascii_lowercase, k=20)) | |
# loop when we get multi-row inputs | |
if len(shap_values.shape) == 2 and (shap_values.output_names is None or isinstance(shap_values.output_names, str)): | |
xmin = 0 | |
xmax = 0 | |
cmax = 0 | |
for i, v in enumerate(shap_values): | |
values, clustering = unpack_shap_explanation_contents(v) | |
tokens, values, group_sizes = process_shap_values(v.data, values, grouping_threshold, separator, clustering) | |
if i == 0: | |
xmin, xmax, cmax = values_min_max(values, v.base_values) | |
continue | |
xmin_i, xmax_i, cmax_i = values_min_max(values, v.base_values) | |
if xmin_i < xmin: | |
xmin = xmin_i | |
if xmax_i > xmax: | |
xmax = xmax_i | |
if cmax_i > cmax: | |
cmax = cmax_i | |
out = "" | |
for i, v in enumerate(shap_values): | |
out += f""" | |
<br> | |
<hr style="height: 1px; background-color: #fff; border: none; margin-top: 18px; margin-bottom: 18px; border-top: 1px dashed #ccc;""> | |
<div align="center" style="margin-top: -35px;"><div style="display: inline-block; background: #fff; padding: 5px; color: #999; font-family: monospace">[{i}]</div> | |
</div> | |
""" | |
out += text( | |
v, | |
num_starting_labels=num_starting_labels, | |
grouping_threshold=grouping_threshold, | |
separator=separator, | |
xmin=xmin, | |
xmax=xmax, | |
cmax=cmax, | |
display=False, | |
) | |
if display: | |
_ipython_display_html(out) | |
return | |
else: | |
return out | |
if len(shap_values.shape) == 2 and shap_values.output_names is not None: | |
xmin_computed = None | |
xmax_computed = None | |
cmax_computed = None | |
for i in range(shap_values.shape[-1]): | |
values, clustering = unpack_shap_explanation_contents(shap_values[:, i]) | |
tokens, values, group_sizes = process_shap_values( | |
shap_values[:, i].data, values, grouping_threshold, separator, clustering | |
) | |
# if i == 0: | |
# xmin, xmax, cmax = values_min_max(values, shap_values[:,i].base_values) | |
# continue | |
xmin_i, xmax_i, cmax_i = values_min_max(values, shap_values[:, i].base_values) | |
if xmin_computed is None or xmin_i < xmin_computed: | |
xmin_computed = xmin_i | |
if xmax_computed is None or xmax_i > xmax_computed: | |
xmax_computed = xmax_i | |
if cmax_computed is None or cmax_i > cmax_computed: | |
cmax_computed = cmax_i | |
if xmin is None: | |
xmin = xmin_computed | |
if xmax is None: | |
xmax = xmax_computed | |
if cmax is None: | |
cmax = cmax_computed | |
out = f"""<div align='center'> | |
<script> | |
document._hover_{uuid} = '_tp_{uuid}_output_0'; | |
document._zoom_{uuid} = undefined; | |
function _output_onclick_{uuid}(i) {{ | |
var next_id = undefined; | |
if (document._zoom_{uuid} !== undefined) {{ | |
document.getElementById(document._zoom_{uuid}+ '_zoom').style.display = 'none'; | |
if (document._zoom_{uuid} === '_tp_{uuid}_output_' + i) {{ | |
document.getElementById(document._zoom_{uuid}).style.display = 'block'; | |
document.getElementById(document._zoom_{uuid}+'_name').style.borderBottom = '3px solid #000000'; | |
}} else {{ | |
document.getElementById(document._zoom_{uuid}).style.display = 'none'; | |
document.getElementById(document._zoom_{uuid}+'_name').style.borderBottom = 'none'; | |
}} | |
}} | |
if (document._zoom_{uuid} !== '_tp_{uuid}_output_' + i) {{ | |
next_id = '_tp_{uuid}_output_' + i; | |
document.getElementById(next_id).style.display = 'none'; | |
document.getElementById(next_id + '_zoom').style.display = 'block'; | |
document.getElementById(next_id+'_name').style.borderBottom = '3px solid #000000'; | |
}} | |
document._zoom_{uuid} = next_id; | |
}} | |
function _output_onmouseover_{uuid}(i, el) {{ | |
if (document._zoom_{uuid} !== undefined) {{ return; }} | |
if (document._hover_{uuid} !== undefined) {{ | |
document.getElementById(document._hover_{uuid} + '_name').style.borderBottom = 'none'; | |
document.getElementById(document._hover_{uuid}).style.display = 'none'; | |
}} | |
document.getElementById('_tp_{uuid}_output_' + i).style.display = 'block'; | |
el.style.borderBottom = '3px solid #000000'; | |
document._hover_{uuid} = '_tp_{uuid}_output_' + i; | |
}} | |
</script> | |
<div style=\"color: rgb(120,120,120); font-size: 12px;\">outputs</div>""" | |
output_values = shap_values.values.sum(0) + shap_values.base_values | |
output_max = np.max(np.abs(output_values)) | |
for i, name in enumerate(shap_values.output_names): | |
scaled_value = 0.5 + 0.5 * float(output_values[i]) / (float(output_max) + 1e-8) | |
color = colors.red_transparent_blue(scaled_value) | |
color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])) | |
# '#dddddd' if i == 0 else '#ffffff' border-bottom: {'3px solid #000000' if i == 0 else 'none'}; | |
out += f""" | |
<div style="display: inline; border-bottom: {"3px solid #000000" if i == 0 else "none"}; background: rgba{color}; border-radius: 3px; padding: 0px" id="_tp_{uuid}_output_{i}_name" | |
onclick="_output_onclick_{uuid}({i})" | |
onmouseover="_output_onmouseover_{uuid}({i}, this);">{name}</div>""" | |
out += "<br><br>" | |
for i, name in enumerate(shap_values.output_names): | |
out += f"<div id='_tp_{uuid}_output_{i}' style='display: {'block' if i == 0 else 'none'}';>" | |
out += text( | |
shap_values[:, i], | |
num_starting_labels=num_starting_labels, | |
grouping_threshold=grouping_threshold, | |
separator=separator, | |
xmin=xmin, | |
xmax=xmax, | |
cmax=cmax, | |
display=False, | |
) | |
out += "</div>" | |
out += f"<div id='_tp_{uuid}_output_{i}_zoom' style='display: none;'>" | |
out += text( | |
shap_values[:, i], | |
num_starting_labels=num_starting_labels, | |
grouping_threshold=grouping_threshold, | |
separator=separator, | |
display=False, | |
) | |
out += "</div>" | |
out += "</div>" | |
if display: | |
_ipython_display_html(out) | |
return | |
else: | |
return out | |
# text_to_text(shap_values) | |
# return | |
if len(shap_values.shape) == 3: | |
xmin_computed = None | |
xmax_computed = None | |
cmax_computed = None | |
for i in range(shap_values.shape[-1]): | |
for j in range(shap_values.shape[0]): | |
values, clustering = unpack_shap_explanation_contents(shap_values[j, :, i]) | |
tokens, values, group_sizes = process_shap_values( | |
shap_values[j, :, i].data, values, grouping_threshold, separator, clustering | |
) | |
xmin_i, xmax_i, cmax_i = values_min_max(values, shap_values[j, :, i].base_values) | |
if xmin_computed is None or xmin_i < xmin_computed: | |
xmin_computed = xmin_i | |
if xmax_computed is None or xmax_i > xmax_computed: | |
xmax_computed = xmax_i | |
if cmax_computed is None or cmax_i > cmax_computed: | |
cmax_computed = cmax_i | |
if xmin is None: | |
xmin = xmin_computed | |
if xmax is None: | |
xmax = xmax_computed | |
if cmax is None: | |
cmax = cmax_computed | |
out = "" | |
for i, v in enumerate(shap_values): | |
out += f""" | |
<br> | |
<hr style="height: 1px; background-color: #fff; border: none; margin-top: 18px; margin-bottom: 18px; border-top: 1px dashed #ccc;""> | |
<div align="center" style="margin-top: -35px;"><div style="display: inline-block; background: #fff; padding: 5px; color: #999; font-family: monospace">[{i}]</div> | |
</div> | |
""" | |
out += text( | |
v, | |
num_starting_labels=num_starting_labels, | |
grouping_threshold=grouping_threshold, | |
separator=separator, | |
xmin=xmin, | |
xmax=xmax, | |
cmax=cmax, | |
display=False, | |
) | |
if display: | |
_ipython_display_html(out) | |
return | |
else: | |
return out | |
# set any unset bounds | |
xmin_new, xmax_new, cmax_new = values_min_max(shap_values.values, shap_values.base_values) | |
if xmin is None: | |
xmin = xmin_new | |
if xmax is None: | |
xmax = xmax_new | |
if cmax is None: | |
cmax = cmax_new | |
values, clustering = unpack_shap_explanation_contents(shap_values) | |
tokens, values, group_sizes = process_shap_values( | |
shap_values.data, values, grouping_threshold, separator, clustering | |
) | |
# build out HTML output one word one at a time | |
top_inds = np.argsort(-np.abs(values))[:num_starting_labels] | |
out = "" | |
# ev_str = str(shap_values.base_values) | |
# vsum_str = str(values.sum()) | |
# fx_str = str(shap_values.base_values + values.sum()) | |
# uuid = ''.join(random.choices(string.ascii_lowercase, k=20)) | |
encoded_tokens = [t.replace("<", "<").replace(">", ">").replace(" ##", "") for t in tokens] | |
output_name = shap_values.output_names if isinstance(shap_values.output_names, str) else "" | |
out += svg_force_plot( | |
values, | |
shap_values.base_values, | |
shap_values.base_values + values.sum(), | |
encoded_tokens, | |
uuid, | |
xmin, | |
xmax, | |
output_name, | |
) | |
out += ( | |
"<div align='center'><div style=\"color: rgb(120,120,120); font-size: 12px; margin-top: -15px;\">inputs</div>" | |
) | |
for i, token in enumerate(tokens): | |
scaled_value = 0.5 + 0.5 * values[i] / (cmax + 1e-8) | |
color = colors.red_transparent_blue(scaled_value) | |
color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])) | |
# display the labels for the most important words | |
label_display = "none" | |
wrapper_display = "inline" | |
if i in top_inds: | |
label_display = "block" | |
wrapper_display = "inline-block" | |
# create the value_label string | |
value_label = "" | |
if group_sizes[i] == 1: | |
value_label = str(values[i].round(3)) | |
else: | |
value_label = str(values[i].round(3)) + " / " + str(group_sizes[i]) | |
# the HTML for this token | |
out += f"""<div style='display: {wrapper_display}; text-align: center;' | |
><div style='display: {label_display}; color: #999; padding-top: 0px; font-size: 12px;'>{value_label}</div | |
><div id='_tp_{uuid}_ind_{i}' | |
style='display: inline; background: rgba{color}; border-radius: 3px; padding: 0px' | |
onclick=" | |
if (this.previousSibling.style.display == 'none') {{ | |
this.previousSibling.style.display = 'block'; | |
this.parentNode.style.display = 'inline-block'; | |
}} else {{ | |
this.previousSibling.style.display = 'none'; | |
this.parentNode.style.display = 'inline'; | |
}}" | |
onmouseover="document.getElementById('_fb_{uuid}_ind_{i}').style.opacity = 1; document.getElementById('_fs_{uuid}_ind_{i}').style.opacity = 1;" | |
onmouseout="document.getElementById('_fb_{uuid}_ind_{i}').style.opacity = 0; document.getElementById('_fs_{uuid}_ind_{i}').style.opacity = 0;" | |
>{token.replace("<", "<").replace(">", ">").replace(" ##", "")}</div></div>""" | |
out += "</div>" | |
if display: | |
_ipython_display_html(out) | |
return | |
else: | |
return out | |
def process_shap_values(tokens, values, grouping_threshold, separator, clustering=None, return_meta_data=False): | |
# See if we got hierarchical input data. If we did then we need to reprocess the | |
# shap_values and tokens to get the groups we want to display | |
M = len(tokens) | |
if len(values) != M: | |
# make sure we were given a partition tree | |
if clustering is None: | |
raise ValueError( | |
"The length of the attribution values must match the number of " | |
"tokens if shap_values.clustering is None! When passing hierarchical " | |
"attributions the clustering is also required." | |
) | |
# compute the groups, lower_values, and max_values | |
groups = [[i] for i in range(M)] | |
lower_values = np.zeros(len(values)) | |
lower_values[:M] = values[:M] | |
max_values = np.zeros(len(values)) | |
max_values[:M] = np.abs(values[:M]) | |
for i in range(clustering.shape[0]): | |
li = int(clustering[i, 0]) | |
ri = int(clustering[i, 1]) | |
groups.append(groups[li] + groups[ri]) | |
lower_values[M + i] = lower_values[li] + lower_values[ri] + values[M + i] | |
max_values[i + M] = max(abs(values[M + i]) / len(groups[M + i]), max_values[li], max_values[ri]) | |
# compute the upper_values | |
upper_values = np.zeros(len(values)) | |
def lower_credit(upper_values, clustering, i, value=0): | |
if i < M: | |
upper_values[i] = value | |
return | |
li = int(clustering[i - M, 0]) | |
ri = int(clustering[i - M, 1]) | |
upper_values[i] = value | |
value += values[i] | |
# lower_credit(upper_values, clustering, li, value * len(groups[li]) / (len(groups[li]) + len(groups[ri]))) | |
# lower_credit(upper_values, clustering, ri, value * len(groups[ri]) / (len(groups[li]) + len(groups[ri]))) | |
lower_credit(upper_values, clustering, li, value * 0.5) | |
lower_credit(upper_values, clustering, ri, value * 0.5) | |
lower_credit(upper_values, clustering, len(values) - 1) | |
# the group_values comes from the dividends above them and below them | |
group_values = lower_values + upper_values | |
# merge all the tokens in groups dominated by interaction effects (since we don't want to hide those) | |
new_tokens = [] | |
new_values = [] | |
group_sizes = [] | |
# meta data | |
token_id_to_node_id_mapping = np.zeros((M,)) | |
collapsed_node_ids = [] | |
def merge_tokens(new_tokens, new_values, group_sizes, i): | |
# return at the leaves | |
if i < M and i >= 0: | |
new_tokens.append(tokens[i]) | |
new_values.append(group_values[i]) | |
group_sizes.append(1) | |
# meta data | |
collapsed_node_ids.append(i) | |
token_id_to_node_id_mapping[i] = i | |
else: | |
# compute the dividend at internal nodes | |
li = int(clustering[i - M, 0]) | |
ri = int(clustering[i - M, 1]) | |
dv = abs(values[i]) / len(groups[i]) | |
# if the interaction level is too high then just treat this whole group as one token | |
if max(max_values[li], max_values[ri]) < dv * grouping_threshold: | |
new_tokens.append( | |
separator.join([tokens[g] for g in groups[li]]) | |
+ separator | |
+ separator.join([tokens[g] for g in groups[ri]]) | |
) | |
new_values.append(group_values[i]) | |
group_sizes.append(len(groups[i])) | |
# setting collapsed node ids and token id to current node id mapping metadata | |
collapsed_node_ids.append(i) | |
for g in groups[li]: | |
token_id_to_node_id_mapping[g] = i | |
for g in groups[ri]: | |
token_id_to_node_id_mapping[g] = i | |
# if interaction level is not too high we recurse | |
else: | |
merge_tokens(new_tokens, new_values, group_sizes, li) | |
merge_tokens(new_tokens, new_values, group_sizes, ri) | |
merge_tokens(new_tokens, new_values, group_sizes, len(group_values) - 1) | |
# replance the incoming parameters with the grouped versions | |
tokens = np.array(new_tokens) | |
values = np.array(new_values) | |
group_sizes = np.array(group_sizes) | |
# meta data | |
token_id_to_node_id_mapping = np.array(token_id_to_node_id_mapping) | |
collapsed_node_ids = np.array(collapsed_node_ids) | |
M = len(tokens) | |
else: | |
group_sizes = np.ones(M) | |
token_id_to_node_id_mapping = np.arange(M) | |
collapsed_node_ids = np.arange(M) | |
if return_meta_data: | |
return tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids | |
else: | |
return tokens, values, group_sizes | |
def svg_force_plot(values, base_values, fx, tokens, uuid, xmin, xmax, output_name): | |
def xpos(xval): | |
return 100 * (xval - xmin) / (xmax - xmin + 1e-8) | |
s = "" | |
s += '<svg width="100%" height="80px">' | |
### x-axis marks ### | |
# draw x axis line | |
s += '<line x1="0" y1="33" x2="100%" y2="33" style="stroke:rgb(150,150,150);stroke-width:1" />' | |
# draw base value | |
def draw_tick_mark(xval, label=None, bold=False, backing=False): | |
s = "" | |
s += f'<line x1="{xpos(xval)}%" y1="33" x2="{xpos(xval)}%" y2="37" style="stroke:rgb(150,150,150);stroke-width:1" />' | |
if not bold: | |
if backing: | |
s += f'<text x="{xpos(xval)}%" y="27" font-size="13px" style="stroke:#ffffff;stroke-width:8px;" fill="rgb(255,255,255)" dominant-baseline="bottom" text-anchor="middle">{xval:g}</text>' | |
s += f'<text x="{xpos(xval)}%" y="27" font-size="12px" fill="rgb(120,120,120)" dominant-baseline="bottom" text-anchor="middle">{xval:g}</text>' | |
else: | |
if backing: | |
s += f'<text x="{xpos(xval)}%" y="27" font-size="13px" style="stroke:#ffffff;stroke-width:8px;" font-weight="bold" fill="rgb(255,255,255)" dominant-baseline="bottom" text-anchor="middle">{xval:g}</text>' | |
s += f'<text x="{xpos(xval)}%" y="27" font-size="13px" font-weight="bold" fill="rgb(0,0,0)" dominant-baseline="bottom" text-anchor="middle">{xval:g}</text>' | |
if label is not None: | |
s += f'<text x="{xpos(xval)}%" y="10" font-size="12px" fill="rgb(120,120,120)" dominant-baseline="bottom" text-anchor="middle">{label}</text>' | |
return s | |
xcenter = round((xmax + xmin) / 2, int(round(1 - np.log10(xmax - xmin + 1e-8)))) | |
s += draw_tick_mark(xcenter) | |
# np.log10(xmax - xmin) | |
tick_interval = round((xmax - xmin) / 7, int(round(1 - np.log10(xmax - xmin + 1e-8)))) | |
# tick_interval = (xmax - xmin) / 7 | |
side_buffer = (xmax - xmin) / 14 | |
for i in range(1, 10): | |
pos = xcenter - i * tick_interval | |
if pos < xmin + side_buffer: | |
break | |
s += draw_tick_mark(pos) | |
for i in range(1, 10): | |
pos = xcenter + i * tick_interval | |
if pos > xmax - side_buffer: | |
break | |
s += draw_tick_mark(pos) | |
s += draw_tick_mark(base_values, label="base value", backing=True) | |
s += draw_tick_mark( | |
fx, bold=True, label=f'f<tspan baseline-shift="sub" font-size="8px">{output_name}</tspan>(inputs)', backing=True | |
) | |
### Positive value marks ### | |
red = (float(colors.red_rgb[0]) * 255, float(colors.red_rgb[1])* 255, float(colors.red_rgb[2])* 255) | |
light_red = (255, 195, 213) | |
# draw base red bar | |
x = fx - values[values > 0].sum() | |
w = 100 * values[values > 0].sum() / (xmax - xmin + 1e-8) | |
s += f'<rect x="{xpos(x)}%" width="{w}%" y="40" height="18" style="fill:rgb{red}; stroke-width:0; stroke:rgb(0,0,0)" />' | |
# draw underline marks and the text labels | |
pos = fx | |
last_pos = pos | |
inds = [i for i in np.argsort(-np.abs(values)) if values[i] > 0] | |
for i, ind in enumerate(inds): | |
v = values[ind] | |
pos -= v | |
# a line under the bar to animate | |
s += f'<line x1="{xpos(pos)}%" x2="{xpos(last_pos)}%" y1="60" y2="60" id="_fb_{uuid}_ind_{ind}" style="stroke:rgb{red};stroke-width:2; opacity: 0"/>' | |
# the text label cropped and centered | |
s += f'<text x="{(xpos(last_pos) + xpos(pos)) / 2}%" y="71" font-size="12px" id="_fs_{uuid}_ind_{ind}" fill="rgb{red}" style="opacity: 0" dominant-baseline="middle" text-anchor="middle">{values[ind].round(3)}</text>' | |
# the text label cropped and centered | |
s += f'<svg x="{xpos(pos)}%" y="40" height="20" width="{xpos(last_pos) - xpos(pos)}%">' | |
s += ' <svg x="0" y="0" width="100%" height="100%">' | |
s += f' <text x="50%" y="9" font-size="12px" fill="rgb(255,255,255)" dominant-baseline="middle" text-anchor="middle">{tokens[ind].strip()}</text>' | |
s += " </svg>" | |
s += "</svg>" | |
last_pos = pos | |
# draw the divider padding (which covers the text near the dividers) | |
pos = fx | |
for i, ind in enumerate(inds): | |
v = values[ind] | |
pos -= v | |
if i != 0: | |
for j in range(4): | |
s += f'<g transform="translate({2 * j - 8},0)">' | |
s += f' <svg x="{xpos(last_pos)}%" y="40" height="18" overflow="visible" width="30">' | |
s += f' <path d="M 0 -9 l 6 18 L 0 25" fill="none" style="stroke:rgb{red};stroke-width:2" />' | |
s += " </svg>" | |
s += "</g>" | |
if i + 1 != len(inds): | |
for j in range(4): | |
s += f'<g transform="translate({2 * j - 0},0)">' | |
s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">' | |
s += f' <path d="M 0 -9 l 6 18 L 0 25" fill="none" style="stroke:rgb{red};stroke-width:2" />' | |
s += " </svg>" | |
s += "</g>" | |
last_pos = pos | |
# center padding | |
s += f'<rect transform="translate(-8,0)" x="{xpos(fx)}%" y="40" width="8" height="18" style="fill:rgb{red}"/>' | |
# cover up a notch at the end of the red bar | |
pos = fx - values[values > 0].sum() | |
s += '<g transform="translate(-11.5,0)">' | |
s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">' | |
s += ' <path d="M 10 -9 l 6 18 L 10 25 L 0 25 L 0 -9" fill="#ffffff" style="stroke:rgb(255,255,255);stroke-width:2" />' | |
s += " </svg>" | |
s += "</g>" | |
# draw the light red divider lines and a rect to handle mouseover events | |
pos = fx | |
last_pos = pos | |
for i, ind in enumerate(inds): | |
v = values[ind] | |
pos -= v | |
# divider line | |
if i + 1 != len(inds): | |
s += '<g transform="translate(-1.5,0)">' | |
s += f' <svg x="{xpos(last_pos)}%" y="40" height="18" overflow="visible" width="30">' | |
s += f' <path d="M 0 -9 l 6 18 L 0 25" fill="none" style="stroke:rgb{light_red};stroke-width:2" />' | |
s += " </svg>" | |
s += "</g>" | |
# mouse over rectangle | |
s += f'<rect x="{xpos(pos)}%" y="40" height="20" width="{xpos(last_pos) - xpos(pos)}%"' | |
s += ' onmouseover="' | |
s += f"document.getElementById('_tp_{uuid}_ind_{ind}').style.textDecoration = 'underline';" | |
s += f"document.getElementById('_fs_{uuid}_ind_{ind}').style.opacity = 1;" | |
s += f"document.getElementById('_fb_{uuid}_ind_{ind}').style.opacity = 1;" | |
s += '"' | |
s += ' onmouseout="' | |
s += f"document.getElementById('_tp_{uuid}_ind_{ind}').style.textDecoration = 'none';" | |
s += f"document.getElementById('_fs_{uuid}_ind_{ind}').style.opacity = 0;" | |
s += f"document.getElementById('_fb_{uuid}_ind_{ind}').style.opacity = 0;" | |
s += '" style="fill:rgb(0,0,0,0)" />' | |
last_pos = pos | |
### Negative value marks ### | |
blue = (float(colors.blue_rgb[0]) * 255, float(colors.blue_rgb[1]) * 255, float(colors.blue_rgb[2]) * 255) | |
light_blue = (208, 230, 250) | |
# draw base blue bar | |
w = 100 * -values[values < 0].sum() / (xmax - xmin + 1e-8) | |
s += f'<rect x="{xpos(fx)}%" width="{w}%" y="40" height="18" style="fill:rgb{blue}; stroke-width:0; stroke:rgb(0,0,0)" />' | |
# draw underline marks and the text labels | |
pos = fx | |
last_pos = pos | |
inds = [i for i in np.argsort(-np.abs(values)) if values[i] < 0] | |
for i, ind in enumerate(inds): | |
v = values[ind] | |
pos -= v | |
# a line under the bar to animate | |
s += f'<line x1="{xpos(last_pos)}%" x2="{xpos(pos)}%" y1="60" y2="60" id="_fb_{uuid}_ind_{ind}" style="stroke:rgb{blue};stroke-width:2; opacity: 0"/>' | |
# the value text | |
s += f'<text x="{(xpos(last_pos) + xpos(pos)) / 2}%" y="71" font-size="12px" fill="rgb{blue}" id="_fs_{uuid}_ind_{ind}" style="opacity: 0" dominant-baseline="middle" text-anchor="middle">{values[ind].round(3)}</text>' | |
# the text label cropped and centered | |
s += f'<svg x="{xpos(last_pos)}%" y="40" height="20" width="{xpos(pos) - xpos(last_pos)}%">' | |
s += ' <svg x="0" y="0" width="100%" height="100%">' | |
s += f' <text x="50%" y="9" font-size="12px" fill="rgb(255,255,255)" dominant-baseline="middle" text-anchor="middle">{tokens[ind].strip()}</text>' | |
s += " </svg>" | |
s += "</svg>" | |
last_pos = pos | |
# draw the divider padding (which covers the text near the dividers) | |
pos = fx | |
for i, ind in enumerate(inds): | |
v = values[ind] | |
pos -= v | |
if i != 0: | |
for j in range(4): | |
s += f'<g transform="translate({-2 * j + 2},0)">' | |
s += f' <svg x="{xpos(last_pos)}%" y="40" height="18" overflow="visible" width="30">' | |
s += f' <path d="M 8 -9 l -6 18 L 8 25" fill="none" style="stroke:rgb{blue};stroke-width:2" />' | |
s += " </svg>" | |
s += "</g>" | |
if i + 1 != len(inds): | |
for j in range(4): | |
s += f'<g transform="translate(-{2 * j + 8},0)">' | |
s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">' | |
s += f' <path d="M 8 -9 l -6 18 L 8 25" fill="none" style="stroke:rgb{blue};stroke-width:2" />' | |
s += " </svg>" | |
s += "</g>" | |
last_pos = pos | |
# center padding | |
s += f'<rect transform="translate(0,0)" x="{xpos(fx)}%" y="40" width="8" height="18" style="fill:rgb{blue}"/>' | |
# cover up a notch at the end of the blue bar | |
pos = fx - values[values < 0].sum() | |
s += '<g transform="translate(-6.0,0)">' | |
s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">' | |
s += ' <path d="M 8 -9 l -6 18 L 8 25 L 20 25 L 20 -9" fill="#ffffff" style="stroke:rgb(255,255,255);stroke-width:2" />' | |
s += " </svg>" | |
s += "</g>" | |
# draw the light blue divider lines and a rect to handle mouseover events | |
pos = fx | |
last_pos = pos | |
for i, ind in enumerate(inds): | |
v = values[ind] | |
pos -= v | |
# divider line | |
if i + 1 != len(inds): | |
s += '<g transform="translate(-6.0,0)">' | |
s += f' <svg x="{xpos(pos)}%" y="40" height="18" overflow="visible" width="30">' | |
s += f' <path d="M 8 -9 l -6 18 L 8 25" fill="none" style="stroke:rgb{light_blue};stroke-width:2" />' | |
s += " </svg>" | |
s += "</g>" | |
# mouse over rectangle | |
s += f'<rect x="{xpos(last_pos)}%" y="40" height="20" width="{xpos(pos) - xpos(last_pos)}%"' | |
s += ' onmouseover="' | |
s += f"document.getElementById('_tp_{uuid}_ind_{ind}').style.textDecoration = 'underline';" | |
s += f"document.getElementById('_fs_{uuid}_ind_{ind}').style.opacity = 1;" | |
s += f"document.getElementById('_fb_{uuid}_ind_{ind}').style.opacity = 1;" | |
s += '"' | |
s += ' onmouseout="' | |
s += f"document.getElementById('_tp_{uuid}_ind_{ind}').style.textDecoration = 'none';" | |
s += f"document.getElementById('_fs_{uuid}_ind_{ind}').style.opacity = 0;" | |
s += f"document.getElementById('_fb_{uuid}_ind_{ind}').style.opacity = 0;" | |
s += '" style="fill:rgb(0,0,0,0)" />' | |
last_pos = pos | |
s += "</svg>" | |
return s | |
def text_old(shap_values, tokens, partition_tree=None, num_starting_labels=0, grouping_threshold=1, separator=""): | |
"""Plots an explanation of a string of text using coloring and interactive labels. | |
The output is interactive HTML and you can click on any token to toggle the display of the | |
SHAP value assigned to that token. | |
""" | |
# See if we got hierarchical input data. If we did then we need to reprocess the | |
# shap_values and tokens to get the groups we want to display | |
warnings.warn( | |
"This function is not used within the shap library and will therefore be removed in an upcoming release. " | |
"If you rely on this function, please open an issue: https://github.com/shap/shap/issues.", | |
FutureWarning, | |
) | |
M = len(tokens) | |
if len(shap_values) != M: | |
# make sure we were given a partition tree | |
if partition_tree is None: | |
raise ValueError( | |
"The length of the attribution values must match the number of " | |
"tokens if partition_tree is None! When passing hierarchical " | |
"attributions the partition_tree is also required." | |
) | |
# compute the groups, lower_values, and max_values | |
groups = [[i] for i in range(M)] | |
lower_values = np.zeros(len(shap_values)) | |
lower_values[:M] = shap_values[:M] | |
max_values = np.zeros(len(shap_values)) | |
max_values[:M] = np.abs(shap_values[:M]) | |
for i in range(partition_tree.shape[0]): | |
li = partition_tree[i, 0] | |
ri = partition_tree[i, 1] | |
groups.append(groups[li] + groups[ri]) | |
lower_values[M + i] = lower_values[li] + lower_values[ri] + shap_values[M + i] | |
max_values[i + M] = max(abs(shap_values[M + i]) / len(groups[M + i]), max_values[li], max_values[ri]) | |
# compute the upper_values | |
upper_values = np.zeros(len(shap_values)) | |
def lower_credit(upper_values, partition_tree, i, value=0): | |
if i < M: | |
upper_values[i] = value | |
return | |
li = partition_tree[i - M, 0] | |
ri = partition_tree[i - M, 1] | |
upper_values[i] = value | |
value += shap_values[i] | |
lower_credit(upper_values, partition_tree, li, value * 0.5) | |
lower_credit(upper_values, partition_tree, ri, value * 0.5) | |
lower_credit(upper_values, partition_tree, len(shap_values) - 1) | |
# the group_values comes from the dividends above them and below them | |
group_values = lower_values + upper_values | |
# merge all the tokens in groups dominated by interaction effects (since we don't want to hide those) | |
new_tokens = [] | |
new_shap_values = [] | |
group_sizes = [] | |
def merge_tokens(new_tokens, new_values, group_sizes, i): | |
# return at the leaves | |
if i < M and i >= 0: | |
new_tokens.append(tokens[i]) | |
new_values.append(group_values[i]) | |
group_sizes.append(1) | |
else: | |
# compute the dividend at internal nodes | |
li = partition_tree[i - M, 0] | |
ri = partition_tree[i - M, 1] | |
dv = abs(shap_values[i]) / len(groups[i]) | |
# if the interaction level is too high then just treat this whole group as one token | |
if dv > grouping_threshold * max(max_values[li], max_values[ri]): | |
new_tokens.append( | |
separator.join([tokens[g] for g in groups[li]]) | |
+ separator | |
+ separator.join([tokens[g] for g in groups[ri]]) | |
) | |
new_values.append(group_values[i] / len(groups[i])) | |
group_sizes.append(len(groups[i])) | |
# if interaction level is not too high we recurse | |
else: | |
merge_tokens(new_tokens, new_values, group_sizes, li) | |
merge_tokens(new_tokens, new_values, group_sizes, ri) | |
merge_tokens(new_tokens, new_shap_values, group_sizes, len(group_values) - 1) | |
# replance the incoming parameters with the grouped versions | |
tokens = np.array(new_tokens) | |
shap_values = np.array(new_shap_values) | |
group_sizes = np.array(group_sizes) | |
M = len(tokens) | |
else: | |
group_sizes = np.ones(M) | |
# build out HTML output one word one at a time | |
top_inds = np.argsort(-np.abs(shap_values))[:num_starting_labels] | |
maxv = shap_values.max() | |
minv = shap_values.min() | |
out = "" | |
for i in range(M): | |
scaled_value = 0.5 + 0.5 * shap_values[i] / max(abs(maxv), abs(minv)) | |
color = colors.red_transparent_blue(scaled_value) | |
color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])) | |
# display the labels for the most important words | |
label_display = "none" | |
wrapper_display = "inline" | |
if i in top_inds: | |
label_display = "block" | |
wrapper_display = "inline-block" | |
# create the value_label string | |
value_label = "" | |
if group_sizes[i] == 1: | |
value_label = str(shap_values[i].round(3)) | |
else: | |
value_label = str((shap_values[i] * group_sizes[i]).round(3)) + " / " + str(group_sizes[i]) | |
# the HTML for this token | |
out += ( | |
"<div style='display: " | |
+ wrapper_display | |
+ "; text-align: center;'>" | |
+ "<div style='display: " | |
+ label_display | |
+ "; color: #999; padding-top: 0px; font-size: 12px;'>" | |
+ value_label | |
+ "</div>" | |
+ "<div " | |
+ "style='display: inline; background: rgba" | |
+ str(color) | |
+ "; border-radius: 3px; padding: 0px'" | |
+ "onclick=\"if (this.previousSibling.style.display == 'none') {" | |
+ "this.previousSibling.style.display = 'block';" | |
+ "this.parentNode.style.display = 'inline-block';" | |
+ "} else {" | |
+ "this.previousSibling.style.display = 'none';" | |
+ "this.parentNode.style.display = 'inline';" | |
+ "}" | |
+ '"' | |
+ ">" | |
+ tokens[i].replace("<", "<").replace(">", ">").replace(" ##", "") | |
+ "</div>" | |
+ "</div>" | |
) | |
return _ipython_display_html(out) | |
def text_to_text(shap_values): | |
# unique ID added to HTML elements and function to avoid collision of different instances | |
uuid = "".join(random.choices(string.ascii_lowercase, k=20)) | |
saliency_plot_markup = saliency_plot(shap_values) | |
heatmap_markup = heatmap(shap_values) | |
html = f""" | |
<html> | |
<div id="{uuid}_viz_container"> | |
<div id="{uuid}_viz_header" style="padding:15px;border-style:solid;margin:5px;font-family:sans-serif;font-weight:bold;"> | |
Visualization Type: | |
<select name="viz_type" id="{uuid}_viz_type" onchange="selectVizType_{uuid}(this)"> | |
<option value="heatmap" selected="selected">Input/Output - Heatmap</option> | |
<option value="saliency-plot">Saliency Plot</option> | |
</select> | |
</div> | |
<div id="{uuid}_content" style="padding:15px;border-style:solid;margin:5px;"> | |
<div id = "{uuid}_saliency_plot_container" class="{uuid}_viz_container" style="display:none"> | |
{saliency_plot_markup} | |
</div> | |
<div id = "{uuid}_heatmap_container" class="{uuid}_viz_container"> | |
{heatmap_markup} | |
</div> | |
</div> | |
</div> | |
</html> | |
""" | |
javascript = f""" | |
<script> | |
function selectVizType_{uuid}(selectObject) {{ | |
/* Hide all viz */ | |
var elements = document.getElementsByClassName("{uuid}_viz_container") | |
for (var i = 0; i < elements.length; i++){{ | |
elements[i].style.display = 'none'; | |
}} | |
var value = selectObject.value; | |
if ( value === "saliency-plot" ){{ | |
document.getElementById('{uuid}_saliency_plot_container').style.display = "block"; | |
}} | |
else if ( value === "heatmap" ) {{ | |
document.getElementById('{uuid}_heatmap_container').style.display = "block"; | |
}} | |
}} | |
</script> | |
""" | |
_ipython_display_html(javascript + html) | |
def saliency_plot(shap_values): | |
uuid = "".join(random.choices(string.ascii_lowercase, k=20)) | |
unpacked_values, clustering = unpack_shap_explanation_contents(shap_values) | |
tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids = process_shap_values( | |
shap_values.data, unpacked_values[:, 0], 1, "", clustering, True | |
) | |
def compress_shap_matrix(shap_matrix, group_sizes): | |
compressed_matrix = np.zeros((group_sizes.shape[0], shap_matrix.shape[1])) | |
counter = 0 | |
for index in range(len(group_sizes)): | |
compressed_matrix[index, :] = np.sum(shap_matrix[counter : counter + group_sizes[index], :], axis=0) | |
counter += group_sizes[index] | |
return compressed_matrix | |
compressed_shap_matrix = compress_shap_matrix(shap_values.values, group_sizes) | |
# generate background colors of saliency plot | |
def get_colors(shap_values): | |
input_colors = [] | |
cmax = max(abs(compressed_shap_matrix.min()), abs(compressed_shap_matrix.max())) | |
for row_index in range(compressed_shap_matrix.shape[0]): | |
input_colors_row = [] | |
for col_index in range(compressed_shap_matrix.shape[1]): | |
scaled_value = 0.5 + 0.5 * compressed_shap_matrix[row_index, col_index] / cmax | |
color = colors.red_transparent_blue(scaled_value) | |
color = "rgba" + str((float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3]))) | |
input_colors_row.append(color) | |
input_colors.append(input_colors_row) | |
return input_colors | |
model_output = shap_values.output_names | |
input_colors = get_colors(shap_values) | |
out = '<table border = "1" cellpadding = "5" cellspacing = "5" style="overflow-x:scroll;display:block;">' | |
# add top row containing input tokens | |
out += "<tr>" | |
out += "<th></th>" | |
for j in range(compressed_shap_matrix.shape[0]): | |
out += ( | |
"<th>" | |
+ tokens[j].replace("<", "<").replace(">", ">").replace(" ##", "").replace("▁", "").replace("Ġ", "") | |
+ "</th>" | |
) | |
out += "</tr>" | |
for row_index in range(compressed_shap_matrix.shape[1]): | |
out += "<tr>" | |
out += ( | |
"<th>" | |
+ model_output[row_index] | |
.replace("<", "<") | |
.replace(">", ">") | |
.replace(" ##", "") | |
.replace("▁", "") | |
.replace("Ġ", "") | |
+ "</th>" | |
) | |
for col_index in range(compressed_shap_matrix.shape[0]): | |
out += ( | |
'<th style="background:' | |
+ input_colors[col_index][row_index] | |
+ '">' | |
+ str(round(compressed_shap_matrix[col_index][row_index], 3)) | |
+ "</th>" | |
) | |
out += "</tr>" | |
out += "</table>" | |
saliency_plot_html = f""" | |
<div id="{uuid}_saliency_plot" class="{uuid}_viz_content"> | |
<div style="margin:5px;font-family:sans-serif;font-weight:bold;"> | |
<span style="font-size: 20px;"> Saliency Plot </span> | |
<br> | |
x-axis: Output Text | |
<br> | |
y-axis: Input Text | |
</div> | |
{out} | |
</div> | |
""" | |
return saliency_plot_html | |
def heatmap(shap_values): | |
# constants | |
TREE_NODE_KEY_TOKENS = "tokens" | |
TREE_NODE_KEY_CHILDREN = "children" | |
uuid = "".join(random.choices(string.ascii_lowercase, k=20)) | |
def get_color(shap_value, cmax): | |
scaled_value = 0.5 + 0.5 * shap_value / cmax | |
color = colors.red_transparent_blue(scaled_value) | |
color = (float(color[0]) * 255, float(color[1]) * 255, float(color[2]) * 255, float(color[3])) | |
return color | |
def process_text_to_text_shap_values(shap_values): | |
processed_values = [] | |
unpacked_values, clustering = unpack_shap_explanation_contents(shap_values) | |
max_val = 0 | |
for index, output_token in enumerate(shap_values.output_names): | |
tokens, values, group_sizes, token_id_to_node_id_mapping, collapsed_node_ids = process_shap_values( | |
shap_values.data, unpacked_values[:, index], 1, "", clustering, True | |
) | |
processed_value = { | |
"tokens": tokens, | |
"values": values, | |
"group_sizes": group_sizes, | |
"token_id_to_node_id_mapping": token_id_to_node_id_mapping, | |
"collapsed_node_ids": collapsed_node_ids, | |
} | |
processed_values.append(processed_value) | |
max_val = max(max_val, np.max(values)) | |
return processed_values, max_val | |
# unpack input tokens and output tokens | |
model_input = shap_values.data | |
model_output = shap_values.output_names | |
processed_values, max_val = process_text_to_text_shap_values(shap_values) | |
# generate dictionary containing precomputed background colors and shap values which are addressable by html token ids | |
colors_dict = {} | |
shap_values_dict = {} | |
token_id_to_node_id_mapping = {} | |
cmax = max(abs(shap_values.values.min()), abs(shap_values.values.max()), max_val) | |
# input token -> output token color and label value mapping | |
for row_index in range(len(model_input)): | |
color_values = {} | |
shap_values_list = {} | |
for col_index in range(len(model_output)): | |
color_values[uuid + "_output_flat_token_" + str(col_index)] = "rgba" + str( | |
get_color(shap_values.values[row_index][col_index], cmax) | |
) | |
shap_values_list[uuid + "_output_flat_value_label_" + str(col_index)] = round( | |
shap_values.values[row_index][col_index], 3 | |
) | |
colors_dict[f"{uuid}_input_node_{row_index}_content"] = color_values | |
shap_values_dict[f"{uuid}_input_node_{row_index}_content"] = shap_values_list | |
# output token -> input token color and label value mapping | |
for col_index in range(len(model_output)): | |
color_values = {} | |
shap_values_list = {} | |
for row_index in range(processed_values[col_index]["collapsed_node_ids"].shape[0]): | |
color_values[ | |
uuid + "_input_node_" + str(processed_values[col_index]["collapsed_node_ids"][row_index]) + "_content" | |
] = "rgba" + str(get_color(processed_values[col_index]["values"][row_index], cmax)) | |
shap_label_value_str = str(round(processed_values[col_index]["values"][row_index], 3)) | |
if processed_values[col_index]["group_sizes"][row_index] > 1: | |
shap_label_value_str += "/" + str(processed_values[col_index]["group_sizes"][row_index]) | |
shap_values_list[ | |
uuid + "_input_node_" + str(processed_values[col_index]["collapsed_node_ids"][row_index]) + "_label" | |
] = shap_label_value_str | |
colors_dict[uuid + "_output_flat_token_" + str(col_index)] = color_values | |
shap_values_dict[uuid + "_output_flat_token_" + str(col_index)] = shap_values_list | |
token_id_to_node_id_mapping_dict = {} | |
for index, node_id in enumerate(processed_values[col_index]["token_id_to_node_id_mapping"].tolist()): | |
token_id_to_node_id_mapping_dict[f"{uuid}_input_node_{index}_content"] = ( | |
f"{uuid}_input_node_{int(node_id)}_content" | |
) | |
token_id_to_node_id_mapping[uuid + "_output_flat_token_" + str(col_index)] = token_id_to_node_id_mapping_dict | |
# convert python dictionary into json to be inserted into the runtime javascript environment | |
colors_json = json.dumps(colors_dict) | |
shap_values_json = json.dumps(shap_values_dict) | |
token_id_to_node_id_mapping_json = json.dumps(token_id_to_node_id_mapping) | |
javascript_values = ( | |
"<script> " | |
f"colors_{uuid} = {colors_json}\n" | |
f" shap_values_{uuid} = {shap_values_json}\n" | |
f" token_id_to_node_id_mapping_{uuid} = {token_id_to_node_id_mapping_json}\n" | |
"</script> \n " | |
) | |
def generate_tree(shap_values): | |
num_tokens = len(shap_values.data) | |
token_list = {} | |
for index in range(num_tokens): | |
node_content = {} | |
node_content[TREE_NODE_KEY_TOKENS] = shap_values.data[index] | |
node_content[TREE_NODE_KEY_CHILDREN] = {} | |
token_list[str(index)] = node_content | |
counter = num_tokens | |
for pair in shap_values.clustering: | |
first_node = str(int(pair[0])) | |
second_node = str(int(pair[1])) | |
new_node_content = {} | |
new_node_content[TREE_NODE_KEY_CHILDREN] = { | |
first_node: token_list[first_node], | |
second_node: token_list[second_node], | |
} | |
token_list[str(counter)] = new_node_content | |
counter += 1 | |
del token_list[first_node] | |
del token_list[second_node] | |
return token_list | |
tree = generate_tree(shap_values) | |
# generates the input token html elements | |
# each element contains the label value (initially hidden) and the token text | |
input_text_html = "" | |
def populate_input_tree(input_index, token_list_subtree, input_text_html): | |
content = token_list_subtree[input_index] | |
input_text_html += ( | |
f'<div id="{uuid}_input_node_{input_index}_container" style="display:inline;text-align:center">' | |
) | |
input_text_html += ( | |
f'<div id="{uuid}_input_node_{input_index}_label" style="display:none; padding-top: 0px; font-size:12px;">' | |
) | |
input_text_html += "</div>" | |
if token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN]: | |
input_text_html += f'<div id="{uuid}_input_node_{input_index}_content" style="display:inline;">' | |
for child_index, child_content in token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN].items(): | |
input_text_html = populate_input_tree( | |
child_index, token_list_subtree[input_index][TREE_NODE_KEY_CHILDREN], input_text_html | |
) | |
input_text_html += "</div>" | |
else: | |
input_text_html += ( | |
f'<div id="{uuid}_input_node_{input_index}_content"' | |
"style='display: inline; background:transparent; border-radius: 3px; padding: 0px;cursor: default;cursor: pointer;'" | |
f'onmouseover="onMouseHoverFlat_{uuid}(this.id)" ' | |
f'onmouseout="onMouseOutFlat_{uuid}(this.id)" ' | |
f'onclick="onMouseClickFlat_{uuid}(this.id)" ' | |
">" | |
) | |
input_text_html += ( | |
content[TREE_NODE_KEY_TOKENS] | |
.replace("<", "<") | |
.replace(">", ">") | |
.replace(" ##", "") | |
.replace("▁", "") | |
.replace("Ġ", "") | |
) | |
input_text_html += "</div>" | |
input_text_html += "</div>" | |
return input_text_html | |
input_text_html = populate_input_tree(list(tree.keys())[0], tree, input_text_html) | |
# generates the output token html elements | |
output_text_html = "" | |
for i in range(len(model_output)): | |
output_text_html += ( | |
"<div style='display:inline; text-align:center;'>" | |
f"<div id='{uuid}_output_flat_value_label_{i}'" | |
"style='display:none;color: #999; padding-top: 0px; font-size:12px;'>" | |
"</div>" | |
f"<div id='{uuid}_output_flat_token_{i}'" | |
"style='display: inline; background:transparent; border-radius: 3px; padding: 0px;cursor: default;cursor: pointer;'" | |
f'onmouseover="onMouseHoverFlat_{uuid}(this.id)" ' | |
f'onmouseout="onMouseOutFlat_{uuid}(this.id)" ' | |
f'onclick="onMouseClickFlat_{uuid}(this.id)" ' | |
">" | |
+ model_output[i] | |
.replace("<", "<") | |
.replace(">", ">") | |
.replace(" ##", "") | |
.replace("▁", "") | |
.replace("Ġ", "") | |
+ " </div>" | |
+ "</div>" | |
) | |
heatmap_html = f""" | |
<div id="{uuid}_heatmap" class="{uuid}_viz_content"> | |
<div id="{uuid}_heatmap_header" style="padding:15px;margin:5px;font-family:sans-serif;font-weight:bold;"> | |
<div style="display:inline"> | |
<span style="font-size: 20px;"> Input/Output - Heatmap </span> | |
</div> | |
<div style="display:inline;float:right"> | |
Layout : | |
<select name="alignment" id="{uuid}_alignment" onchange="selectAlignment_{uuid}(this)"> | |
<option value="left-right" selected="selected">Left/Right</option> | |
<option value="top-bottom">Top/Bottom</option> | |
</select> | |
</div> | |
</div> | |
<div id="{uuid}_heatmap_content" style="display:flex;"> | |
<div id="{uuid}_input_container" style="padding:15px;border-style:solid;margin:5px;flex:1;"> | |
<div id="{uuid}_input_header" style="margin:5px;font-weight:bold;font-family:sans-serif;margin-bottom:10px"> | |
Input Text | |
</div> | |
<div id="{uuid}_input_content" style="margin:5px;font-family:sans-serif;"> | |
{input_text_html} | |
</div> | |
</div> | |
<div id="{uuid}_output_container" style="padding:15px;border-style:solid;margin:5px;flex:1;"> | |
<div id="{uuid}_output_header" style="margin:5px;font-weight:bold;font-family:sans-serif;margin-bottom:10px"> | |
Output Text | |
</div> | |
<div id="{uuid}_output_content" style="margin:5px;font-family:sans-serif;"> | |
{output_text_html} | |
</div> | |
</div> | |
</div> | |
</div> | |
""" | |
heatmap_javascript = f""" | |
<script> | |
function selectAlignment_{uuid}(selectObject) {{ | |
var value = selectObject.value; | |
if ( value === "left-right" ){{ | |
document.getElementById('{uuid}_heatmap_content').style.display = "flex"; | |
}} | |
else if ( value === "top-bottom" ) {{ | |
document.getElementById('{uuid}_heatmap_content').style.display = "inline"; | |
}} | |
}} | |
var {uuid}_heatmap_flat_state = null; | |
function onMouseHoverFlat_{uuid}(id) {{ | |
if ({uuid}_heatmap_flat_state === null) {{ | |
setBackgroundColors_{uuid}(id); | |
document.getElementById(id).style.backgroundColor = "grey"; | |
}} | |
if (getIdSide_{uuid}(id) === 'input' && getIdSide_{uuid}({uuid}_heatmap_flat_state) === 'output'){{ | |
label_content_id = token_id_to_node_id_mapping_{uuid}[{uuid}_heatmap_flat_state][id]; | |
if (document.getElementById(label_content_id).previousElementSibling.style.display == 'none'){{ | |
document.getElementById(label_content_id).style.textShadow = "0px 0px 1px #000000"; | |
}} | |
}} | |
}} | |
function onMouseOutFlat_{uuid}(id) {{ | |
if ({uuid}_heatmap_flat_state === null) {{ | |
cleanValuesAndColors_{uuid}(id); | |
document.getElementById(id).style.backgroundColor = "transparent"; | |
}} | |
if (getIdSide_{uuid}(id) === 'input' && getIdSide_{uuid}({uuid}_heatmap_flat_state) === 'output'){{ | |
label_content_id = token_id_to_node_id_mapping_{uuid}[{uuid}_heatmap_flat_state][id]; | |
if (document.getElementById(label_content_id).previousElementSibling.style.display == 'none'){{ | |
document.getElementById(label_content_id).style.textShadow = "inherit"; | |
}} | |
}} | |
}} | |
function onMouseClickFlat_{uuid}(id) {{ | |
if ({uuid}_heatmap_flat_state === id) {{ | |
// If the clicked token was already selected | |
document.getElementById(id).style.backgroundColor = "transparent"; | |
cleanValuesAndColors_{uuid}(id); | |
{uuid}_heatmap_flat_state = null; | |
}} | |
else {{ | |
if ({uuid}_heatmap_flat_state === null) {{ | |
// No token previously selected, new token clicked on | |
cleanValuesAndColors_{uuid}(id) | |
{uuid}_heatmap_flat_state = id; | |
document.getElementById(id).style.backgroundColor = "grey"; | |
setLabelValues_{uuid}(id); | |
setBackgroundColors_{uuid}(id); | |
}} | |
else {{ | |
if (getIdSide_{uuid}({uuid}_heatmap_flat_state) === getIdSide_{uuid}(id)) {{ | |
// User clicked a token on the same side as the currently selected token | |
cleanValuesAndColors_{uuid}({uuid}_heatmap_flat_state) | |
document.getElementById({uuid}_heatmap_flat_state).style.backgroundColor = "transparent"; | |
{uuid}_heatmap_flat_state = id; | |
document.getElementById(id).style.backgroundColor = "grey"; | |
setLabelValues_{uuid}(id); | |
setBackgroundColors_{uuid}(id); | |
}} | |
else{{ | |
if (getIdSide_{uuid}(id) === 'input') {{ | |
label_content_id = token_id_to_node_id_mapping_{uuid}[{uuid}_heatmap_flat_state][id]; | |
if (document.getElementById(label_content_id).previousElementSibling.style.display == 'none') {{ | |
document.getElementById(label_content_id).previousElementSibling.style.display = 'block'; | |
document.getElementById(label_content_id).parentNode.style.display = 'inline-block'; | |
document.getElementById(label_content_id).style.textShadow = "0px 0px 1px #000000"; | |
}} | |
else {{ | |
document.getElementById(label_content_id).previousElementSibling.style.display = 'none'; | |
document.getElementById(label_content_id).parentNode.style.display = 'inline'; | |
document.getElementById(label_content_id).style.textShadow = "inherit"; | |
}} | |
}} | |
else {{ | |
if (document.getElementById(id).previousElementSibling.style.display == 'none') {{ | |
document.getElementById(id).previousElementSibling.style.display = 'block'; | |
document.getElementById(id).parentNode.style.display = 'inline-block'; | |
}} | |
else {{ | |
document.getElementById(id).previousElementSibling.style.display = 'none'; | |
document.getElementById(id).parentNode.style.display = 'inline'; | |
}} | |
}} | |
}} | |
}} | |
}} | |
}} | |
function setLabelValues_{uuid}(id) {{ | |
for(const token in shap_values_{uuid}[id]){{ | |
document.getElementById(token).innerHTML = shap_values_{uuid}[id][token]; | |
document.getElementById(token).nextElementSibling.title = 'SHAP Value : ' + shap_values_{uuid}[id][token]; | |
}} | |
}} | |
function setBackgroundColors_{uuid}(id) {{ | |
for(const token in colors_{uuid}[id]){{ | |
document.getElementById(token).style.backgroundColor = colors_{uuid}[id][token]; | |
}} | |
}} | |
function cleanValuesAndColors_{uuid}(id) {{ | |
for(const token in shap_values_{uuid}[id]){{ | |
document.getElementById(token).innerHTML = ""; | |
document.getElementById(token).nextElementSibling.title = ""; | |
}} | |
for(const token in colors_{uuid}[id]){{ | |
document.getElementById(token).style.backgroundColor = "transparent"; | |
document.getElementById(token).previousElementSibling.style.display = 'none'; | |
document.getElementById(token).parentNode.style.display = 'inline'; | |
document.getElementById(token).style.textShadow = "inherit"; | |
}} | |
}} | |
function getIdSide_{uuid}(id) {{ | |
if (id === null) {{ | |
return 'null' | |
}} | |
return id.split("_")[1]; | |
}} | |
</script> | |
""" | |
return heatmap_html + heatmap_javascript + javascript_values | |
def unpack_shap_explanation_contents(shap_values): | |
values = getattr(shap_values, "hierarchical_values", None) | |
if values is None: | |
values = shap_values.values | |
clustering = getattr(shap_values, "clustering", None) | |
return np.array(values), clustering | |
def _ipython_display_html(data): | |
"""Check IPython is installed, then display HTML""" | |
if not have_ipython: | |
msg = "IPython is required for this function but is not installed. Fix this with `pip install ipython`." | |
raise ImportError(msg) | |
return ipython_display(HTML(data)) | |