File size: 2,999 Bytes
9d1fa85 265622b 9d1fa85 86d2882 265622b 86d2882 265622b 9d1fa85 86d2882 9d1fa85 86d2882 9d1fa85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import pathlib
import gradio
from captum.attr import visualization
class Markdown(gradio.Markdown):
def __init__(self, value, *args, **kwargs):
if isinstance(value, pathlib.Path):
value = value.read_text()
elif isinstance(value, io.TextIOWrapper):
value = value.read()
super().__init__(value, *args, **kwargs)
# from https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455
class PyTMinMaxScalerVectorized(object):
"""
Transforms each channel to the range [0, 1].
"""
def __init__(self, dimension=-1):
self.d = dimension
def __call__(self, tensor):
d = self.d
scale = 1.0 / (
tensor.max(dim=d, keepdim=True)[0] - tensor.min(dim=d, keepdim=True)[0]
)
tensor.mul_(scale).sub_(tensor.min(dim=d, keepdim=True)[0])
return tensor
# copied out of captum because we need raw html instead of a jupyter widget
def visualize_text(datarecords, legend=False):
dom = ["<table width: 100%>"]
rows = [
# "<tr><th>True Label</th>"
"<th style='text-align: left'>Predicted Label</th>"
"<th style='text-align: left'>Attribution Label</th>"
# "<th>Attribution Score</th>"
"<th style='text-align: left'>Word Importance</th>"
]
for datarecord in datarecords:
rows.append(
"".join(
[
"<tr>",
# visualization.format_classname(datarecord.true_class),
# visualization.format_classname(
# "{0} ({1:.2f})".format(
# datarecord.pred_class#, datarecord.pred_prob
# )
# ),
visualization.format_classname(datarecord.pred_class),
visualization.format_classname(datarecord.attr_class),
# visualization.format_classname(
# "{0:.2f}".format(datarecord.attr_score)
# ),
visualization.format_word_importances(
datarecord.raw_input_ids, datarecord.word_attributions
),
"<tr>",
]
)
)
if legend:
dom.append(
'<div style="border-top: 1px solid; margin-top: 5px; \
padding-top: 5px; display: inline-block">'
)
dom.append("<b>Legend: </b>")
for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
dom.append(
'<span style="display: inline-block; width: 10px; height: 10px; \
border: 1px solid; background-color: \
{value}"></span> {label} '.format(
value=visualization._get_color(value), label=label
)
)
dom.append("</div>")
dom.append("".join(rows))
dom.append("</table>")
html = "".join(dom)
return html
|