File size: 2,999 Bytes
9d1fa85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265622b
9d1fa85
 
86d2882
265622b
 
86d2882
265622b
9d1fa85
 
 
 
 
 
86d2882
 
 
 
 
 
 
9d1fa85
86d2882
 
 
9d1fa85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pathlib
import gradio
from captum.attr import visualization

class Markdown(gradio.Markdown):
    def __init__(self, value, *args, **kwargs):
        if isinstance(value, pathlib.Path):
            value = value.read_text()
        elif isinstance(value, io.TextIOWrapper):
            value = value.read()
        super().__init__(value, *args, **kwargs)

# from https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455
class PyTMinMaxScalerVectorized(object):
    """
    Transforms each channel to the range [0, 1].
    """

    def __init__(self, dimension=-1):
        self.d = dimension

    def __call__(self, tensor):
        d = self.d
        scale = 1.0 / (
            tensor.max(dim=d, keepdim=True)[0] - tensor.min(dim=d, keepdim=True)[0]
        )
        tensor.mul_(scale).sub_(tensor.min(dim=d, keepdim=True)[0])
        return tensor

# copied out of captum because we need raw html instead of a jupyter widget
def visualize_text(datarecords, legend=False):
    dom = ["<table width: 100%>"]
    rows = [
#        "<tr><th>True Label</th>"
        "<th style='text-align: left'>Predicted Label</th>"
        "<th style='text-align: left'>Attribution Label</th>"
#        "<th>Attribution Score</th>"
        "<th style='text-align: left'>Word Importance</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
#                    visualization.format_classname(datarecord.true_class),
#                    visualization.format_classname(
#                        "{0} ({1:.2f})".format(
#                            datarecord.pred_class#, datarecord.pred_prob
#                        )
#                    ),
                    visualization.format_classname(datarecord.pred_class),
                    visualization.format_classname(datarecord.attr_class),
#                    visualization.format_classname(
#                        "{0:.2f}".format(datarecord.attr_score)
#                    ),
                    visualization.format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=visualization._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = "".join(dom)

    return html