import pathlib import gradio from captum.attr import visualization class Markdown(gradio.Markdown): def __init__(self, value, *args, **kwargs): if isinstance(value, pathlib.Path): value = value.read_text() elif isinstance(value, io.TextIOWrapper): value = value.read() super().__init__(value, *args, **kwargs) # from https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455 class PyTMinMaxScalerVectorized(object): """ Transforms each channel to the range [0, 1]. """ def __init__(self, dimension=-1): self.d = dimension def __call__(self, tensor): d = self.d scale = 1.0 / ( tensor.max(dim=d, keepdim=True)[0] - tensor.min(dim=d, keepdim=True)[0] ) tensor.mul_(scale).sub_(tensor.min(dim=d, keepdim=True)[0]) return tensor # copied out of captum because we need raw html instead of a jupyter widget def visualize_text(datarecords, legend=False): dom = [""] rows = [ # "" "" "" # "" "" ] for datarecord in datarecords: rows.append( "".join( [ "", # visualization.format_classname(datarecord.true_class), # visualization.format_classname( # "{0} ({1:.2f})".format( # datarecord.pred_class#, datarecord.pred_prob # ) # ), visualization.format_classname(datarecord.pred_class), visualization.format_classname(datarecord.attr_class), # visualization.format_classname( # "{0:.2f}".format(datarecord.attr_score) # ), visualization.format_word_importances( datarecord.raw_input_ids, datarecord.word_attributions ), "", ] ) ) if legend: dom.append( '
' ) dom.append("Legend: ") for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]): dom.append( ' {label} '.format( value=visualization._get_color(value), label=label ) ) dom.append("
") dom.append("".join(rows)) dom.append("
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
") html = "".join(dom) return html